Spaces:
Runtime error
Runtime error
Upload 60 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +5 -0
- LICENSE +201 -0
- configs/vlog_read_script_sample.yaml +39 -0
- configs/vlog_write_script.yaml +3 -0
- configs/with_mask_ref_sample.yaml +36 -0
- configs/with_mask_sample.yaml +33 -0
- datasets/__pycache__/video_transforms.cpython-310.pyc +0 -0
- datasets/video_transforms.py +382 -0
- diffusion/__init__.py +47 -0
- diffusion/__pycache__/__init__.cpython-310.pyc +0 -0
- diffusion/__pycache__/diffusion_utils.cpython-310.pyc +0 -0
- diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc +0 -0
- diffusion/__pycache__/respace.cpython-310.pyc +0 -0
- diffusion/diffusion_utils.py +88 -0
- diffusion/gaussian_diffusion.py +931 -0
- diffusion/respace.py +130 -0
- diffusion/timestep_sampler.py +150 -0
- input/i2v/A_big_drop_of_water_falls_on_a_rose_petal.png +3 -0
- input/i2v/A_fish_swims_past_an_oriental_woman.png +3 -0
- input/i2v/Cinematic_photograph_View_of_piloting_aaero.png +3 -0
- input/i2v/Planet_hits_earth.png +3 -0
- input/i2v/Underwater_environment_cosmetic_bottles.png +3 -0
- models/__init__.py +33 -0
- models/__pycache__/__init__.cpython-310.pyc +0 -0
- models/__pycache__/attention.cpython-310.pyc +0 -0
- models/__pycache__/clip.cpython-310.pyc +0 -0
- models/__pycache__/resnet.cpython-310.pyc +0 -0
- models/__pycache__/unet.cpython-310.pyc +0 -0
- models/__pycache__/unet_blocks.cpython-310.pyc +0 -0
- models/attention.py +966 -0
- models/clip.py +123 -0
- models/resnet.py +212 -0
- models/unet.py +699 -0
- models/unet_blocks.py +650 -0
- models/utils.py +215 -0
- requirements.txt +25 -0
- results/mask_no_ref/Planet_hits_earth..mp4 +0 -0
- results/mask_ref/Planet_hits_earth..mp4 +0 -0
- results/vlog/teddy_travel/ref_img/teddy.jpg +0 -0
- results/vlog/teddy_travel/script/protagonist_place_reference.txt +0 -0
- results/vlog/teddy_travel/script/protagonists_places.txt +22 -0
- results/vlog/teddy_travel/script/time_scripts.txt +94 -0
- results/vlog/teddy_travel/script/video_prompts.txt +0 -0
- results/vlog/teddy_travel/script/zh_video_prompts.txt +95 -0
- results/vlog/teddy_travel/story.txt +1 -0
- results/vlog/teddy_travel_/story.txt +1 -0
- sample_scripts/vlog_read_script_sample.py +303 -0
- sample_scripts/vlog_write_script.py +91 -0
- sample_scripts/with_mask_ref_sample.py +275 -0
- sample_scripts/with_mask_sample.py +249 -0
.gitattributes
CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
input/i2v/A_big_drop_of_water_falls_on_a_rose_petal.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
input/i2v/A_fish_swims_past_an_oriental_woman.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
input/i2v/Cinematic_photograph_View_of_piloting_aaero.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
input/i2v/Planet_hits_earth.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
input/i2v/Underwater_environment_cosmetic_bottles.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
configs/vlog_read_script_sample.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# path:
|
2 |
+
ckpt: "pretrained/ShowMaker.pt"
|
3 |
+
pretrained_model_path: "pretrained/stable-diffusion-v1-4/"
|
4 |
+
image_encoder_path: "pretrained/OpenCLIP-ViT-H-14"
|
5 |
+
save_path: "results/vlog/teddy_travel/video"
|
6 |
+
|
7 |
+
# script path
|
8 |
+
reference_image_path: ["results/vlog/teddy_travel/ref_img/teddy.jpg"]
|
9 |
+
script_file_path: "results/vlog/teddy_travel/script/video_prompts.txt"
|
10 |
+
zh_script_file_path: "results/vlog/teddy_travel/script/zh_video_prompts.txt"
|
11 |
+
protagonist_file_path: "results/vlog/teddy_travel/script/protagonists_places.txt"
|
12 |
+
reference_file_path: "results/vlog/teddy_travel/script/protagonist_place_reference.txt"
|
13 |
+
time_file_path: "results/vlog/teddy_travel/script/time_scripts.txt"
|
14 |
+
video_transition: False
|
15 |
+
|
16 |
+
# model config:
|
17 |
+
model: UNet
|
18 |
+
num_frames: 16
|
19 |
+
image_size: [320, 512]
|
20 |
+
negative_prompt: "white background"
|
21 |
+
|
22 |
+
# sample config:
|
23 |
+
ref_cfg_scale: 0.3
|
24 |
+
seed: 3407
|
25 |
+
guidance_scale: 7.5
|
26 |
+
cfg_scale: 8.0
|
27 |
+
sample_method: 'ddim'
|
28 |
+
num_sampling_steps: 100
|
29 |
+
researve_frame: 3
|
30 |
+
mask_type: "first3"
|
31 |
+
use_mask: True
|
32 |
+
use_fp16: True
|
33 |
+
enable_xformers_memory_efficient_attention: True
|
34 |
+
do_classifier_free_guidance: True
|
35 |
+
fps: 8
|
36 |
+
sample_num:
|
37 |
+
|
38 |
+
# model speedup
|
39 |
+
use_compile: False
|
configs/vlog_write_script.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# script path
|
2 |
+
story_path: "./results/vlog/teddy_travel_/story.txt"
|
3 |
+
only_one_protagonist: False
|
configs/with_mask_ref_sample.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# path config:
|
2 |
+
ckpt: "pretrained/ShowMaker.pt"
|
3 |
+
pretrained_model_path: "pretrained/stable-diffusion-v1-4/"
|
4 |
+
image_encoder_path: "pretrained/OpenCLIP-ViT-H-14"
|
5 |
+
input_path: 'input/i2v/Planet_hits_earth.png'
|
6 |
+
ref_path: 'input/i2v/Planet_hits_earth.png'
|
7 |
+
save_path: "results/mask_ref/"
|
8 |
+
|
9 |
+
# model config:
|
10 |
+
model: UNet
|
11 |
+
num_frames: 16
|
12 |
+
# image_size: [320, 512]
|
13 |
+
image_size: [240, 560]
|
14 |
+
|
15 |
+
# model speedup
|
16 |
+
use_fp16: True
|
17 |
+
enable_xformers_memory_efficient_attention: True
|
18 |
+
|
19 |
+
# sample config:
|
20 |
+
seed: 3407
|
21 |
+
cfg_scale: 8.0
|
22 |
+
ref_cfg_scale: 0.5
|
23 |
+
sample_method: 'ddim'
|
24 |
+
num_sampling_steps: 100
|
25 |
+
text_prompt: [
|
26 |
+
# "Cinematic photograph. View of piloting aaero.",
|
27 |
+
# "A fish swims past an oriental woman.",
|
28 |
+
# "A big drop of water falls on a rose petal.",
|
29 |
+
# "Underwater environment cosmetic bottles.".
|
30 |
+
"Planet hits earth.",
|
31 |
+
]
|
32 |
+
additional_prompt: ""
|
33 |
+
negative_prompt: ""
|
34 |
+
do_classifier_free_guidance: True
|
35 |
+
mask_type: "first1"
|
36 |
+
use_mask: True
|
configs/with_mask_sample.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# path config:
|
2 |
+
ckpt: "pretrained/ShowMaker.pt"
|
3 |
+
pretrained_model_path: "pretrained/OpenCLIP-ViT-H-14"
|
4 |
+
input_path: 'input/i2v/Planet_hits_earth.png'
|
5 |
+
save_path: "results/mask_no_ref/"
|
6 |
+
|
7 |
+
# model config:
|
8 |
+
model: UNet
|
9 |
+
num_frames: 16
|
10 |
+
# image_size: [320, 512]
|
11 |
+
image_size: [240, 560]
|
12 |
+
|
13 |
+
# model speedup
|
14 |
+
use_fp16: True
|
15 |
+
enable_xformers_memory_efficient_attention: True
|
16 |
+
|
17 |
+
# sample config:
|
18 |
+
seed: 3407
|
19 |
+
cfg_scale: 8.0
|
20 |
+
sample_method: 'ddim'
|
21 |
+
num_sampling_steps: 100
|
22 |
+
text_prompt: [
|
23 |
+
# "Cinematic photograph. View of piloting aaero.",
|
24 |
+
# "A fish swims past an oriental woman.",
|
25 |
+
# "A big drop of water falls on a rose petal.",
|
26 |
+
# "Underwater environment cosmetic bottles.".
|
27 |
+
"Planet hits earth.",
|
28 |
+
]
|
29 |
+
additional_prompt: ""
|
30 |
+
negative_prompt: ""
|
31 |
+
do_classifier_free_guidance: True
|
32 |
+
mask_type: "first1"
|
33 |
+
use_mask: True
|
datasets/__pycache__/video_transforms.cpython-310.pyc
ADDED
Binary file (12.4 kB). View file
|
|
datasets/video_transforms.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import random
|
3 |
+
import numbers
|
4 |
+
from torchvision.transforms import RandomCrop, RandomResizedCrop
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
def _is_tensor_video_clip(clip):
|
8 |
+
if not torch.is_tensor(clip):
|
9 |
+
raise TypeError("clip should be Tensor. Got %s" % type(clip))
|
10 |
+
|
11 |
+
if not clip.ndimension() == 4:
|
12 |
+
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
|
13 |
+
|
14 |
+
return True
|
15 |
+
|
16 |
+
|
17 |
+
def center_crop_arr(pil_image, image_size):
|
18 |
+
"""
|
19 |
+
Center cropping implementation from ADM.
|
20 |
+
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
21 |
+
"""
|
22 |
+
while min(*pil_image.size) >= 2 * image_size:
|
23 |
+
pil_image = pil_image.resize(
|
24 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
25 |
+
)
|
26 |
+
|
27 |
+
scale = image_size / min(*pil_image.size)
|
28 |
+
pil_image = pil_image.resize(
|
29 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
30 |
+
)
|
31 |
+
|
32 |
+
arr = np.array(pil_image)
|
33 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
34 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
35 |
+
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
36 |
+
|
37 |
+
|
38 |
+
def crop(clip, i, j, h, w):
|
39 |
+
"""
|
40 |
+
Args:
|
41 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
42 |
+
"""
|
43 |
+
if len(clip.size()) != 4:
|
44 |
+
raise ValueError("clip should be a 4D tensor")
|
45 |
+
return clip[..., i : i + h, j : j + w]
|
46 |
+
|
47 |
+
|
48 |
+
def resize(clip, target_size, interpolation_mode):
|
49 |
+
if len(target_size) != 2:
|
50 |
+
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
|
51 |
+
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
|
52 |
+
|
53 |
+
def resize_scale(clip, target_size, interpolation_mode):
|
54 |
+
if len(target_size) != 2:
|
55 |
+
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
|
56 |
+
H, W = clip.size(-2), clip.size(-1)
|
57 |
+
scale_ = target_size[0] / min(H, W)
|
58 |
+
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
|
59 |
+
|
60 |
+
def resize_with_scale_factor(clip, scale_factor, interpolation_mode):
|
61 |
+
return torch.nn.functional.interpolate(clip, scale_factor=scale_factor, mode=interpolation_mode, align_corners=False)
|
62 |
+
|
63 |
+
def resize_scale_with_height(clip, target_size, interpolation_mode):
|
64 |
+
H, W = clip.size(-2), clip.size(-1)
|
65 |
+
scale_ = target_size / H
|
66 |
+
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
|
67 |
+
|
68 |
+
def resize_scale_with_weight(clip, target_size, interpolation_mode):
|
69 |
+
H, W = clip.size(-2), clip.size(-1)
|
70 |
+
scale_ = target_size / W
|
71 |
+
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
|
72 |
+
|
73 |
+
|
74 |
+
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
|
75 |
+
"""
|
76 |
+
Do spatial cropping and resizing to the video clip
|
77 |
+
Args:
|
78 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
79 |
+
i (int): i in (i,j) i.e coordinates of the upper left corner.
|
80 |
+
j (int): j in (i,j) i.e coordinates of the upper left corner.
|
81 |
+
h (int): Height of the cropped region.
|
82 |
+
w (int): Width of the cropped region.
|
83 |
+
size (tuple(int, int)): height and width of resized clip
|
84 |
+
Returns:
|
85 |
+
clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
|
86 |
+
"""
|
87 |
+
if not _is_tensor_video_clip(clip):
|
88 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
89 |
+
clip = crop(clip, i, j, h, w)
|
90 |
+
clip = resize(clip, size, interpolation_mode)
|
91 |
+
return clip
|
92 |
+
|
93 |
+
|
94 |
+
def center_crop(clip, crop_size):
|
95 |
+
if not _is_tensor_video_clip(clip):
|
96 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
97 |
+
h, w = clip.size(-2), clip.size(-1)
|
98 |
+
# print(clip.shape)
|
99 |
+
th, tw = crop_size
|
100 |
+
if h < th or w < tw:
|
101 |
+
# print(h, w)
|
102 |
+
raise ValueError("height {} and width {} must be no smaller than crop_size".format(h, w))
|
103 |
+
|
104 |
+
i = int(round((h - th) / 2.0))
|
105 |
+
j = int(round((w - tw) / 2.0))
|
106 |
+
return crop(clip, i, j, th, tw)
|
107 |
+
|
108 |
+
|
109 |
+
def center_crop_using_short_edge(clip):
|
110 |
+
if not _is_tensor_video_clip(clip):
|
111 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
112 |
+
h, w = clip.size(-2), clip.size(-1)
|
113 |
+
if h < w:
|
114 |
+
th, tw = h, h
|
115 |
+
i = 0
|
116 |
+
j = int(round((w - tw) / 2.0))
|
117 |
+
else:
|
118 |
+
th, tw = w, w
|
119 |
+
i = int(round((h - th) / 2.0))
|
120 |
+
j = 0
|
121 |
+
return crop(clip, i, j, th, tw)
|
122 |
+
|
123 |
+
|
124 |
+
def random_shift_crop(clip):
|
125 |
+
'''
|
126 |
+
Slide along the long edge, with the short edge as crop size
|
127 |
+
'''
|
128 |
+
if not _is_tensor_video_clip(clip):
|
129 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
130 |
+
h, w = clip.size(-2), clip.size(-1)
|
131 |
+
|
132 |
+
if h <= w:
|
133 |
+
long_edge = w
|
134 |
+
short_edge = h
|
135 |
+
else:
|
136 |
+
long_edge = h
|
137 |
+
short_edge =w
|
138 |
+
|
139 |
+
th, tw = short_edge, short_edge
|
140 |
+
|
141 |
+
i = torch.randint(0, h - th + 1, size=(1,)).item()
|
142 |
+
j = torch.randint(0, w - tw + 1, size=(1,)).item()
|
143 |
+
return crop(clip, i, j, th, tw)
|
144 |
+
|
145 |
+
|
146 |
+
def to_tensor(clip):
|
147 |
+
"""
|
148 |
+
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
149 |
+
permute the dimensions of clip tensor
|
150 |
+
Args:
|
151 |
+
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
|
152 |
+
Return:
|
153 |
+
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
|
154 |
+
"""
|
155 |
+
_is_tensor_video_clip(clip)
|
156 |
+
if not clip.dtype == torch.uint8:
|
157 |
+
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
|
158 |
+
# return clip.float().permute(3, 0, 1, 2) / 255.0
|
159 |
+
return clip.float() / 255.0
|
160 |
+
|
161 |
+
|
162 |
+
def normalize(clip, mean, std, inplace=False):
|
163 |
+
"""
|
164 |
+
Args:
|
165 |
+
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
|
166 |
+
mean (tuple): pixel RGB mean. Size is (3)
|
167 |
+
std (tuple): pixel standard deviation. Size is (3)
|
168 |
+
Returns:
|
169 |
+
normalized clip (torch.tensor): Size is (T, C, H, W)
|
170 |
+
"""
|
171 |
+
if not _is_tensor_video_clip(clip):
|
172 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
173 |
+
if not inplace:
|
174 |
+
clip = clip.clone()
|
175 |
+
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
|
176 |
+
# print(mean)
|
177 |
+
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
|
178 |
+
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
|
179 |
+
return clip
|
180 |
+
|
181 |
+
|
182 |
+
def hflip(clip):
|
183 |
+
"""
|
184 |
+
Args:
|
185 |
+
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
|
186 |
+
Returns:
|
187 |
+
flipped clip (torch.tensor): Size is (T, C, H, W)
|
188 |
+
"""
|
189 |
+
if not _is_tensor_video_clip(clip):
|
190 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
191 |
+
return clip.flip(-1)
|
192 |
+
|
193 |
+
|
194 |
+
class RandomCropVideo:
|
195 |
+
def __init__(self, size):
|
196 |
+
if isinstance(size, numbers.Number):
|
197 |
+
self.size = (int(size), int(size))
|
198 |
+
else:
|
199 |
+
self.size = size
|
200 |
+
|
201 |
+
def __call__(self, clip):
|
202 |
+
"""
|
203 |
+
Args:
|
204 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
205 |
+
Returns:
|
206 |
+
torch.tensor: randomly cropped video clip.
|
207 |
+
size is (T, C, OH, OW)
|
208 |
+
"""
|
209 |
+
i, j, h, w = self.get_params(clip)
|
210 |
+
return crop(clip, i, j, h, w)
|
211 |
+
|
212 |
+
def get_params(self, clip):
|
213 |
+
h, w = clip.shape[-2:]
|
214 |
+
th, tw = self.size
|
215 |
+
|
216 |
+
if h < th or w < tw:
|
217 |
+
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
|
218 |
+
|
219 |
+
if w == tw and h == th:
|
220 |
+
return 0, 0, h, w
|
221 |
+
|
222 |
+
i = torch.randint(0, h - th + 1, size=(1,)).item()
|
223 |
+
j = torch.randint(0, w - tw + 1, size=(1,)).item()
|
224 |
+
|
225 |
+
return i, j, th, tw
|
226 |
+
|
227 |
+
def __repr__(self) -> str:
|
228 |
+
return f"{self.__class__.__name__}(size={self.size})"
|
229 |
+
|
230 |
+
class CenterCropResizeVideo:
|
231 |
+
'''
|
232 |
+
First use the short side for cropping length,
|
233 |
+
center crop video, then resize to the specified size
|
234 |
+
'''
|
235 |
+
def __init__(
|
236 |
+
self,
|
237 |
+
size,
|
238 |
+
interpolation_mode="bilinear",
|
239 |
+
):
|
240 |
+
if isinstance(size, tuple):
|
241 |
+
if len(size) != 2:
|
242 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
243 |
+
self.size = size
|
244 |
+
else:
|
245 |
+
self.size = (size, size)
|
246 |
+
|
247 |
+
self.interpolation_mode = interpolation_mode
|
248 |
+
|
249 |
+
|
250 |
+
def __call__(self, clip):
|
251 |
+
"""
|
252 |
+
Args:
|
253 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
254 |
+
Returns:
|
255 |
+
torch.tensor: scale resized / center cropped video clip.
|
256 |
+
size is (T, C, crop_size, crop_size)
|
257 |
+
"""
|
258 |
+
# print(clip.shape)
|
259 |
+
clip_center_crop = center_crop_using_short_edge(clip)
|
260 |
+
# print(clip_center_crop.shape) 320 512
|
261 |
+
clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode)
|
262 |
+
return clip_center_crop_resize
|
263 |
+
|
264 |
+
def __repr__(self) -> str:
|
265 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
266 |
+
|
267 |
+
|
268 |
+
class CenterCropVideo:
|
269 |
+
def __init__(
|
270 |
+
self,
|
271 |
+
size,
|
272 |
+
interpolation_mode="bilinear",
|
273 |
+
):
|
274 |
+
if isinstance(size, tuple):
|
275 |
+
if len(size) != 2:
|
276 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
277 |
+
self.size = size
|
278 |
+
else:
|
279 |
+
self.size = (size, size)
|
280 |
+
|
281 |
+
self.interpolation_mode = interpolation_mode
|
282 |
+
|
283 |
+
|
284 |
+
def __call__(self, clip):
|
285 |
+
"""
|
286 |
+
Args:
|
287 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
288 |
+
Returns:
|
289 |
+
torch.tensor: center cropped video clip.
|
290 |
+
size is (T, C, crop_size, crop_size)
|
291 |
+
"""
|
292 |
+
clip_center_crop = center_crop(clip, self.size)
|
293 |
+
return clip_center_crop
|
294 |
+
|
295 |
+
def __repr__(self) -> str:
|
296 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
297 |
+
|
298 |
+
|
299 |
+
class NormalizeVideo:
|
300 |
+
"""
|
301 |
+
Normalize the video clip by mean subtraction and division by standard deviation
|
302 |
+
Args:
|
303 |
+
mean (3-tuple): pixel RGB mean
|
304 |
+
std (3-tuple): pixel RGB standard deviation
|
305 |
+
inplace (boolean): whether do in-place normalization
|
306 |
+
"""
|
307 |
+
|
308 |
+
def __init__(self, mean, std, inplace=False):
|
309 |
+
self.mean = mean
|
310 |
+
self.std = std
|
311 |
+
self.inplace = inplace
|
312 |
+
|
313 |
+
def __call__(self, clip):
|
314 |
+
"""
|
315 |
+
Args:
|
316 |
+
clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
|
317 |
+
"""
|
318 |
+
return normalize(clip, self.mean, self.std, self.inplace)
|
319 |
+
|
320 |
+
def __repr__(self) -> str:
|
321 |
+
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
|
322 |
+
|
323 |
+
|
324 |
+
class ToTensorVideo:
|
325 |
+
"""
|
326 |
+
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
327 |
+
permute the dimensions of clip tensor
|
328 |
+
"""
|
329 |
+
|
330 |
+
def __init__(self):
|
331 |
+
pass
|
332 |
+
|
333 |
+
def __call__(self, clip):
|
334 |
+
"""
|
335 |
+
Args:
|
336 |
+
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
|
337 |
+
Return:
|
338 |
+
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
|
339 |
+
"""
|
340 |
+
return to_tensor(clip)
|
341 |
+
|
342 |
+
def __repr__(self) -> str:
|
343 |
+
return self.__class__.__name__
|
344 |
+
|
345 |
+
|
346 |
+
class ResizeVideo():
|
347 |
+
'''
|
348 |
+
First use the short side for cropping length,
|
349 |
+
center crop video, then resize to the specified size
|
350 |
+
'''
|
351 |
+
def __init__(
|
352 |
+
self,
|
353 |
+
size,
|
354 |
+
interpolation_mode="bilinear",
|
355 |
+
):
|
356 |
+
if isinstance(size, tuple):
|
357 |
+
if len(size) != 2:
|
358 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
359 |
+
self.size = size
|
360 |
+
else:
|
361 |
+
self.size = (size, size)
|
362 |
+
|
363 |
+
self.interpolation_mode = interpolation_mode
|
364 |
+
|
365 |
+
|
366 |
+
def __call__(self, clip):
|
367 |
+
"""
|
368 |
+
Args:
|
369 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
370 |
+
Returns:
|
371 |
+
torch.tensor: scale resized / center cropped video clip.
|
372 |
+
size is (T, C, crop_size, crop_size)
|
373 |
+
"""
|
374 |
+
clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
|
375 |
+
return clip_resize
|
376 |
+
|
377 |
+
def __repr__(self) -> str:
|
378 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
379 |
+
|
380 |
+
# ------------------------------------------------------------
|
381 |
+
# --------------------- Sampling ---------------------------
|
382 |
+
# ------------------------------------------------------------
|
diffusion/__init__.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
from . import gaussian_diffusion as gd
|
7 |
+
from .respace import SpacedDiffusion, space_timesteps
|
8 |
+
|
9 |
+
|
10 |
+
def create_diffusion(
|
11 |
+
timestep_respacing,
|
12 |
+
noise_schedule="linear",
|
13 |
+
use_kl=False,
|
14 |
+
sigma_small=False,
|
15 |
+
predict_xstart=False,
|
16 |
+
# learn_sigma=True,
|
17 |
+
learn_sigma=False, # for unet
|
18 |
+
rescale_learned_sigmas=False,
|
19 |
+
diffusion_steps=1000
|
20 |
+
):
|
21 |
+
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
|
22 |
+
if use_kl:
|
23 |
+
loss_type = gd.LossType.RESCALED_KL
|
24 |
+
elif rescale_learned_sigmas:
|
25 |
+
loss_type = gd.LossType.RESCALED_MSE
|
26 |
+
else:
|
27 |
+
loss_type = gd.LossType.MSE
|
28 |
+
if timestep_respacing is None or timestep_respacing == "":
|
29 |
+
timestep_respacing = [diffusion_steps]
|
30 |
+
return SpacedDiffusion(
|
31 |
+
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
|
32 |
+
betas=betas,
|
33 |
+
model_mean_type=(
|
34 |
+
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
|
35 |
+
),
|
36 |
+
model_var_type=(
|
37 |
+
(
|
38 |
+
gd.ModelVarType.FIXED_LARGE
|
39 |
+
if not sigma_small
|
40 |
+
else gd.ModelVarType.FIXED_SMALL
|
41 |
+
)
|
42 |
+
if not learn_sigma
|
43 |
+
else gd.ModelVarType.LEARNED_RANGE
|
44 |
+
),
|
45 |
+
loss_type=loss_type
|
46 |
+
# rescale_timesteps=rescale_timesteps,
|
47 |
+
)
|
diffusion/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (1.01 kB). View file
|
|
diffusion/__pycache__/diffusion_utils.cpython-310.pyc
ADDED
Binary file (2.83 kB). View file
|
|
diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc
ADDED
Binary file (25 kB). View file
|
|
diffusion/__pycache__/respace.cpython-310.pyc
ADDED
Binary file (4.98 kB). View file
|
|
diffusion/diffusion_utils.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
import torch as th
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
11 |
+
"""
|
12 |
+
Compute the KL divergence between two gaussians.
|
13 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
14 |
+
scalars, among other use cases.
|
15 |
+
"""
|
16 |
+
tensor = None
|
17 |
+
for obj in (mean1, logvar1, mean2, logvar2):
|
18 |
+
if isinstance(obj, th.Tensor):
|
19 |
+
tensor = obj
|
20 |
+
break
|
21 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
22 |
+
|
23 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
24 |
+
# Tensors, but it does not work for th.exp().
|
25 |
+
logvar1, logvar2 = [
|
26 |
+
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
|
27 |
+
for x in (logvar1, logvar2)
|
28 |
+
]
|
29 |
+
|
30 |
+
return 0.5 * (
|
31 |
+
-1.0
|
32 |
+
+ logvar2
|
33 |
+
- logvar1
|
34 |
+
+ th.exp(logvar1 - logvar2)
|
35 |
+
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
def approx_standard_normal_cdf(x):
|
40 |
+
"""
|
41 |
+
A fast approximation of the cumulative distribution function of the
|
42 |
+
standard normal.
|
43 |
+
"""
|
44 |
+
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
|
45 |
+
|
46 |
+
|
47 |
+
def continuous_gaussian_log_likelihood(x, *, means, log_scales):
|
48 |
+
"""
|
49 |
+
Compute the log-likelihood of a continuous Gaussian distribution.
|
50 |
+
:param x: the targets
|
51 |
+
:param means: the Gaussian mean Tensor.
|
52 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
53 |
+
:return: a tensor like x of log probabilities (in nats).
|
54 |
+
"""
|
55 |
+
centered_x = x - means
|
56 |
+
inv_stdv = th.exp(-log_scales)
|
57 |
+
normalized_x = centered_x * inv_stdv
|
58 |
+
log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
|
59 |
+
return log_probs
|
60 |
+
|
61 |
+
|
62 |
+
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
63 |
+
"""
|
64 |
+
Compute the log-likelihood of a Gaussian distribution discretizing to a
|
65 |
+
given image.
|
66 |
+
:param x: the target images. It is assumed that this was uint8 values,
|
67 |
+
rescaled to the range [-1, 1].
|
68 |
+
:param means: the Gaussian mean Tensor.
|
69 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
70 |
+
:return: a tensor like x of log probabilities (in nats).
|
71 |
+
"""
|
72 |
+
assert x.shape == means.shape == log_scales.shape
|
73 |
+
centered_x = x - means
|
74 |
+
inv_stdv = th.exp(-log_scales)
|
75 |
+
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
|
76 |
+
cdf_plus = approx_standard_normal_cdf(plus_in)
|
77 |
+
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
|
78 |
+
cdf_min = approx_standard_normal_cdf(min_in)
|
79 |
+
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
|
80 |
+
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
|
81 |
+
cdf_delta = cdf_plus - cdf_min
|
82 |
+
log_probs = th.where(
|
83 |
+
x < -0.999,
|
84 |
+
log_cdf_plus,
|
85 |
+
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
|
86 |
+
)
|
87 |
+
assert log_probs.shape == x.shape
|
88 |
+
return log_probs
|
diffusion/gaussian_diffusion.py
ADDED
@@ -0,0 +1,931 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch as th
|
11 |
+
import enum
|
12 |
+
|
13 |
+
from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
|
14 |
+
|
15 |
+
|
16 |
+
def mean_flat(tensor):
|
17 |
+
"""
|
18 |
+
Take the mean over all non-batch dimensions.
|
19 |
+
"""
|
20 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
21 |
+
|
22 |
+
|
23 |
+
class ModelMeanType(enum.Enum):
|
24 |
+
"""
|
25 |
+
Which type of output the model predicts.
|
26 |
+
"""
|
27 |
+
|
28 |
+
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
|
29 |
+
START_X = enum.auto() # the model predicts x_0
|
30 |
+
EPSILON = enum.auto() # the model predicts epsilon
|
31 |
+
|
32 |
+
|
33 |
+
class ModelVarType(enum.Enum):
|
34 |
+
"""
|
35 |
+
What is used as the model's output variance.
|
36 |
+
The LEARNED_RANGE option has been added to allow the model to predict
|
37 |
+
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
|
38 |
+
"""
|
39 |
+
|
40 |
+
LEARNED = enum.auto()
|
41 |
+
FIXED_SMALL = enum.auto()
|
42 |
+
FIXED_LARGE = enum.auto()
|
43 |
+
LEARNED_RANGE = enum.auto()
|
44 |
+
|
45 |
+
|
46 |
+
class LossType(enum.Enum):
|
47 |
+
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
|
48 |
+
RESCALED_MSE = (
|
49 |
+
enum.auto()
|
50 |
+
) # use raw MSE loss (with RESCALED_KL when learning variances)
|
51 |
+
KL = enum.auto() # use the variational lower-bound
|
52 |
+
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
|
53 |
+
|
54 |
+
def is_vb(self):
|
55 |
+
return self == LossType.KL or self == LossType.RESCALED_KL
|
56 |
+
|
57 |
+
|
58 |
+
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
|
59 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
60 |
+
warmup_time = int(num_diffusion_timesteps * warmup_frac)
|
61 |
+
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
|
62 |
+
return betas
|
63 |
+
|
64 |
+
|
65 |
+
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
|
66 |
+
"""
|
67 |
+
This is the deprecated API for creating beta schedules.
|
68 |
+
See get_named_beta_schedule() for the new library of schedules.
|
69 |
+
"""
|
70 |
+
if beta_schedule == "quad":
|
71 |
+
betas = (
|
72 |
+
np.linspace(
|
73 |
+
beta_start ** 0.5,
|
74 |
+
beta_end ** 0.5,
|
75 |
+
num_diffusion_timesteps,
|
76 |
+
dtype=np.float64,
|
77 |
+
)
|
78 |
+
** 2
|
79 |
+
)
|
80 |
+
elif beta_schedule == "linear":
|
81 |
+
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
82 |
+
elif beta_schedule == "warmup10":
|
83 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
|
84 |
+
elif beta_schedule == "warmup50":
|
85 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
|
86 |
+
elif beta_schedule == "const":
|
87 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
88 |
+
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
|
89 |
+
betas = 1.0 / np.linspace(
|
90 |
+
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
|
91 |
+
)
|
92 |
+
else:
|
93 |
+
raise NotImplementedError(beta_schedule)
|
94 |
+
assert betas.shape == (num_diffusion_timesteps,)
|
95 |
+
return betas
|
96 |
+
|
97 |
+
|
98 |
+
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
99 |
+
"""
|
100 |
+
Get a pre-defined beta schedule for the given name.
|
101 |
+
The beta schedule library consists of beta schedules which remain similar
|
102 |
+
in the limit of num_diffusion_timesteps.
|
103 |
+
Beta schedules may be added, but should not be removed or changed once
|
104 |
+
they are committed to maintain backwards compatibility.
|
105 |
+
"""
|
106 |
+
if schedule_name == "linear":
|
107 |
+
# Linear schedule from Ho et al, extended to work for any number of
|
108 |
+
# diffusion steps.
|
109 |
+
scale = 1000 / num_diffusion_timesteps
|
110 |
+
return get_beta_schedule(
|
111 |
+
"linear",
|
112 |
+
beta_start=scale * 0.0001,
|
113 |
+
beta_end=scale * 0.02,
|
114 |
+
# diffuser stable diffusion
|
115 |
+
# beta_start=scale * 0.00085,
|
116 |
+
# beta_end=scale * 0.012,
|
117 |
+
num_diffusion_timesteps=num_diffusion_timesteps,
|
118 |
+
)
|
119 |
+
elif schedule_name == "squaredcos_cap_v2":
|
120 |
+
return betas_for_alpha_bar(
|
121 |
+
num_diffusion_timesteps,
|
122 |
+
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
123 |
+
)
|
124 |
+
else:
|
125 |
+
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
126 |
+
|
127 |
+
|
128 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
129 |
+
"""
|
130 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
131 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
132 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
133 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
134 |
+
produces the cumulative product of (1-beta) up to that
|
135 |
+
part of the diffusion process.
|
136 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
137 |
+
prevent singularities.
|
138 |
+
"""
|
139 |
+
betas = []
|
140 |
+
for i in range(num_diffusion_timesteps):
|
141 |
+
t1 = i / num_diffusion_timesteps
|
142 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
143 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
144 |
+
return np.array(betas)
|
145 |
+
|
146 |
+
|
147 |
+
class GaussianDiffusion:
|
148 |
+
"""
|
149 |
+
Utilities for training and sampling diffusion models.
|
150 |
+
Original ported from this codebase:
|
151 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
|
152 |
+
:param betas: a 1-D numpy array of betas for each diffusion timestep,
|
153 |
+
starting at T and going to 1.
|
154 |
+
"""
|
155 |
+
|
156 |
+
def __init__(
|
157 |
+
self,
|
158 |
+
*,
|
159 |
+
betas,
|
160 |
+
model_mean_type,
|
161 |
+
model_var_type,
|
162 |
+
loss_type
|
163 |
+
):
|
164 |
+
|
165 |
+
self.model_mean_type = model_mean_type
|
166 |
+
self.model_var_type = model_var_type
|
167 |
+
self.loss_type = loss_type
|
168 |
+
|
169 |
+
# Use float64 for accuracy.
|
170 |
+
betas = np.array(betas, dtype=np.float64)
|
171 |
+
self.betas = betas
|
172 |
+
assert len(betas.shape) == 1, "betas must be 1-D"
|
173 |
+
assert (betas > 0).all() and (betas <= 1).all()
|
174 |
+
|
175 |
+
self.num_timesteps = int(betas.shape[0])
|
176 |
+
|
177 |
+
alphas = 1.0 - betas
|
178 |
+
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
179 |
+
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
180 |
+
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
|
181 |
+
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
|
182 |
+
|
183 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
184 |
+
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
185 |
+
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
|
186 |
+
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
|
187 |
+
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
188 |
+
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
189 |
+
|
190 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
191 |
+
self.posterior_variance = (
|
192 |
+
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
193 |
+
)
|
194 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
195 |
+
self.posterior_log_variance_clipped = np.log(
|
196 |
+
np.append(self.posterior_variance[1], self.posterior_variance[1:])
|
197 |
+
) if len(self.posterior_variance) > 1 else np.array([])
|
198 |
+
|
199 |
+
self.posterior_mean_coef1 = (
|
200 |
+
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
201 |
+
)
|
202 |
+
self.posterior_mean_coef2 = (
|
203 |
+
(1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
204 |
+
)
|
205 |
+
|
206 |
+
def q_mean_variance(self, x_start, t):
|
207 |
+
"""
|
208 |
+
Get the distribution q(x_t | x_0).
|
209 |
+
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
210 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
211 |
+
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
212 |
+
"""
|
213 |
+
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
214 |
+
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
215 |
+
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
216 |
+
return mean, variance, log_variance
|
217 |
+
|
218 |
+
def q_sample(self, x_start, t, noise=None):
|
219 |
+
"""
|
220 |
+
Diffuse the data for a given number of diffusion steps.
|
221 |
+
In other words, sample from q(x_t | x_0).
|
222 |
+
:param x_start: the initial data batch.
|
223 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
224 |
+
:param noise: if specified, the split-out normal noise.
|
225 |
+
:return: A noisy version of x_start.
|
226 |
+
"""
|
227 |
+
if noise is None:
|
228 |
+
noise = th.randn_like(x_start)
|
229 |
+
assert noise.shape == x_start.shape
|
230 |
+
return (
|
231 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
232 |
+
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
233 |
+
)
|
234 |
+
|
235 |
+
def q_posterior_mean_variance(self, x_start, x_t, t):
|
236 |
+
"""
|
237 |
+
Compute the mean and variance of the diffusion posterior:
|
238 |
+
q(x_{t-1} | x_t, x_0)
|
239 |
+
"""
|
240 |
+
assert x_start.shape == x_t.shape
|
241 |
+
posterior_mean = (
|
242 |
+
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
243 |
+
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
244 |
+
)
|
245 |
+
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
246 |
+
posterior_log_variance_clipped = _extract_into_tensor(
|
247 |
+
self.posterior_log_variance_clipped, t, x_t.shape
|
248 |
+
)
|
249 |
+
assert (
|
250 |
+
posterior_mean.shape[0]
|
251 |
+
== posterior_variance.shape[0]
|
252 |
+
== posterior_log_variance_clipped.shape[0]
|
253 |
+
== x_start.shape[0]
|
254 |
+
)
|
255 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
256 |
+
|
257 |
+
def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None,
|
258 |
+
mask=None, x_start=None, use_concat=False):
|
259 |
+
"""
|
260 |
+
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
261 |
+
the initial x, x_0.
|
262 |
+
:param model: the model, which takes a signal and a batch of timesteps
|
263 |
+
as input.
|
264 |
+
:param x: the [N x C x ...] tensor at time t.
|
265 |
+
:param t: a 1-D Tensor of timesteps.
|
266 |
+
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
267 |
+
:param denoised_fn: if not None, a function which applies to the
|
268 |
+
x_start prediction before it is used to sample. Applies before
|
269 |
+
clip_denoised.
|
270 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
271 |
+
pass to the model. This can be used for conditioning.
|
272 |
+
:return: a dict with the following keys:
|
273 |
+
- 'mean': the model mean output.
|
274 |
+
- 'variance': the model variance output.
|
275 |
+
- 'log_variance': the log of 'variance'.
|
276 |
+
- 'pred_xstart': the prediction for x_0.
|
277 |
+
"""
|
278 |
+
if model_kwargs is None:
|
279 |
+
model_kwargs = {}
|
280 |
+
|
281 |
+
B, F, C = x.shape[:3]
|
282 |
+
assert t.shape == (B,)
|
283 |
+
if use_concat:
|
284 |
+
model_output = model(th.concat([x, mask, x_start], dim=1), t, **model_kwargs)
|
285 |
+
else:
|
286 |
+
model_output = model(x, t, **model_kwargs)
|
287 |
+
try:
|
288 |
+
model_output = model_output.sample # for tav unet
|
289 |
+
except:
|
290 |
+
pass
|
291 |
+
# model_output = model(x, t, **model_kwargs)
|
292 |
+
if isinstance(model_output, tuple):
|
293 |
+
model_output, extra = model_output
|
294 |
+
else:
|
295 |
+
extra = None
|
296 |
+
|
297 |
+
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
298 |
+
assert model_output.shape == (B, F, C * 2, *x.shape[3:])
|
299 |
+
model_output, model_var_values = th.split(model_output, C, dim=2)
|
300 |
+
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
|
301 |
+
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
|
302 |
+
# The model_var_values is [-1, 1] for [min_var, max_var].
|
303 |
+
frac = (model_var_values + 1) / 2
|
304 |
+
model_log_variance = frac * max_log + (1 - frac) * min_log
|
305 |
+
model_variance = th.exp(model_log_variance)
|
306 |
+
else:
|
307 |
+
model_variance, model_log_variance = {
|
308 |
+
# for fixedlarge, we set the initial (log-)variance like so
|
309 |
+
# to get a better decoder log likelihood.
|
310 |
+
ModelVarType.FIXED_LARGE: (
|
311 |
+
np.append(self.posterior_variance[1], self.betas[1:]),
|
312 |
+
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
|
313 |
+
),
|
314 |
+
ModelVarType.FIXED_SMALL: (
|
315 |
+
self.posterior_variance,
|
316 |
+
self.posterior_log_variance_clipped,
|
317 |
+
),
|
318 |
+
}[self.model_var_type]
|
319 |
+
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
320 |
+
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
321 |
+
|
322 |
+
def process_xstart(x):
|
323 |
+
if denoised_fn is not None:
|
324 |
+
x = denoised_fn(x)
|
325 |
+
if clip_denoised:
|
326 |
+
return x.clamp(-1, 1)
|
327 |
+
return x
|
328 |
+
|
329 |
+
if self.model_mean_type == ModelMeanType.START_X:
|
330 |
+
pred_xstart = process_xstart(model_output)
|
331 |
+
else:
|
332 |
+
pred_xstart = process_xstart(
|
333 |
+
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
334 |
+
)
|
335 |
+
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
|
336 |
+
|
337 |
+
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
338 |
+
return {
|
339 |
+
"mean": model_mean,
|
340 |
+
"variance": model_variance,
|
341 |
+
"log_variance": model_log_variance,
|
342 |
+
"pred_xstart": pred_xstart,
|
343 |
+
"extra": extra,
|
344 |
+
}
|
345 |
+
|
346 |
+
def _predict_xstart_from_eps(self, x_t, t, eps):
|
347 |
+
assert x_t.shape == eps.shape
|
348 |
+
return (
|
349 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
350 |
+
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
351 |
+
)
|
352 |
+
|
353 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
354 |
+
return (
|
355 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
356 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
357 |
+
|
358 |
+
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
359 |
+
"""
|
360 |
+
Compute the mean for the previous step, given a function cond_fn that
|
361 |
+
computes the gradient of a conditional log probability with respect to
|
362 |
+
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
|
363 |
+
condition on y.
|
364 |
+
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
|
365 |
+
"""
|
366 |
+
gradient = cond_fn(x, t, **model_kwargs)
|
367 |
+
new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
368 |
+
return new_mean
|
369 |
+
|
370 |
+
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
371 |
+
"""
|
372 |
+
Compute what the p_mean_variance output would have been, should the
|
373 |
+
model's score function be conditioned by cond_fn.
|
374 |
+
See condition_mean() for details on cond_fn.
|
375 |
+
Unlike condition_mean(), this instead uses the conditioning strategy
|
376 |
+
from Song et al (2020).
|
377 |
+
"""
|
378 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
379 |
+
|
380 |
+
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
|
381 |
+
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
|
382 |
+
|
383 |
+
out = p_mean_var.copy()
|
384 |
+
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
|
385 |
+
out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
|
386 |
+
return out
|
387 |
+
|
388 |
+
def p_sample(
|
389 |
+
self,
|
390 |
+
model,
|
391 |
+
x,
|
392 |
+
t,
|
393 |
+
clip_denoised=True,
|
394 |
+
denoised_fn=None,
|
395 |
+
cond_fn=None,
|
396 |
+
model_kwargs=None,
|
397 |
+
mask=None,
|
398 |
+
x_start=None,
|
399 |
+
use_concat=False
|
400 |
+
):
|
401 |
+
"""
|
402 |
+
Sample x_{t-1} from the model at the given timestep.
|
403 |
+
:param model: the model to sample from.
|
404 |
+
:param x: the current tensor at x_{t-1}.
|
405 |
+
:param t: the value of t, starting at 0 for the first diffusion step.
|
406 |
+
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
|
407 |
+
:param denoised_fn: if not None, a function which applies to the
|
408 |
+
x_start prediction before it is used to sample.
|
409 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
410 |
+
similarly to the model.
|
411 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
412 |
+
pass to the model. This can be used for conditioning.
|
413 |
+
:return: a dict containing the following keys:
|
414 |
+
- 'sample': a random sample from the model.
|
415 |
+
- 'pred_xstart': a prediction of x_0.
|
416 |
+
"""
|
417 |
+
out = self.p_mean_variance(
|
418 |
+
model,
|
419 |
+
x,
|
420 |
+
t,
|
421 |
+
clip_denoised=clip_denoised,
|
422 |
+
denoised_fn=denoised_fn,
|
423 |
+
model_kwargs=model_kwargs,
|
424 |
+
mask=mask,
|
425 |
+
x_start=x_start,
|
426 |
+
use_concat=use_concat
|
427 |
+
)
|
428 |
+
noise = th.randn_like(x)
|
429 |
+
nonzero_mask = (
|
430 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
431 |
+
) # no noise when t == 0
|
432 |
+
if cond_fn is not None:
|
433 |
+
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
434 |
+
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
435 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
436 |
+
|
437 |
+
def p_sample_loop(
|
438 |
+
self,
|
439 |
+
model,
|
440 |
+
shape,
|
441 |
+
noise=None,
|
442 |
+
clip_denoised=True,
|
443 |
+
denoised_fn=None,
|
444 |
+
cond_fn=None,
|
445 |
+
model_kwargs=None,
|
446 |
+
device=None,
|
447 |
+
progress=False,
|
448 |
+
mask=None,
|
449 |
+
x_start=None,
|
450 |
+
use_concat=False,
|
451 |
+
):
|
452 |
+
"""
|
453 |
+
Generate samples from the model.
|
454 |
+
:param model: the model module.
|
455 |
+
:param shape: the shape of the samples, (N, C, H, W).
|
456 |
+
:param noise: if specified, the noise from the encoder to sample.
|
457 |
+
Should be of the same shape as `shape`.
|
458 |
+
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
|
459 |
+
:param denoised_fn: if not None, a function which applies to the
|
460 |
+
x_start prediction before it is used to sample.
|
461 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
462 |
+
similarly to the model.
|
463 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
464 |
+
pass to the model. This can be used for conditioning.
|
465 |
+
:param device: if specified, the device to create the samples on.
|
466 |
+
If not specified, use a model parameter's device.
|
467 |
+
:param progress: if True, show a tqdm progress bar.
|
468 |
+
:return: a non-differentiable batch of samples.
|
469 |
+
"""
|
470 |
+
final = None
|
471 |
+
for sample in self.p_sample_loop_progressive(
|
472 |
+
model,
|
473 |
+
shape,
|
474 |
+
noise=noise,
|
475 |
+
clip_denoised=clip_denoised,
|
476 |
+
denoised_fn=denoised_fn,
|
477 |
+
cond_fn=cond_fn,
|
478 |
+
model_kwargs=model_kwargs,
|
479 |
+
device=device,
|
480 |
+
progress=progress,
|
481 |
+
mask=mask,
|
482 |
+
x_start=x_start,
|
483 |
+
use_concat=use_concat
|
484 |
+
):
|
485 |
+
final = sample
|
486 |
+
return final["sample"]
|
487 |
+
|
488 |
+
def p_sample_loop_progressive(
|
489 |
+
self,
|
490 |
+
model,
|
491 |
+
shape,
|
492 |
+
noise=None,
|
493 |
+
clip_denoised=True,
|
494 |
+
denoised_fn=None,
|
495 |
+
cond_fn=None,
|
496 |
+
model_kwargs=None,
|
497 |
+
device=None,
|
498 |
+
progress=False,
|
499 |
+
mask=None,
|
500 |
+
x_start=None,
|
501 |
+
use_concat=False
|
502 |
+
):
|
503 |
+
"""
|
504 |
+
Generate samples from the model and yield intermediate samples from
|
505 |
+
each timestep of diffusion.
|
506 |
+
Arguments are the same as p_sample_loop().
|
507 |
+
Returns a generator over dicts, where each dict is the return value of
|
508 |
+
p_sample().
|
509 |
+
"""
|
510 |
+
if device is None:
|
511 |
+
device = next(model.parameters()).device
|
512 |
+
assert isinstance(shape, (tuple, list))
|
513 |
+
if noise is not None:
|
514 |
+
img = noise
|
515 |
+
else:
|
516 |
+
img = th.randn(*shape, device=device)
|
517 |
+
indices = list(range(self.num_timesteps))[::-1]
|
518 |
+
|
519 |
+
if progress:
|
520 |
+
# Lazy import so that we don't depend on tqdm.
|
521 |
+
from tqdm.auto import tqdm
|
522 |
+
|
523 |
+
indices = tqdm(indices)
|
524 |
+
|
525 |
+
for i in indices:
|
526 |
+
t = th.tensor([i] * shape[0], device=device)
|
527 |
+
with th.no_grad():
|
528 |
+
out = self.p_sample(
|
529 |
+
model,
|
530 |
+
img,
|
531 |
+
t,
|
532 |
+
clip_denoised=clip_denoised,
|
533 |
+
denoised_fn=denoised_fn,
|
534 |
+
cond_fn=cond_fn,
|
535 |
+
model_kwargs=model_kwargs,
|
536 |
+
mask=mask,
|
537 |
+
x_start=x_start,
|
538 |
+
use_concat=use_concat
|
539 |
+
)
|
540 |
+
yield out
|
541 |
+
img = out["sample"]
|
542 |
+
|
543 |
+
def ddim_sample(
|
544 |
+
self,
|
545 |
+
model,
|
546 |
+
x,
|
547 |
+
t,
|
548 |
+
clip_denoised=True,
|
549 |
+
denoised_fn=None,
|
550 |
+
cond_fn=None,
|
551 |
+
model_kwargs=None,
|
552 |
+
eta=0.0,
|
553 |
+
mask=None,
|
554 |
+
x_start=None,
|
555 |
+
use_concat=False
|
556 |
+
):
|
557 |
+
"""
|
558 |
+
Sample x_{t-1} from the model using DDIM.
|
559 |
+
Same usage as p_sample().
|
560 |
+
"""
|
561 |
+
out = self.p_mean_variance(
|
562 |
+
model,
|
563 |
+
x,
|
564 |
+
t,
|
565 |
+
clip_denoised=clip_denoised,
|
566 |
+
denoised_fn=denoised_fn,
|
567 |
+
model_kwargs=model_kwargs,
|
568 |
+
mask=mask,
|
569 |
+
x_start=x_start,
|
570 |
+
use_concat=use_concat
|
571 |
+
)
|
572 |
+
if cond_fn is not None:
|
573 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
574 |
+
|
575 |
+
# Usually our model outputs epsilon, but we re-derive it
|
576 |
+
# in case we used x_start or x_prev prediction.
|
577 |
+
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
|
578 |
+
|
579 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
580 |
+
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
581 |
+
sigma = (
|
582 |
+
eta
|
583 |
+
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
584 |
+
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
585 |
+
)
|
586 |
+
# Equation 12.
|
587 |
+
noise = th.randn_like(x)
|
588 |
+
mean_pred = (
|
589 |
+
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
|
590 |
+
+ th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
|
591 |
+
)
|
592 |
+
nonzero_mask = (
|
593 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
594 |
+
) # no noise when t == 0
|
595 |
+
sample = mean_pred + nonzero_mask * sigma * noise
|
596 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
597 |
+
|
598 |
+
def ddim_reverse_sample(
|
599 |
+
self,
|
600 |
+
model,
|
601 |
+
x,
|
602 |
+
t,
|
603 |
+
clip_denoised=True,
|
604 |
+
denoised_fn=None,
|
605 |
+
cond_fn=None,
|
606 |
+
model_kwargs=None,
|
607 |
+
eta=0.0,
|
608 |
+
):
|
609 |
+
"""
|
610 |
+
Sample x_{t+1} from the model using DDIM reverse ODE.
|
611 |
+
"""
|
612 |
+
assert eta == 0.0, "Reverse ODE only for deterministic path"
|
613 |
+
out = self.p_mean_variance(
|
614 |
+
model,
|
615 |
+
x,
|
616 |
+
t,
|
617 |
+
clip_denoised=clip_denoised,
|
618 |
+
denoised_fn=denoised_fn,
|
619 |
+
model_kwargs=model_kwargs,
|
620 |
+
)
|
621 |
+
if cond_fn is not None:
|
622 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
623 |
+
# Usually our model outputs epsilon, but we re-derive it
|
624 |
+
# in case we used x_start or x_prev prediction.
|
625 |
+
eps = (
|
626 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
627 |
+
- out["pred_xstart"]
|
628 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
|
629 |
+
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
|
630 |
+
|
631 |
+
# Equation 12. reversed
|
632 |
+
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
|
633 |
+
|
634 |
+
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
|
635 |
+
|
636 |
+
def ddim_sample_loop(
|
637 |
+
self,
|
638 |
+
model,
|
639 |
+
shape,
|
640 |
+
noise=None,
|
641 |
+
clip_denoised=True,
|
642 |
+
denoised_fn=None,
|
643 |
+
cond_fn=None,
|
644 |
+
model_kwargs=None,
|
645 |
+
device=None,
|
646 |
+
progress=False,
|
647 |
+
eta=0.0,
|
648 |
+
mask=None,
|
649 |
+
x_start=None,
|
650 |
+
use_concat=False
|
651 |
+
):
|
652 |
+
"""
|
653 |
+
Generate samples from the model using DDIM.
|
654 |
+
Same usage as p_sample_loop().
|
655 |
+
"""
|
656 |
+
final = None
|
657 |
+
for sample in self.ddim_sample_loop_progressive(
|
658 |
+
model,
|
659 |
+
shape,
|
660 |
+
noise=noise,
|
661 |
+
clip_denoised=clip_denoised,
|
662 |
+
denoised_fn=denoised_fn,
|
663 |
+
cond_fn=cond_fn,
|
664 |
+
model_kwargs=model_kwargs,
|
665 |
+
device=device,
|
666 |
+
progress=progress,
|
667 |
+
eta=eta,
|
668 |
+
mask=mask,
|
669 |
+
x_start=x_start,
|
670 |
+
use_concat=use_concat
|
671 |
+
):
|
672 |
+
final = sample
|
673 |
+
return final["sample"]
|
674 |
+
|
675 |
+
def ddim_sample_loop_progressive(
|
676 |
+
self,
|
677 |
+
model,
|
678 |
+
shape,
|
679 |
+
noise=None,
|
680 |
+
clip_denoised=True,
|
681 |
+
denoised_fn=None,
|
682 |
+
cond_fn=None,
|
683 |
+
model_kwargs=None,
|
684 |
+
device=None,
|
685 |
+
progress=False,
|
686 |
+
eta=0.0,
|
687 |
+
mask=None,
|
688 |
+
x_start=None,
|
689 |
+
use_concat=False
|
690 |
+
):
|
691 |
+
"""
|
692 |
+
Use DDIM to sample from the model and yield intermediate samples from
|
693 |
+
each timestep of DDIM.
|
694 |
+
Same usage as p_sample_loop_progressive().
|
695 |
+
"""
|
696 |
+
if device is None:
|
697 |
+
device = next(model.parameters()).device
|
698 |
+
assert isinstance(shape, (tuple, list))
|
699 |
+
if noise is not None:
|
700 |
+
img = noise
|
701 |
+
else:
|
702 |
+
img = th.randn(*shape, device=device)
|
703 |
+
indices = list(range(self.num_timesteps))[::-1]
|
704 |
+
|
705 |
+
if progress:
|
706 |
+
# Lazy import so that we don't depend on tqdm.
|
707 |
+
from tqdm.auto import tqdm
|
708 |
+
|
709 |
+
indices = tqdm(indices)
|
710 |
+
|
711 |
+
for i in indices:
|
712 |
+
t = th.tensor([i] * shape[0], device=device)
|
713 |
+
with th.no_grad():
|
714 |
+
out = self.ddim_sample(
|
715 |
+
model,
|
716 |
+
img,
|
717 |
+
t,
|
718 |
+
clip_denoised=clip_denoised,
|
719 |
+
denoised_fn=denoised_fn,
|
720 |
+
cond_fn=cond_fn,
|
721 |
+
model_kwargs=model_kwargs,
|
722 |
+
eta=eta,
|
723 |
+
mask=mask,
|
724 |
+
x_start=x_start,
|
725 |
+
use_concat=use_concat
|
726 |
+
)
|
727 |
+
yield out
|
728 |
+
img = out["sample"]
|
729 |
+
|
730 |
+
def _vb_terms_bpd(
|
731 |
+
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
|
732 |
+
):
|
733 |
+
"""
|
734 |
+
Get a term for the variational lower-bound.
|
735 |
+
The resulting units are bits (rather than nats, as one might expect).
|
736 |
+
This allows for comparison to other papers.
|
737 |
+
:return: a dict with the following keys:
|
738 |
+
- 'output': a shape [N] tensor of NLLs or KLs.
|
739 |
+
- 'pred_xstart': the x_0 predictions.
|
740 |
+
"""
|
741 |
+
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
|
742 |
+
x_start=x_start, x_t=x_t, t=t
|
743 |
+
)
|
744 |
+
out = self.p_mean_variance(
|
745 |
+
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
|
746 |
+
)
|
747 |
+
kl = normal_kl(
|
748 |
+
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
|
749 |
+
)
|
750 |
+
kl = mean_flat(kl) / np.log(2.0)
|
751 |
+
|
752 |
+
decoder_nll = -discretized_gaussian_log_likelihood(
|
753 |
+
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
|
754 |
+
)
|
755 |
+
assert decoder_nll.shape == x_start.shape
|
756 |
+
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
|
757 |
+
|
758 |
+
# At the first timestep return the decoder NLL,
|
759 |
+
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
760 |
+
output = th.where((t == 0), decoder_nll, kl)
|
761 |
+
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
762 |
+
|
763 |
+
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, use_mask=False):
|
764 |
+
"""
|
765 |
+
Compute training losses for a single timestep.
|
766 |
+
:param model: the model to evaluate loss on.
|
767 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
768 |
+
:param t: a batch of timestep indices.
|
769 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
770 |
+
pass to the model. This can be used for conditioning.
|
771 |
+
:param noise: if specified, the specific Gaussian noise to try to remove.
|
772 |
+
:return: a dict with the key "loss" containing a tensor of shape [N].
|
773 |
+
Some mean or variance settings may also have other keys.
|
774 |
+
"""
|
775 |
+
if model_kwargs is None:
|
776 |
+
model_kwargs = {}
|
777 |
+
if noise is None:
|
778 |
+
noise = th.randn_like(x_start)
|
779 |
+
x_t = self.q_sample(x_start, t, noise=noise)
|
780 |
+
if use_mask:
|
781 |
+
x_t = th.cat([x_t[:, :4], x_start[:, 4:]], dim=1)
|
782 |
+
terms = {}
|
783 |
+
|
784 |
+
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
785 |
+
terms["loss"] = self._vb_terms_bpd(
|
786 |
+
model=model,
|
787 |
+
x_start=x_start,
|
788 |
+
x_t=x_t,
|
789 |
+
t=t,
|
790 |
+
clip_denoised=False,
|
791 |
+
model_kwargs=model_kwargs,
|
792 |
+
)["output"]
|
793 |
+
if self.loss_type == LossType.RESCALED_KL:
|
794 |
+
terms["loss"] *= self.num_timesteps
|
795 |
+
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
796 |
+
model_output = model(x_t, t, **model_kwargs)
|
797 |
+
try:
|
798 |
+
# model_output = model(x_t, t, **model_kwargs).sample
|
799 |
+
model_output = model_output.sample # for tav unet
|
800 |
+
except:
|
801 |
+
pass
|
802 |
+
# model_output = model(x_t, t, **model_kwargs)
|
803 |
+
|
804 |
+
if self.model_var_type in [
|
805 |
+
ModelVarType.LEARNED,
|
806 |
+
ModelVarType.LEARNED_RANGE,
|
807 |
+
]:
|
808 |
+
B, F, C = x_t.shape[:3]
|
809 |
+
assert model_output.shape == (B, F, C * 2, *x_t.shape[3:])
|
810 |
+
model_output, model_var_values = th.split(model_output, C, dim=2)
|
811 |
+
# Learn the variance using the variational bound, but don't let
|
812 |
+
# it affect our mean prediction.
|
813 |
+
frozen_out = th.cat([model_output.detach(), model_var_values], dim=2)
|
814 |
+
terms["vb"] = self._vb_terms_bpd(
|
815 |
+
model=lambda *args, r=frozen_out: r,
|
816 |
+
x_start=x_start,
|
817 |
+
x_t=x_t,
|
818 |
+
t=t,
|
819 |
+
clip_denoised=False,
|
820 |
+
)["output"]
|
821 |
+
if self.loss_type == LossType.RESCALED_MSE:
|
822 |
+
# Divide by 1000 for equivalence with initial implementation.
|
823 |
+
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
824 |
+
terms["vb"] *= self.num_timesteps / 1000.0
|
825 |
+
|
826 |
+
target = {
|
827 |
+
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
|
828 |
+
x_start=x_start, x_t=x_t, t=t
|
829 |
+
)[0],
|
830 |
+
ModelMeanType.START_X: x_start,
|
831 |
+
ModelMeanType.EPSILON: noise,
|
832 |
+
}[self.model_mean_type]
|
833 |
+
# assert model_output.shape == target.shape == x_start.shape
|
834 |
+
if use_mask:
|
835 |
+
terms["mse"] = mean_flat((target[:,:4] - model_output) ** 2)
|
836 |
+
else:
|
837 |
+
terms["mse"] = mean_flat((target - model_output) ** 2)
|
838 |
+
if "vb" in terms:
|
839 |
+
terms["loss"] = terms["mse"] + terms["vb"]
|
840 |
+
else:
|
841 |
+
terms["loss"] = terms["mse"]
|
842 |
+
else:
|
843 |
+
raise NotImplementedError(self.loss_type)
|
844 |
+
|
845 |
+
return terms
|
846 |
+
|
847 |
+
def _prior_bpd(self, x_start):
|
848 |
+
"""
|
849 |
+
Get the prior KL term for the variational lower-bound, measured in
|
850 |
+
bits-per-dim.
|
851 |
+
This term can't be optimized, as it only depends on the encoder.
|
852 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
853 |
+
:return: a batch of [N] KL values (in bits), one per batch element.
|
854 |
+
"""
|
855 |
+
batch_size = x_start.shape[0]
|
856 |
+
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
857 |
+
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
858 |
+
kl_prior = normal_kl(
|
859 |
+
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
|
860 |
+
)
|
861 |
+
return mean_flat(kl_prior) / np.log(2.0)
|
862 |
+
|
863 |
+
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
|
864 |
+
"""
|
865 |
+
Compute the entire variational lower-bound, measured in bits-per-dim,
|
866 |
+
as well as other related quantities.
|
867 |
+
:param model: the model to evaluate loss on.
|
868 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
869 |
+
:param clip_denoised: if True, clip denoised samples.
|
870 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
871 |
+
pass to the model. This can be used for conditioning.
|
872 |
+
:return: a dict containing the following keys:
|
873 |
+
- total_bpd: the total variational lower-bound, per batch element.
|
874 |
+
- prior_bpd: the prior term in the lower-bound.
|
875 |
+
- vb: an [N x T] tensor of terms in the lower-bound.
|
876 |
+
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
|
877 |
+
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
|
878 |
+
"""
|
879 |
+
device = x_start.device
|
880 |
+
batch_size = x_start.shape[0]
|
881 |
+
|
882 |
+
vb = []
|
883 |
+
xstart_mse = []
|
884 |
+
mse = []
|
885 |
+
for t in list(range(self.num_timesteps))[::-1]:
|
886 |
+
t_batch = th.tensor([t] * batch_size, device=device)
|
887 |
+
noise = th.randn_like(x_start)
|
888 |
+
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
|
889 |
+
# Calculate VLB term at the current timestep
|
890 |
+
with th.no_grad():
|
891 |
+
out = self._vb_terms_bpd(
|
892 |
+
model,
|
893 |
+
x_start=x_start,
|
894 |
+
x_t=x_t,
|
895 |
+
t=t_batch,
|
896 |
+
clip_denoised=clip_denoised,
|
897 |
+
model_kwargs=model_kwargs,
|
898 |
+
)
|
899 |
+
vb.append(out["output"])
|
900 |
+
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
|
901 |
+
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
|
902 |
+
mse.append(mean_flat((eps - noise) ** 2))
|
903 |
+
|
904 |
+
vb = th.stack(vb, dim=1)
|
905 |
+
xstart_mse = th.stack(xstart_mse, dim=1)
|
906 |
+
mse = th.stack(mse, dim=1)
|
907 |
+
|
908 |
+
prior_bpd = self._prior_bpd(x_start)
|
909 |
+
total_bpd = vb.sum(dim=1) + prior_bpd
|
910 |
+
return {
|
911 |
+
"total_bpd": total_bpd,
|
912 |
+
"prior_bpd": prior_bpd,
|
913 |
+
"vb": vb,
|
914 |
+
"xstart_mse": xstart_mse,
|
915 |
+
"mse": mse,
|
916 |
+
}
|
917 |
+
|
918 |
+
|
919 |
+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
920 |
+
"""
|
921 |
+
Extract values from a 1-D numpy array for a batch of indices.
|
922 |
+
:param arr: the 1-D numpy array.
|
923 |
+
:param timesteps: a tensor of indices into the array to extract.
|
924 |
+
:param broadcast_shape: a larger shape of K dimensions with the batch
|
925 |
+
dimension equal to the length of timesteps.
|
926 |
+
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
927 |
+
"""
|
928 |
+
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
929 |
+
while len(res.shape) < len(broadcast_shape):
|
930 |
+
res = res[..., None]
|
931 |
+
return res + th.zeros(broadcast_shape, device=timesteps.device)
|
diffusion/respace.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import torch as th
|
8 |
+
|
9 |
+
from .gaussian_diffusion import GaussianDiffusion
|
10 |
+
|
11 |
+
|
12 |
+
def space_timesteps(num_timesteps, section_counts):
|
13 |
+
"""
|
14 |
+
Create a list of timesteps to use from an original diffusion process,
|
15 |
+
given the number of timesteps we want to take from equally-sized portions
|
16 |
+
of the original process.
|
17 |
+
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
18 |
+
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
19 |
+
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
20 |
+
If the stride is a string starting with "ddim", then the fixed striding
|
21 |
+
from the DDIM paper is used, and only one section is allowed.
|
22 |
+
:param num_timesteps: the number of diffusion steps in the original
|
23 |
+
process to divide up.
|
24 |
+
:param section_counts: either a list of numbers, or a string containing
|
25 |
+
comma-separated numbers, indicating the step count
|
26 |
+
per section. As a special case, use "ddimN" where N
|
27 |
+
is a number of steps to use the striding from the
|
28 |
+
DDIM paper.
|
29 |
+
:return: a set of diffusion steps from the original process to use.
|
30 |
+
"""
|
31 |
+
if isinstance(section_counts, str):
|
32 |
+
if section_counts.startswith("ddim"):
|
33 |
+
desired_count = int(section_counts[len("ddim") :])
|
34 |
+
for i in range(1, num_timesteps):
|
35 |
+
if len(range(0, num_timesteps, i)) == desired_count:
|
36 |
+
return set(range(0, num_timesteps, i))
|
37 |
+
raise ValueError(
|
38 |
+
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
39 |
+
)
|
40 |
+
section_counts = [int(x) for x in section_counts.split(",")]
|
41 |
+
size_per = num_timesteps // len(section_counts)
|
42 |
+
extra = num_timesteps % len(section_counts)
|
43 |
+
start_idx = 0
|
44 |
+
all_steps = []
|
45 |
+
for i, section_count in enumerate(section_counts):
|
46 |
+
size = size_per + (1 if i < extra else 0)
|
47 |
+
if size < section_count:
|
48 |
+
raise ValueError(
|
49 |
+
f"cannot divide section of {size} steps into {section_count}"
|
50 |
+
)
|
51 |
+
if section_count <= 1:
|
52 |
+
frac_stride = 1
|
53 |
+
else:
|
54 |
+
frac_stride = (size - 1) / (section_count - 1)
|
55 |
+
cur_idx = 0.0
|
56 |
+
taken_steps = []
|
57 |
+
for _ in range(section_count):
|
58 |
+
taken_steps.append(start_idx + round(cur_idx))
|
59 |
+
cur_idx += frac_stride
|
60 |
+
all_steps += taken_steps
|
61 |
+
start_idx += size
|
62 |
+
return set(all_steps)
|
63 |
+
|
64 |
+
|
65 |
+
class SpacedDiffusion(GaussianDiffusion):
|
66 |
+
"""
|
67 |
+
A diffusion process which can skip steps in a base diffusion process.
|
68 |
+
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
69 |
+
original diffusion process to retain.
|
70 |
+
:param kwargs: the kwargs to create the base diffusion process.
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(self, use_timesteps, **kwargs):
|
74 |
+
self.use_timesteps = set(use_timesteps)
|
75 |
+
self.timestep_map = []
|
76 |
+
self.original_num_steps = len(kwargs["betas"])
|
77 |
+
|
78 |
+
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
|
79 |
+
last_alpha_cumprod = 1.0
|
80 |
+
new_betas = []
|
81 |
+
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
82 |
+
if i in self.use_timesteps:
|
83 |
+
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
84 |
+
last_alpha_cumprod = alpha_cumprod
|
85 |
+
self.timestep_map.append(i)
|
86 |
+
kwargs["betas"] = np.array(new_betas)
|
87 |
+
super().__init__(**kwargs)
|
88 |
+
|
89 |
+
def p_mean_variance(
|
90 |
+
self, model, *args, **kwargs
|
91 |
+
): # pylint: disable=signature-differs
|
92 |
+
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
|
93 |
+
|
94 |
+
# @torch.compile
|
95 |
+
def training_losses(
|
96 |
+
self, model, *args, **kwargs
|
97 |
+
): # pylint: disable=signature-differs
|
98 |
+
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
99 |
+
|
100 |
+
def condition_mean(self, cond_fn, *args, **kwargs):
|
101 |
+
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
102 |
+
|
103 |
+
def condition_score(self, cond_fn, *args, **kwargs):
|
104 |
+
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
|
105 |
+
|
106 |
+
def _wrap_model(self, model):
|
107 |
+
if isinstance(model, _WrappedModel):
|
108 |
+
return model
|
109 |
+
return _WrappedModel(
|
110 |
+
model, self.timestep_map, self.original_num_steps
|
111 |
+
)
|
112 |
+
|
113 |
+
def _scale_timesteps(self, t):
|
114 |
+
# Scaling is done by the wrapped model.
|
115 |
+
return t
|
116 |
+
|
117 |
+
|
118 |
+
class _WrappedModel:
|
119 |
+
def __init__(self, model, timestep_map, original_num_steps):
|
120 |
+
self.model = model
|
121 |
+
self.timestep_map = timestep_map
|
122 |
+
# self.rescale_timesteps = rescale_timesteps
|
123 |
+
self.original_num_steps = original_num_steps
|
124 |
+
|
125 |
+
def __call__(self, x, ts, **kwargs):
|
126 |
+
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
|
127 |
+
new_ts = map_tensor[ts]
|
128 |
+
# if self.rescale_timesteps:
|
129 |
+
# new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
130 |
+
return self.model(x, new_ts, **kwargs)
|
diffusion/timestep_sampler.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch as th
|
10 |
+
import torch.distributed as dist
|
11 |
+
|
12 |
+
|
13 |
+
def create_named_schedule_sampler(name, diffusion):
|
14 |
+
"""
|
15 |
+
Create a ScheduleSampler from a library of pre-defined samplers.
|
16 |
+
:param name: the name of the sampler.
|
17 |
+
:param diffusion: the diffusion object to sample for.
|
18 |
+
"""
|
19 |
+
if name == "uniform":
|
20 |
+
return UniformSampler(diffusion)
|
21 |
+
elif name == "loss-second-moment":
|
22 |
+
return LossSecondMomentResampler(diffusion)
|
23 |
+
else:
|
24 |
+
raise NotImplementedError(f"unknown schedule sampler: {name}")
|
25 |
+
|
26 |
+
|
27 |
+
class ScheduleSampler(ABC):
|
28 |
+
"""
|
29 |
+
A distribution over timesteps in the diffusion process, intended to reduce
|
30 |
+
variance of the objective.
|
31 |
+
By default, samplers perform unbiased importance sampling, in which the
|
32 |
+
objective's mean is unchanged.
|
33 |
+
However, subclasses may override sample() to change how the resampled
|
34 |
+
terms are reweighted, allowing for actual changes in the objective.
|
35 |
+
"""
|
36 |
+
|
37 |
+
@abstractmethod
|
38 |
+
def weights(self):
|
39 |
+
"""
|
40 |
+
Get a numpy array of weights, one per diffusion step.
|
41 |
+
The weights needn't be normalized, but must be positive.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def sample(self, batch_size, device):
|
45 |
+
"""
|
46 |
+
Importance-sample timesteps for a batch.
|
47 |
+
:param batch_size: the number of timesteps.
|
48 |
+
:param device: the torch device to save to.
|
49 |
+
:return: a tuple (timesteps, weights):
|
50 |
+
- timesteps: a tensor of timestep indices.
|
51 |
+
- weights: a tensor of weights to scale the resulting losses.
|
52 |
+
"""
|
53 |
+
w = self.weights()
|
54 |
+
p = w / np.sum(w)
|
55 |
+
indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
|
56 |
+
indices = th.from_numpy(indices_np).long().to(device)
|
57 |
+
weights_np = 1 / (len(p) * p[indices_np])
|
58 |
+
weights = th.from_numpy(weights_np).float().to(device)
|
59 |
+
return indices, weights
|
60 |
+
|
61 |
+
|
62 |
+
class UniformSampler(ScheduleSampler):
|
63 |
+
def __init__(self, diffusion):
|
64 |
+
self.diffusion = diffusion
|
65 |
+
self._weights = np.ones([diffusion.num_timesteps])
|
66 |
+
|
67 |
+
def weights(self):
|
68 |
+
return self._weights
|
69 |
+
|
70 |
+
|
71 |
+
class LossAwareSampler(ScheduleSampler):
|
72 |
+
def update_with_local_losses(self, local_ts, local_losses):
|
73 |
+
"""
|
74 |
+
Update the reweighting using losses from a model.
|
75 |
+
Call this method from each rank with a batch of timesteps and the
|
76 |
+
corresponding losses for each of those timesteps.
|
77 |
+
This method will perform synchronization to make sure all of the ranks
|
78 |
+
maintain the exact same reweighting.
|
79 |
+
:param local_ts: an integer Tensor of timesteps.
|
80 |
+
:param local_losses: a 1D Tensor of losses.
|
81 |
+
"""
|
82 |
+
batch_sizes = [
|
83 |
+
th.tensor([0], dtype=th.int32, device=local_ts.device)
|
84 |
+
for _ in range(dist.get_world_size())
|
85 |
+
]
|
86 |
+
dist.all_gather(
|
87 |
+
batch_sizes,
|
88 |
+
th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
|
89 |
+
)
|
90 |
+
|
91 |
+
# Pad all_gather batches to be the maximum batch size.
|
92 |
+
batch_sizes = [x.item() for x in batch_sizes]
|
93 |
+
max_bs = max(batch_sizes)
|
94 |
+
|
95 |
+
timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
|
96 |
+
loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
|
97 |
+
dist.all_gather(timestep_batches, local_ts)
|
98 |
+
dist.all_gather(loss_batches, local_losses)
|
99 |
+
timesteps = [
|
100 |
+
x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
|
101 |
+
]
|
102 |
+
losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
|
103 |
+
self.update_with_all_losses(timesteps, losses)
|
104 |
+
|
105 |
+
@abstractmethod
|
106 |
+
def update_with_all_losses(self, ts, losses):
|
107 |
+
"""
|
108 |
+
Update the reweighting using losses from a model.
|
109 |
+
Sub-classes should override this method to update the reweighting
|
110 |
+
using losses from the model.
|
111 |
+
This method directly updates the reweighting without synchronizing
|
112 |
+
between workers. It is called by update_with_local_losses from all
|
113 |
+
ranks with identical arguments. Thus, it should have deterministic
|
114 |
+
behavior to maintain state across workers.
|
115 |
+
:param ts: a list of int timesteps.
|
116 |
+
:param losses: a list of float losses, one per timestep.
|
117 |
+
"""
|
118 |
+
|
119 |
+
|
120 |
+
class LossSecondMomentResampler(LossAwareSampler):
|
121 |
+
def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
|
122 |
+
self.diffusion = diffusion
|
123 |
+
self.history_per_term = history_per_term
|
124 |
+
self.uniform_prob = uniform_prob
|
125 |
+
self._loss_history = np.zeros(
|
126 |
+
[diffusion.num_timesteps, history_per_term], dtype=np.float64
|
127 |
+
)
|
128 |
+
self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
|
129 |
+
|
130 |
+
def weights(self):
|
131 |
+
if not self._warmed_up():
|
132 |
+
return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
|
133 |
+
weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
|
134 |
+
weights /= np.sum(weights)
|
135 |
+
weights *= 1 - self.uniform_prob
|
136 |
+
weights += self.uniform_prob / len(weights)
|
137 |
+
return weights
|
138 |
+
|
139 |
+
def update_with_all_losses(self, ts, losses):
|
140 |
+
for t, loss in zip(ts, losses):
|
141 |
+
if self._loss_counts[t] == self.history_per_term:
|
142 |
+
# Shift out the oldest loss term.
|
143 |
+
self._loss_history[t, :-1] = self._loss_history[t, 1:]
|
144 |
+
self._loss_history[t, -1] = loss
|
145 |
+
else:
|
146 |
+
self._loss_history[t, self._loss_counts[t]] = loss
|
147 |
+
self._loss_counts[t] += 1
|
148 |
+
|
149 |
+
def _warmed_up(self):
|
150 |
+
return (self._loss_counts == self.history_per_term).all()
|
input/i2v/A_big_drop_of_water_falls_on_a_rose_petal.png
ADDED
Git LFS Details
|
input/i2v/A_fish_swims_past_an_oriental_woman.png
ADDED
Git LFS Details
|
input/i2v/Cinematic_photograph_View_of_piloting_aaero.png
ADDED
Git LFS Details
|
input/i2v/Planet_hits_earth.png
ADDED
Git LFS Details
|
input/i2v/Underwater_environment_cosmetic_bottles.png
ADDED
Git LFS Details
|
models/__init__.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
sys.path.append(os.path.split(sys.path[0])[0])
|
4 |
+
|
5 |
+
from .unet import UNet3DConditionModel
|
6 |
+
from torch.optim.lr_scheduler import LambdaLR
|
7 |
+
|
8 |
+
def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit
|
9 |
+
from torch.optim.lr_scheduler import LambdaLR
|
10 |
+
def fn(step):
|
11 |
+
if warmup_steps > 0:
|
12 |
+
return min(step / warmup_steps, 1)
|
13 |
+
else:
|
14 |
+
return 1
|
15 |
+
return LambdaLR(optimizer, fn)
|
16 |
+
|
17 |
+
|
18 |
+
def get_lr_scheduler(optimizer, name, **kwargs):
|
19 |
+
if name == 'warmup':
|
20 |
+
return customized_lr_scheduler(optimizer, **kwargs)
|
21 |
+
elif name == 'cosine':
|
22 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
23 |
+
return CosineAnnealingLR(optimizer, **kwargs)
|
24 |
+
else:
|
25 |
+
raise NotImplementedError(name)
|
26 |
+
|
27 |
+
def get_models(args):
|
28 |
+
if 'UNet' in args.model:
|
29 |
+
pretrained_model_path = args.pretrained_model_path
|
30 |
+
return UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", use_concat=args.use_mask)
|
31 |
+
else:
|
32 |
+
raise '{} Model Not Supported!'.format(args.model)
|
33 |
+
|
models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (1.29 kB). View file
|
|
models/__pycache__/attention.cpython-310.pyc
ADDED
Binary file (22.3 kB). View file
|
|
models/__pycache__/clip.cpython-310.pyc
ADDED
Binary file (3.65 kB). View file
|
|
models/__pycache__/resnet.cpython-310.pyc
ADDED
Binary file (5.17 kB). View file
|
|
models/__pycache__/unet.cpython-310.pyc
ADDED
Binary file (17.5 kB). View file
|
|
models/__pycache__/unet_blocks.cpython-310.pyc
ADDED
Binary file (12.4 kB). View file
|
|
models/attention.py
ADDED
@@ -0,0 +1,966 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
sys.path.append(os.path.split(sys.path[0])[0])
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
import math
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch import nn
|
12 |
+
from copy import deepcopy
|
13 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
14 |
+
from diffusers.utils import BaseOutput
|
15 |
+
from diffusers.utils.import_utils import is_xformers_available
|
16 |
+
from diffusers.models.attention import FeedForward, AdaLayerNorm
|
17 |
+
from rotary_embedding_torch import RotaryEmbedding
|
18 |
+
from typing import Callable, Optional
|
19 |
+
from einops import rearrange, repeat
|
20 |
+
|
21 |
+
try:
|
22 |
+
from diffusers.models.modeling_utils import ModelMixin
|
23 |
+
except:
|
24 |
+
from diffusers.modeling_utils import ModelMixin # 0.11.1
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class Transformer3DModelOutput(BaseOutput):
|
29 |
+
sample: torch.FloatTensor
|
30 |
+
|
31 |
+
|
32 |
+
if is_xformers_available():
|
33 |
+
import xformers
|
34 |
+
import xformers.ops
|
35 |
+
else:
|
36 |
+
xformers = None
|
37 |
+
|
38 |
+
def exists(x):
|
39 |
+
return x is not None
|
40 |
+
|
41 |
+
|
42 |
+
class CrossAttention(nn.Module):
|
43 |
+
r"""
|
44 |
+
copy from diffuser 0.11.1
|
45 |
+
A cross attention layer.
|
46 |
+
Parameters:
|
47 |
+
query_dim (`int`): The number of channels in the query.
|
48 |
+
cross_attention_dim (`int`, *optional*):
|
49 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
50 |
+
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
51 |
+
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
52 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
53 |
+
bias (`bool`, *optional*, defaults to False):
|
54 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
query_dim: int,
|
60 |
+
cross_attention_dim: Optional[int] = None,
|
61 |
+
heads: int = 8,
|
62 |
+
dim_head: int = 64,
|
63 |
+
dropout: float = 0.0,
|
64 |
+
bias=False,
|
65 |
+
upcast_attention: bool = False,
|
66 |
+
upcast_softmax: bool = False,
|
67 |
+
added_kv_proj_dim: Optional[int] = None,
|
68 |
+
norm_num_groups: Optional[int] = None,
|
69 |
+
use_relative_position: bool = False,
|
70 |
+
):
|
71 |
+
super().__init__()
|
72 |
+
# print('num head', heads)
|
73 |
+
inner_dim = dim_head * heads
|
74 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
75 |
+
self.upcast_attention = upcast_attention
|
76 |
+
self.upcast_softmax = upcast_softmax
|
77 |
+
|
78 |
+
self.scale = dim_head**-0.5
|
79 |
+
|
80 |
+
self.heads = heads
|
81 |
+
self.dim_head = dim_head
|
82 |
+
# for slice_size > 0 the attention score computation
|
83 |
+
# is split across the batch axis to save memory
|
84 |
+
# You can set slice_size with `set_attention_slice`
|
85 |
+
self.sliceable_head_dim = heads
|
86 |
+
self._slice_size = None
|
87 |
+
self._use_memory_efficient_attention_xformers = False
|
88 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
89 |
+
|
90 |
+
if norm_num_groups is not None:
|
91 |
+
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
|
92 |
+
else:
|
93 |
+
self.group_norm = None
|
94 |
+
|
95 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
96 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
97 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
98 |
+
|
99 |
+
if self.added_kv_proj_dim is not None:
|
100 |
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
101 |
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
102 |
+
|
103 |
+
self.to_out = nn.ModuleList([])
|
104 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
105 |
+
self.to_out.append(nn.Dropout(dropout))
|
106 |
+
|
107 |
+
# print(use_relative_position)
|
108 |
+
self.use_relative_position = use_relative_position
|
109 |
+
if self.use_relative_position:
|
110 |
+
self.rotary_emb = RotaryEmbedding(min(32, dim_head))
|
111 |
+
|
112 |
+
self.ip_transformed = False
|
113 |
+
self.ip_scale = 1
|
114 |
+
|
115 |
+
def ip_transform(self):
|
116 |
+
if self.ip_transformed is not True:
|
117 |
+
self.ip_to_k = deepcopy(self.to_k).to(next(self.parameters()).device)
|
118 |
+
self.ip_to_v = deepcopy(self.to_v).to(next(self.parameters()).device)
|
119 |
+
self.ip_transformed = True
|
120 |
+
|
121 |
+
def ip_train_set(self):
|
122 |
+
if self.ip_transformed is True:
|
123 |
+
self.ip_to_k.requires_grad_(True)
|
124 |
+
self.ip_to_v.requires_grad_(True)
|
125 |
+
|
126 |
+
def set_scale(self, scale):
|
127 |
+
self.ip_scale = scale
|
128 |
+
|
129 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
130 |
+
batch_size, seq_len, dim = tensor.shape
|
131 |
+
head_size = self.heads
|
132 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
133 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
134 |
+
return tensor
|
135 |
+
|
136 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
137 |
+
batch_size, seq_len, dim = tensor.shape
|
138 |
+
head_size = self.heads
|
139 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
140 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
141 |
+
return tensor
|
142 |
+
|
143 |
+
def reshape_for_scores(self, tensor):
|
144 |
+
# split heads and dims
|
145 |
+
# tensor should be [b (h w)] f (d nd)
|
146 |
+
batch_size, seq_len, dim = tensor.shape
|
147 |
+
head_size = self.heads
|
148 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
149 |
+
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
150 |
+
return tensor
|
151 |
+
|
152 |
+
def same_batch_dim_to_heads(self, tensor):
|
153 |
+
batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d
|
154 |
+
tensor = tensor.reshape(batch_size, seq_len, dim * head_size)
|
155 |
+
return tensor
|
156 |
+
|
157 |
+
def set_attention_slice(self, slice_size):
|
158 |
+
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
159 |
+
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
160 |
+
|
161 |
+
self._slice_size = slice_size
|
162 |
+
|
163 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None, ip_hidden_states=None):
|
164 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
165 |
+
|
166 |
+
encoder_hidden_states = encoder_hidden_states
|
167 |
+
|
168 |
+
if self.group_norm is not None:
|
169 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
170 |
+
|
171 |
+
query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
|
172 |
+
|
173 |
+
dim = query.shape[-1]
|
174 |
+
if not self.use_relative_position:
|
175 |
+
query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d
|
176 |
+
|
177 |
+
if self.added_kv_proj_dim is not None:
|
178 |
+
key = self.to_k(hidden_states)
|
179 |
+
value = self.to_v(hidden_states)
|
180 |
+
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
|
181 |
+
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
|
182 |
+
|
183 |
+
key = self.reshape_heads_to_batch_dim(key)
|
184 |
+
value = self.reshape_heads_to_batch_dim(value)
|
185 |
+
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
|
186 |
+
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
|
187 |
+
|
188 |
+
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
|
189 |
+
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
|
190 |
+
else:
|
191 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
192 |
+
key = self.to_k(encoder_hidden_states)
|
193 |
+
value = self.to_v(encoder_hidden_states)
|
194 |
+
|
195 |
+
if not self.use_relative_position:
|
196 |
+
key = self.reshape_heads_to_batch_dim(key)
|
197 |
+
value = self.reshape_heads_to_batch_dim(value)
|
198 |
+
|
199 |
+
if self.ip_transformed is True and ip_hidden_states is not None:
|
200 |
+
# print(ip_hidden_states.dtype)
|
201 |
+
# print(self.ip_to_k.weight.dtype)
|
202 |
+
ip_key = self.ip_to_k(ip_hidden_states)
|
203 |
+
ip_value = self.ip_to_v(ip_hidden_states)
|
204 |
+
|
205 |
+
if not self.use_relative_position:
|
206 |
+
ip_key = self.reshape_heads_to_batch_dim(ip_key)
|
207 |
+
ip_value = self.reshape_heads_to_batch_dim(ip_value)
|
208 |
+
|
209 |
+
if attention_mask is not None:
|
210 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
211 |
+
target_length = query.shape[1]
|
212 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
213 |
+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
214 |
+
|
215 |
+
# attention, what we cannot get enough of
|
216 |
+
if self._use_memory_efficient_attention_xformers:
|
217 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
218 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
219 |
+
hidden_states = hidden_states.to(query.dtype)
|
220 |
+
|
221 |
+
if self.ip_transformed is True and ip_hidden_states is not None:
|
222 |
+
ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, attention_mask)
|
223 |
+
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
224 |
+
|
225 |
+
else:
|
226 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
227 |
+
hidden_states = self._attention(query, key, value, attention_mask)
|
228 |
+
|
229 |
+
if self.ip_transformed is True and ip_hidden_states is not None:
|
230 |
+
ip_hidden_states = self._attention(query, ip_key, ip_value, attention_mask)
|
231 |
+
else:
|
232 |
+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
233 |
+
|
234 |
+
if self.ip_transformed is True and ip_hidden_states is not None:
|
235 |
+
ip_hidden_states = self._sliced_attention(query, ip_key, ip_value, sequence_length, dim, attention_mask)
|
236 |
+
|
237 |
+
if self.ip_transformed is True and ip_hidden_states is not None:
|
238 |
+
hidden_states = hidden_states + self.ip_scale * ip_hidden_states
|
239 |
+
|
240 |
+
# linear proj
|
241 |
+
hidden_states = self.to_out[0](hidden_states)
|
242 |
+
|
243 |
+
# dropout
|
244 |
+
hidden_states = self.to_out[1](hidden_states)
|
245 |
+
return hidden_states
|
246 |
+
|
247 |
+
|
248 |
+
def _attention(self, query, key, value, attention_mask=None):
|
249 |
+
if self.upcast_attention:
|
250 |
+
query = query.float()
|
251 |
+
key = key.float()
|
252 |
+
|
253 |
+
attention_scores = torch.baddbmm(
|
254 |
+
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
255 |
+
query,
|
256 |
+
key.transpose(-1, -2),
|
257 |
+
beta=0,
|
258 |
+
alpha=self.scale,
|
259 |
+
)
|
260 |
+
|
261 |
+
if attention_mask is not None:
|
262 |
+
attention_scores = attention_scores + attention_mask
|
263 |
+
|
264 |
+
if self.upcast_softmax:
|
265 |
+
attention_scores = attention_scores.float()
|
266 |
+
|
267 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
268 |
+
attention_probs = attention_probs.to(value.dtype)
|
269 |
+
hidden_states = torch.bmm(attention_probs, value)
|
270 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
271 |
+
return hidden_states
|
272 |
+
|
273 |
+
def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
|
274 |
+
batch_size_attention = query.shape[0]
|
275 |
+
hidden_states = torch.zeros(
|
276 |
+
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
|
277 |
+
)
|
278 |
+
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
|
279 |
+
for i in range(hidden_states.shape[0] // slice_size):
|
280 |
+
start_idx = i * slice_size
|
281 |
+
end_idx = (i + 1) * slice_size
|
282 |
+
|
283 |
+
query_slice = query[start_idx:end_idx]
|
284 |
+
key_slice = key[start_idx:end_idx]
|
285 |
+
|
286 |
+
if self.upcast_attention:
|
287 |
+
query_slice = query_slice.float()
|
288 |
+
key_slice = key_slice.float()
|
289 |
+
|
290 |
+
attn_slice = torch.baddbmm(
|
291 |
+
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
|
292 |
+
query_slice,
|
293 |
+
key_slice.transpose(-1, -2),
|
294 |
+
beta=0,
|
295 |
+
alpha=self.scale,
|
296 |
+
)
|
297 |
+
|
298 |
+
if attention_mask is not None:
|
299 |
+
attn_slice = attn_slice + attention_mask[start_idx:end_idx]
|
300 |
+
|
301 |
+
if self.upcast_softmax:
|
302 |
+
attn_slice = attn_slice.float()
|
303 |
+
|
304 |
+
attn_slice = attn_slice.softmax(dim=-1)
|
305 |
+
|
306 |
+
# cast back to the original dtype
|
307 |
+
attn_slice = attn_slice.to(value.dtype)
|
308 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
309 |
+
|
310 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
311 |
+
|
312 |
+
# reshape hidden_states
|
313 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
314 |
+
return hidden_states
|
315 |
+
|
316 |
+
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
|
317 |
+
# TODO attention_mask
|
318 |
+
query = query.contiguous()
|
319 |
+
key = key.contiguous()
|
320 |
+
value = value.contiguous()
|
321 |
+
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
322 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
323 |
+
return hidden_states
|
324 |
+
|
325 |
+
|
326 |
+
class Transformer3DModel(ModelMixin, ConfigMixin):
|
327 |
+
@register_to_config
|
328 |
+
def __init__(
|
329 |
+
self,
|
330 |
+
num_attention_heads: int = 16,
|
331 |
+
attention_head_dim: int = 88,
|
332 |
+
in_channels: Optional[int] = None,
|
333 |
+
num_layers: int = 1,
|
334 |
+
dropout: float = 0.0,
|
335 |
+
norm_num_groups: int = 32,
|
336 |
+
cross_attention_dim: Optional[int] = None,
|
337 |
+
attention_bias: bool = False,
|
338 |
+
activation_fn: str = "geglu",
|
339 |
+
num_embeds_ada_norm: Optional[int] = None,
|
340 |
+
use_linear_projection: bool = False,
|
341 |
+
only_cross_attention: bool = False,
|
342 |
+
upcast_attention: bool = False,
|
343 |
+
use_first_frame: bool = False,
|
344 |
+
use_relative_position: bool = False,
|
345 |
+
rotary_emb: bool = None,
|
346 |
+
):
|
347 |
+
super().__init__()
|
348 |
+
self.use_linear_projection = use_linear_projection
|
349 |
+
self.num_attention_heads = num_attention_heads
|
350 |
+
self.attention_head_dim = attention_head_dim
|
351 |
+
inner_dim = num_attention_heads * attention_head_dim
|
352 |
+
|
353 |
+
# Define input layers
|
354 |
+
self.in_channels = in_channels
|
355 |
+
|
356 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
357 |
+
if use_linear_projection:
|
358 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
359 |
+
else:
|
360 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
361 |
+
|
362 |
+
# Define transformers blocks
|
363 |
+
self.transformer_blocks = nn.ModuleList(
|
364 |
+
[
|
365 |
+
BasicTransformerBlock(
|
366 |
+
inner_dim,
|
367 |
+
num_attention_heads,
|
368 |
+
attention_head_dim,
|
369 |
+
dropout=dropout,
|
370 |
+
cross_attention_dim=cross_attention_dim,
|
371 |
+
activation_fn=activation_fn,
|
372 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
373 |
+
attention_bias=attention_bias,
|
374 |
+
only_cross_attention=only_cross_attention,
|
375 |
+
upcast_attention=upcast_attention,
|
376 |
+
use_first_frame=use_first_frame,
|
377 |
+
use_relative_position=use_relative_position,
|
378 |
+
rotary_emb=rotary_emb,
|
379 |
+
)
|
380 |
+
for d in range(num_layers)
|
381 |
+
]
|
382 |
+
)
|
383 |
+
|
384 |
+
# 4. Define output layers
|
385 |
+
if use_linear_projection:
|
386 |
+
self.proj_out = nn.Linear(in_channels, inner_dim)
|
387 |
+
else:
|
388 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
389 |
+
|
390 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, use_image_num=None, return_dict: bool = True, ip_hidden_states=None, encoder_temporal_hidden_states=None):
|
391 |
+
# Input
|
392 |
+
# if ip_hidden_states is not None:
|
393 |
+
# ip_hidden_states = ip_hidden_states.to(dtype=encoder_hidden_states.dtype)
|
394 |
+
# print(ip_hidden_states.shape)
|
395 |
+
# print(encoder_hidden_states.shape)
|
396 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
397 |
+
if self.training:
|
398 |
+
video_length = hidden_states.shape[2] - use_image_num
|
399 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
|
400 |
+
encoder_hidden_states_length = encoder_hidden_states.shape[1]
|
401 |
+
encoder_hidden_states_video = encoder_hidden_states[:, :encoder_hidden_states_length - use_image_num, ...]
|
402 |
+
encoder_hidden_states_video = repeat(encoder_hidden_states_video, 'b m n c -> b (m f) n c', f=video_length).contiguous()
|
403 |
+
encoder_hidden_states_image = encoder_hidden_states[:, encoder_hidden_states_length - use_image_num:, ...]
|
404 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1)
|
405 |
+
encoder_hidden_states = rearrange(encoder_hidden_states, 'b m n c -> (b m) n c').contiguous()
|
406 |
+
|
407 |
+
if ip_hidden_states is not None:
|
408 |
+
ip_hidden_states_length = ip_hidden_states.shape[1]
|
409 |
+
ip_hidden_states_video = ip_hidden_states[:, :ip_hidden_states_length - use_image_num, ...]
|
410 |
+
ip_hidden_states_video = repeat(ip_hidden_states_video, 'b m n c -> b (m f) n c', f=video_length).contiguous()
|
411 |
+
ip_hidden_states_image = ip_hidden_states[:, ip_hidden_states_length - use_image_num:, ...]
|
412 |
+
ip_hidden_states = torch.cat([ip_hidden_states_video, ip_hidden_states_image], dim=1)
|
413 |
+
ip_hidden_states = rearrange(ip_hidden_states, 'b m n c -> (b m) n c').contiguous()
|
414 |
+
|
415 |
+
else:
|
416 |
+
video_length = hidden_states.shape[2]
|
417 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
|
418 |
+
encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length).contiguous()
|
419 |
+
|
420 |
+
if encoder_temporal_hidden_states is not None:
|
421 |
+
encoder_temporal_hidden_states = repeat(encoder_temporal_hidden_states, 'b n c -> (b f) n c', f=video_length).contiguous()
|
422 |
+
|
423 |
+
if ip_hidden_states is not None:
|
424 |
+
ip_hidden_states = repeat(ip_hidden_states, 'b 1 n c -> (b f) n c', f=video_length).contiguous()
|
425 |
+
|
426 |
+
batch, channel, height, weight = hidden_states.shape
|
427 |
+
residual = hidden_states
|
428 |
+
|
429 |
+
hidden_states = self.norm(hidden_states)
|
430 |
+
if not self.use_linear_projection:
|
431 |
+
hidden_states = self.proj_in(hidden_states)
|
432 |
+
inner_dim = hidden_states.shape[1]
|
433 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
434 |
+
else:
|
435 |
+
inner_dim = hidden_states.shape[1]
|
436 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
437 |
+
hidden_states = self.proj_in(hidden_states)
|
438 |
+
|
439 |
+
# Blocks
|
440 |
+
for block in self.transformer_blocks:
|
441 |
+
hidden_states = block(
|
442 |
+
hidden_states,
|
443 |
+
encoder_hidden_states=encoder_hidden_states,
|
444 |
+
timestep=timestep,
|
445 |
+
video_length=video_length,
|
446 |
+
use_image_num=use_image_num,
|
447 |
+
ip_hidden_states=ip_hidden_states,
|
448 |
+
encoder_temporal_hidden_states=encoder_temporal_hidden_states
|
449 |
+
)
|
450 |
+
|
451 |
+
# Output
|
452 |
+
if not self.use_linear_projection:
|
453 |
+
hidden_states = (
|
454 |
+
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
455 |
+
)
|
456 |
+
hidden_states = self.proj_out(hidden_states)
|
457 |
+
else:
|
458 |
+
hidden_states = self.proj_out(hidden_states)
|
459 |
+
hidden_states = (
|
460 |
+
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
461 |
+
)
|
462 |
+
|
463 |
+
output = hidden_states + residual
|
464 |
+
|
465 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length + use_image_num).contiguous()
|
466 |
+
if not return_dict:
|
467 |
+
return (output,)
|
468 |
+
|
469 |
+
return Transformer3DModelOutput(sample=output)
|
470 |
+
|
471 |
+
|
472 |
+
class BasicTransformerBlock(nn.Module):
|
473 |
+
def __init__(
|
474 |
+
self,
|
475 |
+
dim: int,
|
476 |
+
num_attention_heads: int,
|
477 |
+
attention_head_dim: int,
|
478 |
+
dropout=0.0,
|
479 |
+
cross_attention_dim: Optional[int] = None,
|
480 |
+
activation_fn: str = "geglu",
|
481 |
+
num_embeds_ada_norm: Optional[int] = None,
|
482 |
+
attention_bias: bool = False,
|
483 |
+
only_cross_attention: bool = False,
|
484 |
+
upcast_attention: bool = False,
|
485 |
+
use_first_frame: bool = False,
|
486 |
+
use_relative_position: bool = False,
|
487 |
+
rotary_emb: bool = False,
|
488 |
+
):
|
489 |
+
super().__init__()
|
490 |
+
self.only_cross_attention = only_cross_attention
|
491 |
+
# print(only_cross_attention)
|
492 |
+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
493 |
+
# print(self.use_ada_layer_norm)
|
494 |
+
self.use_first_frame = use_first_frame
|
495 |
+
|
496 |
+
self.dim = dim
|
497 |
+
self.cross_attention_dim = cross_attention_dim
|
498 |
+
self.num_attention_heads = num_attention_heads
|
499 |
+
self.attention_head_dim = attention_head_dim
|
500 |
+
self.dropout = dropout
|
501 |
+
self.attention_bias = attention_bias
|
502 |
+
self.upcast_attention = upcast_attention
|
503 |
+
|
504 |
+
# Spatial-Attn
|
505 |
+
self.attn1 = CrossAttention(
|
506 |
+
query_dim=dim,
|
507 |
+
heads=num_attention_heads,
|
508 |
+
dim_head=attention_head_dim,
|
509 |
+
dropout=dropout,
|
510 |
+
bias=attention_bias,
|
511 |
+
cross_attention_dim=None,
|
512 |
+
upcast_attention=upcast_attention,
|
513 |
+
)
|
514 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
515 |
+
|
516 |
+
# Text Cross-Attn
|
517 |
+
if cross_attention_dim is not None:
|
518 |
+
self.attn2 = CrossAttention(
|
519 |
+
query_dim=dim,
|
520 |
+
cross_attention_dim=cross_attention_dim,
|
521 |
+
heads=num_attention_heads,
|
522 |
+
dim_head=attention_head_dim,
|
523 |
+
dropout=dropout,
|
524 |
+
bias=attention_bias,
|
525 |
+
upcast_attention=upcast_attention,
|
526 |
+
)
|
527 |
+
else:
|
528 |
+
self.attn2 = None
|
529 |
+
|
530 |
+
if cross_attention_dim is not None:
|
531 |
+
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
532 |
+
else:
|
533 |
+
self.norm2 = None
|
534 |
+
|
535 |
+
# Temp
|
536 |
+
self.attn_temp = TemporalAttention(
|
537 |
+
query_dim=dim,
|
538 |
+
heads=num_attention_heads,
|
539 |
+
dim_head=attention_head_dim,
|
540 |
+
dropout=dropout,
|
541 |
+
bias=attention_bias,
|
542 |
+
cross_attention_dim=None,
|
543 |
+
upcast_attention=upcast_attention,
|
544 |
+
rotary_emb=rotary_emb,
|
545 |
+
)
|
546 |
+
self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
547 |
+
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
|
548 |
+
|
549 |
+
# Feed-forward
|
550 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
551 |
+
self.norm3 = nn.LayerNorm(dim)
|
552 |
+
|
553 |
+
self.tca_transformed = False
|
554 |
+
|
555 |
+
def tca_transform(self):
|
556 |
+
if self.tca_transformed is not True:
|
557 |
+
self.cross_attn_temp = CrossAttention(
|
558 |
+
query_dim=self.dim * 16,
|
559 |
+
cross_attention_dim=self.cross_attention_dim,
|
560 |
+
heads=self.num_attention_heads,
|
561 |
+
dim_head=self.attention_head_dim,
|
562 |
+
dropout=self.dropout,
|
563 |
+
bias=self.attention_bias,
|
564 |
+
upcast_attention=self.upcast_attention,
|
565 |
+
)
|
566 |
+
self.cross_norm_temp = AdaLayerNorm(self.dim * 16, self.num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(self.dim * 16)
|
567 |
+
nn.init.zeros_(self.cross_attn_temp.to_out[0].weight.data)
|
568 |
+
self.tca_transformed = True
|
569 |
+
|
570 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, op=None):
|
571 |
+
|
572 |
+
if not is_xformers_available():
|
573 |
+
print("Here is how to install it")
|
574 |
+
raise ModuleNotFoundError(
|
575 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
576 |
+
" xformers",
|
577 |
+
name="xformers",
|
578 |
+
)
|
579 |
+
elif not torch.cuda.is_available():
|
580 |
+
raise ValueError(
|
581 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
|
582 |
+
" available for GPU "
|
583 |
+
)
|
584 |
+
else:
|
585 |
+
try:
|
586 |
+
# Make sure we can run the memory efficient attention
|
587 |
+
_ = xformers.ops.memory_efficient_attention(
|
588 |
+
torch.randn((1, 2, 40), device="cuda"),
|
589 |
+
torch.randn((1, 2, 40), device="cuda"),
|
590 |
+
torch.randn((1, 2, 40), device="cuda"),
|
591 |
+
)
|
592 |
+
except Exception as e:
|
593 |
+
raise e
|
594 |
+
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
595 |
+
if self.attn2 is not None:
|
596 |
+
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
597 |
+
|
598 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None, use_image_num=None, ip_hidden_states=None, encoder_temporal_hidden_states=None):
|
599 |
+
# SparseCausal-Attention
|
600 |
+
norm_hidden_states = (
|
601 |
+
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
602 |
+
)
|
603 |
+
|
604 |
+
if self.only_cross_attention:
|
605 |
+
hidden_states = (
|
606 |
+
self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
|
607 |
+
)
|
608 |
+
else:
|
609 |
+
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, use_image_num=use_image_num) + hidden_states
|
610 |
+
|
611 |
+
if self.attn2 is not None:
|
612 |
+
# Cross-Attention
|
613 |
+
norm_hidden_states = (
|
614 |
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
615 |
+
)
|
616 |
+
hidden_states = (
|
617 |
+
self.attn2(
|
618 |
+
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, ip_hidden_states=ip_hidden_states
|
619 |
+
)
|
620 |
+
+ hidden_states
|
621 |
+
)
|
622 |
+
|
623 |
+
# Temporal Attention
|
624 |
+
if self.training:
|
625 |
+
d = hidden_states.shape[1]
|
626 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous()
|
627 |
+
hidden_states_video = hidden_states[:, :video_length, :]
|
628 |
+
hidden_states_image = hidden_states[:, video_length:, :]
|
629 |
+
norm_hidden_states_video = (
|
630 |
+
self.norm_temp(hidden_states_video, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states_video)
|
631 |
+
)
|
632 |
+
hidden_states_video = self.attn_temp(norm_hidden_states_video) + hidden_states_video
|
633 |
+
|
634 |
+
# Temporal Cross Attention
|
635 |
+
if self.tca_transformed is True:
|
636 |
+
hidden_states_video = rearrange(hidden_states_video, "(b d) f c -> b d (f c)", d=d).contiguous()
|
637 |
+
norm_hidden_states_video = (
|
638 |
+
self.cross_norm_temp(hidden_states_video, timestep) if self.use_ada_layer_norm else self.cross_norm_temp(hidden_states_video)
|
639 |
+
)
|
640 |
+
temp_encoder_hidden_states = rearrange(encoder_hidden_states, "(b f) d c -> b f d c", f=video_length + use_image_num).contiguous()
|
641 |
+
temp_encoder_hidden_states = temp_encoder_hidden_states[:, 0:1].squeeze(dim=1)
|
642 |
+
hidden_states_video = self.cross_attn_temp(norm_hidden_states_video, encoder_hidden_states=temp_encoder_hidden_states, attention_mask=attention_mask) + hidden_states_video
|
643 |
+
hidden_states_video = rearrange(hidden_states_video, "b d (f c) -> (b d) f c", f=video_length).contiguous()
|
644 |
+
|
645 |
+
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
|
646 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous()
|
647 |
+
else:
|
648 |
+
d = hidden_states.shape[1]
|
649 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous()
|
650 |
+
norm_hidden_states = (
|
651 |
+
self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
|
652 |
+
)
|
653 |
+
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
|
654 |
+
|
655 |
+
# Temporal Cross Attention
|
656 |
+
if self.tca_transformed is True:
|
657 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> b d (f c)", d=d).contiguous()
|
658 |
+
norm_hidden_states = (
|
659 |
+
self.cross_norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.cross_norm_temp(hidden_states)
|
660 |
+
)
|
661 |
+
if encoder_temporal_hidden_states is not None:
|
662 |
+
encoder_hidden_states = encoder_temporal_hidden_states
|
663 |
+
temp_encoder_hidden_states = rearrange(encoder_hidden_states, "(b f) d c -> b f d c", f=video_length + use_image_num).contiguous()
|
664 |
+
temp_encoder_hidden_states = temp_encoder_hidden_states[:, 0:1].squeeze(dim=1)
|
665 |
+
hidden_states = self.cross_attn_temp(norm_hidden_states, encoder_hidden_states=temp_encoder_hidden_states, attention_mask=attention_mask) + hidden_states
|
666 |
+
hidden_states = rearrange(hidden_states, "b d (f c) -> (b f) d c", f=video_length + use_image_num, d=d).contiguous()
|
667 |
+
else:
|
668 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous()
|
669 |
+
|
670 |
+
# Feed-forward
|
671 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
672 |
+
|
673 |
+
return hidden_states
|
674 |
+
|
675 |
+
|
676 |
+
class SparseCausalAttention(CrossAttention):
|
677 |
+
def forward_video(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
678 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
679 |
+
|
680 |
+
encoder_hidden_states = encoder_hidden_states
|
681 |
+
|
682 |
+
if self.group_norm is not None:
|
683 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
684 |
+
|
685 |
+
query = self.to_q(hidden_states)
|
686 |
+
dim = query.shape[-1]
|
687 |
+
query = self.reshape_heads_to_batch_dim(query)
|
688 |
+
|
689 |
+
if self.added_kv_proj_dim is not None:
|
690 |
+
raise NotImplementedError
|
691 |
+
|
692 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
693 |
+
key = self.to_k(encoder_hidden_states)
|
694 |
+
value = self.to_v(encoder_hidden_states)
|
695 |
+
|
696 |
+
former_frame_index = torch.arange(video_length) - 1
|
697 |
+
former_frame_index[0] = 0
|
698 |
+
|
699 |
+
key = rearrange(key, "(b f) d c -> b f d c", f=video_length).contiguous()
|
700 |
+
key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2)
|
701 |
+
key = rearrange(key, "b f d c -> (b f) d c").contiguous()
|
702 |
+
|
703 |
+
value = rearrange(value, "(b f) d c -> b f d c", f=video_length).contiguous()
|
704 |
+
value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2)
|
705 |
+
value = rearrange(value, "b f d c -> (b f) d c").contiguous()
|
706 |
+
|
707 |
+
key = self.reshape_heads_to_batch_dim(key)
|
708 |
+
value = self.reshape_heads_to_batch_dim(value)
|
709 |
+
|
710 |
+
if attention_mask is not None:
|
711 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
712 |
+
target_length = query.shape[1]
|
713 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
714 |
+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
715 |
+
|
716 |
+
# attention, what we cannot get enough of
|
717 |
+
if self._use_memory_efficient_attention_xformers:
|
718 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
719 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
720 |
+
hidden_states = hidden_states.to(query.dtype)
|
721 |
+
else:
|
722 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
723 |
+
hidden_states = self._attention(query, key, value, attention_mask)
|
724 |
+
else:
|
725 |
+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
726 |
+
|
727 |
+
# linear proj
|
728 |
+
hidden_states = self.to_out[0](hidden_states)
|
729 |
+
|
730 |
+
# dropout
|
731 |
+
hidden_states = self.to_out[1](hidden_states)
|
732 |
+
return hidden_states
|
733 |
+
|
734 |
+
def forward_image(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None):
|
735 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
736 |
+
|
737 |
+
encoder_hidden_states = encoder_hidden_states
|
738 |
+
|
739 |
+
if self.group_norm is not None:
|
740 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
741 |
+
|
742 |
+
query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
|
743 |
+
dim = query.shape[-1]
|
744 |
+
if not self.use_relative_position:
|
745 |
+
query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d
|
746 |
+
|
747 |
+
if self.added_kv_proj_dim is not None:
|
748 |
+
key = self.to_k(hidden_states)
|
749 |
+
value = self.to_v(hidden_states)
|
750 |
+
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
|
751 |
+
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
|
752 |
+
|
753 |
+
key = self.reshape_heads_to_batch_dim(key)
|
754 |
+
value = self.reshape_heads_to_batch_dim(value)
|
755 |
+
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
|
756 |
+
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
|
757 |
+
|
758 |
+
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
|
759 |
+
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
|
760 |
+
else:
|
761 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
762 |
+
key = self.to_k(encoder_hidden_states)
|
763 |
+
value = self.to_v(encoder_hidden_states)
|
764 |
+
|
765 |
+
if not self.use_relative_position:
|
766 |
+
key = self.reshape_heads_to_batch_dim(key)
|
767 |
+
value = self.reshape_heads_to_batch_dim(value)
|
768 |
+
|
769 |
+
if attention_mask is not None:
|
770 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
771 |
+
target_length = query.shape[1]
|
772 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
773 |
+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
774 |
+
|
775 |
+
# attention, what we cannot get enough of
|
776 |
+
if self._use_memory_efficient_attention_xformers:
|
777 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
778 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
779 |
+
hidden_states = hidden_states.to(query.dtype)
|
780 |
+
else:
|
781 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
782 |
+
hidden_states = self._attention(query, key, value, attention_mask)
|
783 |
+
else:
|
784 |
+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
785 |
+
|
786 |
+
# linear proj
|
787 |
+
hidden_states = self.to_out[0](hidden_states)
|
788 |
+
|
789 |
+
# dropout
|
790 |
+
hidden_states = self.to_out[1](hidden_states)
|
791 |
+
return hidden_states
|
792 |
+
|
793 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, use_image_num=None):
|
794 |
+
if self.training:
|
795 |
+
# print(use_image_num)
|
796 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> b f d c", f=video_length + use_image_num).contiguous()
|
797 |
+
hidden_states_video = hidden_states[:, :video_length, ...]
|
798 |
+
hidden_states_image = hidden_states[:, video_length:, ...]
|
799 |
+
hidden_states_video = rearrange(hidden_states_video, 'b f d c -> (b f) d c').contiguous()
|
800 |
+
hidden_states_image = rearrange(hidden_states_image, 'b f d c -> (b f) d c').contiguous()
|
801 |
+
hidden_states_video = self.forward_video(hidden_states=hidden_states_video,
|
802 |
+
encoder_hidden_states=encoder_hidden_states,
|
803 |
+
attention_mask=attention_mask,
|
804 |
+
video_length=video_length)
|
805 |
+
hidden_states_image = self.forward_image(hidden_states=hidden_states_image,
|
806 |
+
encoder_hidden_states=encoder_hidden_states,
|
807 |
+
attention_mask=attention_mask)
|
808 |
+
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=0)
|
809 |
+
return hidden_states
|
810 |
+
# exit()
|
811 |
+
else:
|
812 |
+
return self.forward_video(hidden_states=hidden_states,
|
813 |
+
encoder_hidden_states=encoder_hidden_states,
|
814 |
+
attention_mask=attention_mask,
|
815 |
+
video_length=video_length)
|
816 |
+
|
817 |
+
class TemporalAttention(CrossAttention):
|
818 |
+
def __init__(self,
|
819 |
+
query_dim: int,
|
820 |
+
cross_attention_dim: Optional[int] = None,
|
821 |
+
heads: int = 8,
|
822 |
+
dim_head: int = 64,
|
823 |
+
dropout: float = 0.0,
|
824 |
+
bias=False,
|
825 |
+
upcast_attention: bool = False,
|
826 |
+
upcast_softmax: bool = False,
|
827 |
+
added_kv_proj_dim: Optional[int] = None,
|
828 |
+
norm_num_groups: Optional[int] = None,
|
829 |
+
rotary_emb=None):
|
830 |
+
super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, added_kv_proj_dim, norm_num_groups)
|
831 |
+
# relative time positional embeddings
|
832 |
+
self.time_rel_pos_bias = RelativePositionBias(heads=heads, max_distance=32) # realistically will not be able to generate that many frames of video... yet
|
833 |
+
self.rotary_emb = rotary_emb
|
834 |
+
|
835 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
836 |
+
time_rel_pos_bias = self.time_rel_pos_bias(hidden_states.shape[1], device=hidden_states.device)
|
837 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
838 |
+
|
839 |
+
encoder_hidden_states = encoder_hidden_states
|
840 |
+
|
841 |
+
if self.group_norm is not None:
|
842 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
843 |
+
|
844 |
+
query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
|
845 |
+
dim = query.shape[-1]
|
846 |
+
|
847 |
+
if self.added_kv_proj_dim is not None:
|
848 |
+
key = self.to_k(hidden_states)
|
849 |
+
value = self.to_v(hidden_states)
|
850 |
+
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
|
851 |
+
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
|
852 |
+
|
853 |
+
key = self.reshape_heads_to_batch_dim(key)
|
854 |
+
value = self.reshape_heads_to_batch_dim(value)
|
855 |
+
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
|
856 |
+
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
|
857 |
+
|
858 |
+
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
|
859 |
+
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
|
860 |
+
else:
|
861 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
862 |
+
key = self.to_k(encoder_hidden_states)
|
863 |
+
value = self.to_v(encoder_hidden_states)
|
864 |
+
|
865 |
+
if attention_mask is not None:
|
866 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
867 |
+
target_length = query.shape[1]
|
868 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
869 |
+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
870 |
+
|
871 |
+
# attention, what we cannot get enough of
|
872 |
+
if self._use_memory_efficient_attention_xformers:
|
873 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
874 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
875 |
+
hidden_states = hidden_states.to(query.dtype)
|
876 |
+
else:
|
877 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
878 |
+
hidden_states = self._attention(query, key, value, attention_mask, time_rel_pos_bias)
|
879 |
+
else:
|
880 |
+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
881 |
+
|
882 |
+
# linear proj
|
883 |
+
hidden_states = self.to_out[0](hidden_states)
|
884 |
+
|
885 |
+
# dropout
|
886 |
+
hidden_states = self.to_out[1](hidden_states)
|
887 |
+
return hidden_states
|
888 |
+
|
889 |
+
|
890 |
+
def _attention(self, query, key, value, attention_mask=None, time_rel_pos_bias=None):
|
891 |
+
if self.upcast_attention:
|
892 |
+
query = query.float()
|
893 |
+
key = key.float()
|
894 |
+
|
895 |
+
query = self.scale * rearrange(query, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
|
896 |
+
key = rearrange(key, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
|
897 |
+
value = rearrange(value, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
|
898 |
+
|
899 |
+
# torch.baddbmm only accepte 3-D tensor
|
900 |
+
# https://runebook.dev/zh/docs/pytorch/generated/torch.baddbmm
|
901 |
+
# attention_scores = self.scale * torch.matmul(query, key.transpose(-1, -2))
|
902 |
+
if exists(self.rotary_emb):
|
903 |
+
query = self.rotary_emb.rotate_queries_or_keys(query)
|
904 |
+
key = self.rotary_emb.rotate_queries_or_keys(key)
|
905 |
+
|
906 |
+
attention_scores = torch.einsum('... h i d, ... h j d -> ... h i j', query, key)
|
907 |
+
|
908 |
+
attention_scores = attention_scores + time_rel_pos_bias
|
909 |
+
|
910 |
+
if attention_mask is not None:
|
911 |
+
# add attention mask
|
912 |
+
attention_scores = attention_scores + attention_mask
|
913 |
+
|
914 |
+
# vdm
|
915 |
+
attention_scores = attention_scores - attention_scores.amax(dim = -1, keepdim = True).detach()
|
916 |
+
|
917 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
918 |
+
# print(attention_probs[0][0])
|
919 |
+
|
920 |
+
# cast back to the original dtype
|
921 |
+
attention_probs = attention_probs.to(value.dtype)
|
922 |
+
|
923 |
+
# compute attention output
|
924 |
+
hidden_states = torch.einsum('... h i j, ... h j d -> ... h i d', attention_probs, value)
|
925 |
+
hidden_states = rearrange(hidden_states, 'b h f d -> b f (h d)')
|
926 |
+
return hidden_states
|
927 |
+
|
928 |
+
class RelativePositionBias(nn.Module):
|
929 |
+
def __init__(
|
930 |
+
self,
|
931 |
+
heads=8,
|
932 |
+
num_buckets=32,
|
933 |
+
max_distance=128,
|
934 |
+
):
|
935 |
+
super().__init__()
|
936 |
+
self.num_buckets = num_buckets
|
937 |
+
self.max_distance = max_distance
|
938 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
939 |
+
|
940 |
+
@staticmethod
|
941 |
+
def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
|
942 |
+
ret = 0
|
943 |
+
n = -relative_position
|
944 |
+
|
945 |
+
num_buckets //= 2
|
946 |
+
ret += (n < 0).long() * num_buckets
|
947 |
+
n = torch.abs(n)
|
948 |
+
|
949 |
+
max_exact = num_buckets // 2
|
950 |
+
is_small = n < max_exact
|
951 |
+
|
952 |
+
val_if_large = max_exact + (
|
953 |
+
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
|
954 |
+
).long()
|
955 |
+
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
956 |
+
|
957 |
+
ret += torch.where(is_small, n, val_if_large)
|
958 |
+
return ret
|
959 |
+
|
960 |
+
def forward(self, n, device):
|
961 |
+
q_pos = torch.arange(n, dtype = torch.long, device = device)
|
962 |
+
k_pos = torch.arange(n, dtype = torch.long, device = device)
|
963 |
+
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
|
964 |
+
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
|
965 |
+
values = self.relative_attention_bias(rp_bucket)
|
966 |
+
return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames
|
models/clip.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import CLIPTokenizer, CLIPTextModel
|
4 |
+
|
5 |
+
import transformers
|
6 |
+
transformers.logging.set_verbosity_error()
|
7 |
+
|
8 |
+
"""
|
9 |
+
Will encounter following warning:
|
10 |
+
- This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task
|
11 |
+
or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
|
12 |
+
- This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model
|
13 |
+
that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
|
14 |
+
|
15 |
+
https://github.com/CompVis/stable-diffusion/issues/97
|
16 |
+
according to this issue, this warning is safe.
|
17 |
+
|
18 |
+
This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion.
|
19 |
+
You can safely ignore the warning, it is not an error.
|
20 |
+
|
21 |
+
This clip usage is from U-ViT and same with Stable Diffusion.
|
22 |
+
"""
|
23 |
+
|
24 |
+
class AbstractEncoder(nn.Module):
|
25 |
+
def __init__(self):
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
def encode(self, *args, **kwargs):
|
29 |
+
raise NotImplementedError
|
30 |
+
|
31 |
+
|
32 |
+
class FrozenCLIPEmbedder(AbstractEncoder):
|
33 |
+
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
34 |
+
# def __init__(self, version="openai/clip-vit-huge-patch14", device="cuda", max_length=77):
|
35 |
+
def __init__(self, path, device="cuda", max_length=77):
|
36 |
+
super().__init__()
|
37 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(path, subfolder="tokenizer")
|
38 |
+
self.transformer = CLIPTextModel.from_pretrained(path, subfolder='text_encoder')
|
39 |
+
self.device = device
|
40 |
+
self.max_length = max_length
|
41 |
+
self.freeze()
|
42 |
+
|
43 |
+
def freeze(self):
|
44 |
+
self.transformer = self.transformer.eval()
|
45 |
+
for param in self.parameters():
|
46 |
+
param.requires_grad = False
|
47 |
+
|
48 |
+
def forward(self, text):
|
49 |
+
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
50 |
+
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
51 |
+
tokens = batch_encoding["input_ids"].to(self.device)
|
52 |
+
outputs = self.transformer(input_ids=tokens)
|
53 |
+
|
54 |
+
z = outputs.last_hidden_state
|
55 |
+
return z
|
56 |
+
|
57 |
+
def encode(self, text):
|
58 |
+
return self(text)
|
59 |
+
|
60 |
+
|
61 |
+
class TextEmbedder(nn.Module):
|
62 |
+
"""
|
63 |
+
Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance.
|
64 |
+
"""
|
65 |
+
def __init__(self, path, dropout_prob=0.1):
|
66 |
+
super().__init__()
|
67 |
+
self.text_encodder = FrozenCLIPEmbedder(path=path)
|
68 |
+
self.dropout_prob = dropout_prob
|
69 |
+
|
70 |
+
def token_drop(self, text_prompts, force_drop_ids=None):
|
71 |
+
"""
|
72 |
+
Drops text to enable classifier-free guidance.
|
73 |
+
"""
|
74 |
+
if force_drop_ids is None:
|
75 |
+
drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob
|
76 |
+
else:
|
77 |
+
# TODO
|
78 |
+
drop_ids = force_drop_ids == 1
|
79 |
+
labels = list(numpy.where(drop_ids, "", text_prompts))
|
80 |
+
# print(labels)
|
81 |
+
return labels
|
82 |
+
|
83 |
+
def forward(self, text_prompts, train, force_drop_ids=None):
|
84 |
+
use_dropout = self.dropout_prob > 0
|
85 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
86 |
+
text_prompts = self.token_drop(text_prompts, force_drop_ids)
|
87 |
+
embeddings = self.text_encodder(text_prompts)
|
88 |
+
return embeddings
|
89 |
+
|
90 |
+
|
91 |
+
if __name__ == '__main__':
|
92 |
+
|
93 |
+
r"""
|
94 |
+
Returns:
|
95 |
+
|
96 |
+
Examples from CLIPTextModel:
|
97 |
+
|
98 |
+
```python
|
99 |
+
>>> from transformers import AutoTokenizer, CLIPTextModel
|
100 |
+
|
101 |
+
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
|
102 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
103 |
+
|
104 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
105 |
+
|
106 |
+
>>> outputs = model(**inputs)
|
107 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
108 |
+
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
|
109 |
+
```"""
|
110 |
+
|
111 |
+
import torch
|
112 |
+
|
113 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
114 |
+
|
115 |
+
text_encoder = TextEmbedder(path='/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-2-1-base',
|
116 |
+
dropout_prob=0.00001).to(device)
|
117 |
+
|
118 |
+
text_prompt = [["a photo of a cat", "a photo of a cat"], ["a photo of a dog", "a photo of a cat"], ['a photo of a dog human', "a photo of a cat"]]
|
119 |
+
# text_prompt = ('None', 'None', 'None')
|
120 |
+
output = text_encoder(text_prompts=text_prompt, train=False)
|
121 |
+
# print(output)
|
122 |
+
print(output.shape)
|
123 |
+
# print(output.shape)
|
models/resnet.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
sys.path.append(os.path.split(sys.path[0])[0])
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from einops import rearrange
|
11 |
+
|
12 |
+
|
13 |
+
class InflatedConv3d(nn.Conv2d):
|
14 |
+
def forward(self, x):
|
15 |
+
video_length = x.shape[2]
|
16 |
+
|
17 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
18 |
+
x = super().forward(x)
|
19 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
20 |
+
|
21 |
+
return x
|
22 |
+
|
23 |
+
|
24 |
+
class Upsample3D(nn.Module):
|
25 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
26 |
+
super().__init__()
|
27 |
+
self.channels = channels
|
28 |
+
self.out_channels = out_channels or channels
|
29 |
+
self.use_conv = use_conv
|
30 |
+
self.use_conv_transpose = use_conv_transpose
|
31 |
+
self.name = name
|
32 |
+
|
33 |
+
conv = None
|
34 |
+
if use_conv_transpose:
|
35 |
+
raise NotImplementedError
|
36 |
+
elif use_conv:
|
37 |
+
conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
|
38 |
+
|
39 |
+
if name == "conv":
|
40 |
+
self.conv = conv
|
41 |
+
else:
|
42 |
+
self.Conv2d_0 = conv
|
43 |
+
|
44 |
+
def forward(self, hidden_states, output_size=None):
|
45 |
+
assert hidden_states.shape[1] == self.channels
|
46 |
+
|
47 |
+
if self.use_conv_transpose:
|
48 |
+
raise NotImplementedError
|
49 |
+
|
50 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
51 |
+
dtype = hidden_states.dtype
|
52 |
+
if dtype == torch.bfloat16:
|
53 |
+
hidden_states = hidden_states.to(torch.float32)
|
54 |
+
|
55 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
56 |
+
if hidden_states.shape[0] >= 64:
|
57 |
+
hidden_states = hidden_states.contiguous()
|
58 |
+
|
59 |
+
# if `output_size` is passed we force the interpolation output
|
60 |
+
# size and do not make use of `scale_factor=2`
|
61 |
+
if output_size is None:
|
62 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
|
63 |
+
else:
|
64 |
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
65 |
+
|
66 |
+
# If the input is bfloat16, we cast back to bfloat16
|
67 |
+
if dtype == torch.bfloat16:
|
68 |
+
hidden_states = hidden_states.to(dtype)
|
69 |
+
|
70 |
+
if self.use_conv:
|
71 |
+
if self.name == "conv":
|
72 |
+
hidden_states = self.conv(hidden_states)
|
73 |
+
else:
|
74 |
+
hidden_states = self.Conv2d_0(hidden_states)
|
75 |
+
|
76 |
+
return hidden_states
|
77 |
+
|
78 |
+
|
79 |
+
class Downsample3D(nn.Module):
|
80 |
+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
81 |
+
super().__init__()
|
82 |
+
self.channels = channels
|
83 |
+
self.out_channels = out_channels or channels
|
84 |
+
self.use_conv = use_conv
|
85 |
+
self.padding = padding
|
86 |
+
stride = 2
|
87 |
+
self.name = name
|
88 |
+
|
89 |
+
if use_conv:
|
90 |
+
conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
91 |
+
else:
|
92 |
+
raise NotImplementedError
|
93 |
+
|
94 |
+
if name == "conv":
|
95 |
+
self.Conv2d_0 = conv
|
96 |
+
self.conv = conv
|
97 |
+
elif name == "Conv2d_0":
|
98 |
+
self.conv = conv
|
99 |
+
else:
|
100 |
+
self.conv = conv
|
101 |
+
|
102 |
+
def forward(self, hidden_states):
|
103 |
+
assert hidden_states.shape[1] == self.channels
|
104 |
+
if self.use_conv and self.padding == 0:
|
105 |
+
raise NotImplementedError
|
106 |
+
|
107 |
+
assert hidden_states.shape[1] == self.channels
|
108 |
+
hidden_states = self.conv(hidden_states)
|
109 |
+
|
110 |
+
return hidden_states
|
111 |
+
|
112 |
+
|
113 |
+
class ResnetBlock3D(nn.Module):
|
114 |
+
def __init__(
|
115 |
+
self,
|
116 |
+
*,
|
117 |
+
in_channels,
|
118 |
+
out_channels=None,
|
119 |
+
conv_shortcut=False,
|
120 |
+
dropout=0.0,
|
121 |
+
temb_channels=512,
|
122 |
+
groups=32,
|
123 |
+
groups_out=None,
|
124 |
+
pre_norm=True,
|
125 |
+
eps=1e-6,
|
126 |
+
non_linearity="swish",
|
127 |
+
time_embedding_norm="default",
|
128 |
+
output_scale_factor=1.0,
|
129 |
+
use_in_shortcut=None,
|
130 |
+
):
|
131 |
+
super().__init__()
|
132 |
+
self.pre_norm = pre_norm
|
133 |
+
self.pre_norm = True
|
134 |
+
self.in_channels = in_channels
|
135 |
+
out_channels = in_channels if out_channels is None else out_channels
|
136 |
+
self.out_channels = out_channels
|
137 |
+
self.use_conv_shortcut = conv_shortcut
|
138 |
+
self.time_embedding_norm = time_embedding_norm
|
139 |
+
self.output_scale_factor = output_scale_factor
|
140 |
+
|
141 |
+
if groups_out is None:
|
142 |
+
groups_out = groups
|
143 |
+
|
144 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
145 |
+
|
146 |
+
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
147 |
+
|
148 |
+
if temb_channels is not None:
|
149 |
+
if self.time_embedding_norm == "default":
|
150 |
+
time_emb_proj_out_channels = out_channels
|
151 |
+
elif self.time_embedding_norm == "scale_shift":
|
152 |
+
time_emb_proj_out_channels = out_channels * 2
|
153 |
+
else:
|
154 |
+
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
155 |
+
|
156 |
+
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
|
157 |
+
else:
|
158 |
+
self.time_emb_proj = None
|
159 |
+
|
160 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
161 |
+
self.dropout = torch.nn.Dropout(dropout)
|
162 |
+
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
163 |
+
|
164 |
+
if non_linearity == "swish":
|
165 |
+
self.nonlinearity = lambda x: F.silu(x)
|
166 |
+
elif non_linearity == "mish":
|
167 |
+
self.nonlinearity = Mish()
|
168 |
+
elif non_linearity == "silu":
|
169 |
+
self.nonlinearity = nn.SiLU()
|
170 |
+
|
171 |
+
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
|
172 |
+
|
173 |
+
self.conv_shortcut = None
|
174 |
+
if self.use_in_shortcut:
|
175 |
+
self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
176 |
+
|
177 |
+
def forward(self, input_tensor, temb):
|
178 |
+
hidden_states = input_tensor
|
179 |
+
|
180 |
+
hidden_states = self.norm1(hidden_states)
|
181 |
+
hidden_states = self.nonlinearity(hidden_states)
|
182 |
+
|
183 |
+
hidden_states = self.conv1(hidden_states)
|
184 |
+
|
185 |
+
if temb is not None:
|
186 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
187 |
+
|
188 |
+
if temb is not None and self.time_embedding_norm == "default":
|
189 |
+
hidden_states = hidden_states + temb
|
190 |
+
|
191 |
+
hidden_states = self.norm2(hidden_states)
|
192 |
+
|
193 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
194 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
195 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
196 |
+
|
197 |
+
hidden_states = self.nonlinearity(hidden_states)
|
198 |
+
|
199 |
+
hidden_states = self.dropout(hidden_states)
|
200 |
+
hidden_states = self.conv2(hidden_states)
|
201 |
+
|
202 |
+
if self.conv_shortcut is not None:
|
203 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
204 |
+
|
205 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
206 |
+
|
207 |
+
return output_tensor
|
208 |
+
|
209 |
+
|
210 |
+
class Mish(torch.nn.Module):
|
211 |
+
def forward(self, hidden_states):
|
212 |
+
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
|
models/unet.py
ADDED
@@ -0,0 +1,699 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
sys.path.append(os.path.split(sys.path[0])[0])
|
9 |
+
|
10 |
+
import math
|
11 |
+
import json
|
12 |
+
import torch
|
13 |
+
import einops
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.utils.checkpoint
|
16 |
+
|
17 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
18 |
+
from diffusers.utils import BaseOutput, logging
|
19 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
20 |
+
from einops import rearrange
|
21 |
+
|
22 |
+
try:
|
23 |
+
from diffusers.models.modeling_utils import ModelMixin
|
24 |
+
except:
|
25 |
+
from diffusers.modeling_utils import ModelMixin # 0.11.1
|
26 |
+
|
27 |
+
try:
|
28 |
+
from .unet_blocks import (
|
29 |
+
CrossAttnDownBlock3D,
|
30 |
+
CrossAttnUpBlock3D,
|
31 |
+
DownBlock3D,
|
32 |
+
UNetMidBlock3DCrossAttn,
|
33 |
+
UpBlock3D,
|
34 |
+
get_down_block,
|
35 |
+
get_up_block,
|
36 |
+
)
|
37 |
+
from .resnet import InflatedConv3d
|
38 |
+
except:
|
39 |
+
from unet_blocks import (
|
40 |
+
CrossAttnDownBlock3D,
|
41 |
+
CrossAttnUpBlock3D,
|
42 |
+
DownBlock3D,
|
43 |
+
UNetMidBlock3DCrossAttn,
|
44 |
+
UpBlock3D,
|
45 |
+
get_down_block,
|
46 |
+
get_up_block,
|
47 |
+
)
|
48 |
+
from resnet import InflatedConv3d
|
49 |
+
|
50 |
+
from rotary_embedding_torch import RotaryEmbedding
|
51 |
+
|
52 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
53 |
+
|
54 |
+
class RelativePositionBias(nn.Module):
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
heads=8,
|
58 |
+
num_buckets=32,
|
59 |
+
max_distance=128,
|
60 |
+
):
|
61 |
+
super().__init__()
|
62 |
+
self.num_buckets = num_buckets
|
63 |
+
self.max_distance = max_distance
|
64 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
65 |
+
|
66 |
+
@staticmethod
|
67 |
+
def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
|
68 |
+
ret = 0
|
69 |
+
n = -relative_position
|
70 |
+
|
71 |
+
num_buckets //= 2
|
72 |
+
ret += (n < 0).long() * num_buckets
|
73 |
+
n = torch.abs(n)
|
74 |
+
|
75 |
+
max_exact = num_buckets // 2
|
76 |
+
is_small = n < max_exact
|
77 |
+
|
78 |
+
val_if_large = max_exact + (
|
79 |
+
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
|
80 |
+
).long()
|
81 |
+
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
82 |
+
|
83 |
+
ret += torch.where(is_small, n, val_if_large)
|
84 |
+
return ret
|
85 |
+
|
86 |
+
def forward(self, n, device):
|
87 |
+
q_pos = torch.arange(n, dtype = torch.long, device = device)
|
88 |
+
k_pos = torch.arange(n, dtype = torch.long, device = device)
|
89 |
+
rel_pos = einops.rearrange(k_pos, 'j -> 1 j') - einops.rearrange(q_pos, 'i -> i 1')
|
90 |
+
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
|
91 |
+
values = self.relative_attention_bias(rp_bucket)
|
92 |
+
return einops.rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames
|
93 |
+
|
94 |
+
@dataclass
|
95 |
+
class UNet3DConditionOutput(BaseOutput):
|
96 |
+
sample: torch.FloatTensor
|
97 |
+
|
98 |
+
|
99 |
+
class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
100 |
+
_supports_gradient_checkpointing = True
|
101 |
+
|
102 |
+
@register_to_config
|
103 |
+
def __init__(
|
104 |
+
self,
|
105 |
+
sample_size: Optional[int] = None, # 64
|
106 |
+
in_channels: int = 4,
|
107 |
+
out_channels: int = 4,
|
108 |
+
center_input_sample: bool = False,
|
109 |
+
flip_sin_to_cos: bool = True,
|
110 |
+
freq_shift: int = 0,
|
111 |
+
down_block_types: Tuple[str] = (
|
112 |
+
"CrossAttnDownBlock3D",
|
113 |
+
"CrossAttnDownBlock3D",
|
114 |
+
"CrossAttnDownBlock3D",
|
115 |
+
"DownBlock3D",
|
116 |
+
),
|
117 |
+
mid_block_type: str = "UNetMidBlock3DCrossAttn",
|
118 |
+
up_block_types: Tuple[str] = (
|
119 |
+
"UpBlock3D",
|
120 |
+
"CrossAttnUpBlock3D",
|
121 |
+
"CrossAttnUpBlock3D",
|
122 |
+
"CrossAttnUpBlock3D"
|
123 |
+
),
|
124 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
125 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
126 |
+
layers_per_block: int = 2,
|
127 |
+
downsample_padding: int = 1,
|
128 |
+
mid_block_scale_factor: float = 1,
|
129 |
+
act_fn: str = "silu",
|
130 |
+
norm_num_groups: int = 32,
|
131 |
+
norm_eps: float = 1e-5,
|
132 |
+
cross_attention_dim: int = 1280,
|
133 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
134 |
+
dual_cross_attention: bool = False,
|
135 |
+
use_linear_projection: bool = False,
|
136 |
+
class_embed_type: Optional[str] = None,
|
137 |
+
num_class_embeds: Optional[int] = None,
|
138 |
+
upcast_attention: bool = False,
|
139 |
+
resnet_time_scale_shift: str = "default",
|
140 |
+
use_first_frame: bool = False,
|
141 |
+
use_relative_position: bool = False,
|
142 |
+
):
|
143 |
+
super().__init__()
|
144 |
+
|
145 |
+
# print(use_first_frame)
|
146 |
+
|
147 |
+
self.sample_size = sample_size
|
148 |
+
time_embed_dim = block_out_channels[0] * 4
|
149 |
+
|
150 |
+
# input
|
151 |
+
self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
152 |
+
|
153 |
+
# time
|
154 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
155 |
+
timestep_input_dim = block_out_channels[0]
|
156 |
+
|
157 |
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
158 |
+
|
159 |
+
# class embedding
|
160 |
+
if class_embed_type is None and num_class_embeds is not None:
|
161 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
162 |
+
elif class_embed_type == "timestep":
|
163 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
164 |
+
elif class_embed_type == "identity":
|
165 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
166 |
+
else:
|
167 |
+
self.class_embedding = None
|
168 |
+
|
169 |
+
self.down_blocks = nn.ModuleList([])
|
170 |
+
self.mid_block = None
|
171 |
+
self.up_blocks = nn.ModuleList([])
|
172 |
+
|
173 |
+
if isinstance(only_cross_attention, bool):
|
174 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
175 |
+
|
176 |
+
if isinstance(attention_head_dim, int):
|
177 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
178 |
+
|
179 |
+
rotary_emb = RotaryEmbedding(32)
|
180 |
+
|
181 |
+
# down
|
182 |
+
output_channel = block_out_channels[0]
|
183 |
+
for i, down_block_type in enumerate(down_block_types):
|
184 |
+
input_channel = output_channel
|
185 |
+
output_channel = block_out_channels[i]
|
186 |
+
is_final_block = i == len(block_out_channels) - 1
|
187 |
+
|
188 |
+
down_block = get_down_block(
|
189 |
+
down_block_type,
|
190 |
+
num_layers=layers_per_block,
|
191 |
+
in_channels=input_channel,
|
192 |
+
out_channels=output_channel,
|
193 |
+
temb_channels=time_embed_dim,
|
194 |
+
add_downsample=not is_final_block,
|
195 |
+
resnet_eps=norm_eps,
|
196 |
+
resnet_act_fn=act_fn,
|
197 |
+
resnet_groups=norm_num_groups,
|
198 |
+
cross_attention_dim=cross_attention_dim,
|
199 |
+
attn_num_head_channels=attention_head_dim[i],
|
200 |
+
downsample_padding=downsample_padding,
|
201 |
+
dual_cross_attention=dual_cross_attention,
|
202 |
+
use_linear_projection=use_linear_projection,
|
203 |
+
only_cross_attention=only_cross_attention[i],
|
204 |
+
upcast_attention=upcast_attention,
|
205 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
206 |
+
use_first_frame=use_first_frame,
|
207 |
+
use_relative_position=use_relative_position,
|
208 |
+
rotary_emb=rotary_emb,
|
209 |
+
)
|
210 |
+
self.down_blocks.append(down_block)
|
211 |
+
|
212 |
+
# mid
|
213 |
+
if mid_block_type == "UNetMidBlock3DCrossAttn":
|
214 |
+
self.mid_block = UNetMidBlock3DCrossAttn(
|
215 |
+
in_channels=block_out_channels[-1],
|
216 |
+
temb_channels=time_embed_dim,
|
217 |
+
resnet_eps=norm_eps,
|
218 |
+
resnet_act_fn=act_fn,
|
219 |
+
output_scale_factor=mid_block_scale_factor,
|
220 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
221 |
+
cross_attention_dim=cross_attention_dim,
|
222 |
+
attn_num_head_channels=attention_head_dim[-1],
|
223 |
+
resnet_groups=norm_num_groups,
|
224 |
+
dual_cross_attention=dual_cross_attention,
|
225 |
+
use_linear_projection=use_linear_projection,
|
226 |
+
upcast_attention=upcast_attention,
|
227 |
+
use_first_frame=use_first_frame,
|
228 |
+
use_relative_position=use_relative_position,
|
229 |
+
rotary_emb=rotary_emb,
|
230 |
+
)
|
231 |
+
else:
|
232 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
233 |
+
|
234 |
+
# count how many layers upsample the videos
|
235 |
+
self.num_upsamplers = 0
|
236 |
+
|
237 |
+
# up
|
238 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
239 |
+
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
240 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
241 |
+
output_channel = reversed_block_out_channels[0]
|
242 |
+
for i, up_block_type in enumerate(up_block_types):
|
243 |
+
is_final_block = i == len(block_out_channels) - 1
|
244 |
+
|
245 |
+
prev_output_channel = output_channel
|
246 |
+
output_channel = reversed_block_out_channels[i]
|
247 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
248 |
+
|
249 |
+
# add upsample block for all BUT final layer
|
250 |
+
if not is_final_block:
|
251 |
+
add_upsample = True
|
252 |
+
self.num_upsamplers += 1
|
253 |
+
else:
|
254 |
+
add_upsample = False
|
255 |
+
|
256 |
+
up_block = get_up_block(
|
257 |
+
up_block_type,
|
258 |
+
num_layers=layers_per_block + 1,
|
259 |
+
in_channels=input_channel,
|
260 |
+
out_channels=output_channel,
|
261 |
+
prev_output_channel=prev_output_channel,
|
262 |
+
temb_channels=time_embed_dim,
|
263 |
+
add_upsample=add_upsample,
|
264 |
+
resnet_eps=norm_eps,
|
265 |
+
resnet_act_fn=act_fn,
|
266 |
+
resnet_groups=norm_num_groups,
|
267 |
+
cross_attention_dim=cross_attention_dim,
|
268 |
+
attn_num_head_channels=reversed_attention_head_dim[i],
|
269 |
+
dual_cross_attention=dual_cross_attention,
|
270 |
+
use_linear_projection=use_linear_projection,
|
271 |
+
only_cross_attention=only_cross_attention[i],
|
272 |
+
upcast_attention=upcast_attention,
|
273 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
274 |
+
use_first_frame=use_first_frame,
|
275 |
+
use_relative_position=use_relative_position,
|
276 |
+
rotary_emb=rotary_emb,
|
277 |
+
)
|
278 |
+
self.up_blocks.append(up_block)
|
279 |
+
prev_output_channel = output_channel
|
280 |
+
|
281 |
+
# out
|
282 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
283 |
+
self.conv_act = nn.SiLU()
|
284 |
+
self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
285 |
+
|
286 |
+
# relative time positional embeddings
|
287 |
+
self.use_relative_position = use_relative_position
|
288 |
+
if self.use_relative_position:
|
289 |
+
self.time_rel_pos_bias = RelativePositionBias(heads=8, max_distance=32) # realistically will not be able to generate that many frames of video... yet
|
290 |
+
|
291 |
+
def set_attention_slice(self, slice_size):
|
292 |
+
r"""
|
293 |
+
Enable sliced attention computation.
|
294 |
+
|
295 |
+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
296 |
+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
300 |
+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
301 |
+
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
302 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
303 |
+
must be a multiple of `slice_size`.
|
304 |
+
"""
|
305 |
+
sliceable_head_dims = []
|
306 |
+
|
307 |
+
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
|
308 |
+
if hasattr(module, "set_attention_slice"):
|
309 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
310 |
+
|
311 |
+
for child in module.children():
|
312 |
+
fn_recursive_retrieve_slicable_dims(child)
|
313 |
+
|
314 |
+
# retrieve number of attention layers
|
315 |
+
for module in self.children():
|
316 |
+
fn_recursive_retrieve_slicable_dims(module)
|
317 |
+
|
318 |
+
num_slicable_layers = len(sliceable_head_dims)
|
319 |
+
|
320 |
+
if slice_size == "auto":
|
321 |
+
# half the attention head size is usually a good trade-off between
|
322 |
+
# speed and memory
|
323 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
324 |
+
elif slice_size == "max":
|
325 |
+
# make smallest slice possible
|
326 |
+
slice_size = num_slicable_layers * [1]
|
327 |
+
|
328 |
+
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
329 |
+
|
330 |
+
if len(slice_size) != len(sliceable_head_dims):
|
331 |
+
raise ValueError(
|
332 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
333 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
334 |
+
)
|
335 |
+
|
336 |
+
for i in range(len(slice_size)):
|
337 |
+
size = slice_size[i]
|
338 |
+
dim = sliceable_head_dims[i]
|
339 |
+
if size is not None and size > dim:
|
340 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
341 |
+
|
342 |
+
# Recursively walk through all the children.
|
343 |
+
# Any children which exposes the set_attention_slice method
|
344 |
+
# gets the message
|
345 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
346 |
+
if hasattr(module, "set_attention_slice"):
|
347 |
+
module.set_attention_slice(slice_size.pop())
|
348 |
+
|
349 |
+
for child in module.children():
|
350 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
351 |
+
|
352 |
+
reversed_slice_size = list(reversed(slice_size))
|
353 |
+
for module in self.children():
|
354 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
355 |
+
|
356 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
357 |
+
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
358 |
+
module.gradient_checkpointing = value
|
359 |
+
|
360 |
+
def forward(
|
361 |
+
self,
|
362 |
+
sample: torch.FloatTensor,
|
363 |
+
timestep: Union[torch.Tensor, float, int],
|
364 |
+
encoder_hidden_states: torch.Tensor = None,
|
365 |
+
class_labels: Optional[torch.Tensor] = None,
|
366 |
+
attention_mask: Optional[torch.Tensor] = None,
|
367 |
+
use_image_num: int = 0,
|
368 |
+
return_dict: bool = True,
|
369 |
+
ip_hidden_states = None,
|
370 |
+
encoder_temporal_hidden_states = None
|
371 |
+
) -> Union[UNet3DConditionOutput, Tuple]:
|
372 |
+
r"""
|
373 |
+
Args:
|
374 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
375 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
376 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
377 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
378 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
379 |
+
|
380 |
+
Returns:
|
381 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
382 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
383 |
+
returning a tuple, the first element is the sample tensor.
|
384 |
+
"""
|
385 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
386 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
387 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
388 |
+
# on the fly if necessary.
|
389 |
+
if ip_hidden_states is not None:
|
390 |
+
b = ip_hidden_states.shape[0]
|
391 |
+
ip_hidden_states = rearrange(ip_hidden_states, 'b n c -> (b n) c')
|
392 |
+
ip_hidden_states = self.image_proj_model(ip_hidden_states)
|
393 |
+
ip_hidden_states = rearrange(ip_hidden_states, '(b n) m c -> b n m c', b=b)
|
394 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
395 |
+
|
396 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
397 |
+
forward_upsample_size = False
|
398 |
+
upsample_size = None
|
399 |
+
|
400 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
401 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
402 |
+
forward_upsample_size = True
|
403 |
+
|
404 |
+
# prepare attention_mask
|
405 |
+
if attention_mask is not None:
|
406 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
407 |
+
attention_mask = attention_mask.unsqueeze(1)
|
408 |
+
|
409 |
+
# center input if necessary
|
410 |
+
if self.config.center_input_sample:
|
411 |
+
sample = 2 * sample - 1.0
|
412 |
+
|
413 |
+
# time
|
414 |
+
timesteps = timestep
|
415 |
+
if not torch.is_tensor(timesteps):
|
416 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
417 |
+
is_mps = sample.device.type == "mps"
|
418 |
+
if isinstance(timestep, float):
|
419 |
+
dtype = torch.float32 if is_mps else torch.float64
|
420 |
+
else:
|
421 |
+
dtype = torch.int32 if is_mps else torch.int64
|
422 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
423 |
+
elif len(timesteps.shape) == 0:
|
424 |
+
timesteps = timesteps[None].to(sample.device)
|
425 |
+
|
426 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
427 |
+
timesteps = timesteps.expand(sample.shape[0])
|
428 |
+
|
429 |
+
t_emb = self.time_proj(timesteps)
|
430 |
+
|
431 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
432 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
433 |
+
# there might be better ways to encapsulate this.
|
434 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
435 |
+
emb = self.time_embedding(t_emb)
|
436 |
+
|
437 |
+
if self.class_embedding is not None:
|
438 |
+
if class_labels is None:
|
439 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
440 |
+
|
441 |
+
if self.config.class_embed_type == "timestep":
|
442 |
+
class_labels = self.time_proj(class_labels)
|
443 |
+
|
444 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
445 |
+
# print(emb.shape) # torch.Size([3, 1280])
|
446 |
+
# print(class_emb.shape) # torch.Size([3, 1280])
|
447 |
+
emb = emb + class_emb
|
448 |
+
|
449 |
+
if self.use_relative_position:
|
450 |
+
frame_rel_pos_bias = self.time_rel_pos_bias(sample.shape[2], device=sample.device)
|
451 |
+
else:
|
452 |
+
frame_rel_pos_bias = None
|
453 |
+
|
454 |
+
# pre-process
|
455 |
+
sample = self.conv_in(sample)
|
456 |
+
|
457 |
+
# down
|
458 |
+
down_block_res_samples = (sample,)
|
459 |
+
for downsample_block in self.down_blocks:
|
460 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
461 |
+
sample, res_samples = downsample_block(
|
462 |
+
hidden_states=sample,
|
463 |
+
temb=emb,
|
464 |
+
encoder_hidden_states=encoder_hidden_states,
|
465 |
+
attention_mask=attention_mask,
|
466 |
+
use_image_num=use_image_num,
|
467 |
+
ip_hidden_states=ip_hidden_states,
|
468 |
+
encoder_temporal_hidden_states=encoder_temporal_hidden_states
|
469 |
+
)
|
470 |
+
else:
|
471 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
472 |
+
|
473 |
+
down_block_res_samples += res_samples
|
474 |
+
|
475 |
+
# mid
|
476 |
+
sample = self.mid_block(
|
477 |
+
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states, encoder_temporal_hidden_states=encoder_temporal_hidden_states
|
478 |
+
)
|
479 |
+
|
480 |
+
# up
|
481 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
482 |
+
is_final_block = i == len(self.up_blocks) - 1
|
483 |
+
|
484 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
485 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
486 |
+
|
487 |
+
# if we have not reached the final block and need to forward the
|
488 |
+
# upsample size, we do it here
|
489 |
+
if not is_final_block and forward_upsample_size:
|
490 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
491 |
+
|
492 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
493 |
+
sample = upsample_block(
|
494 |
+
hidden_states=sample,
|
495 |
+
temb=emb,
|
496 |
+
res_hidden_states_tuple=res_samples,
|
497 |
+
encoder_hidden_states=encoder_hidden_states,
|
498 |
+
upsample_size=upsample_size,
|
499 |
+
attention_mask=attention_mask,
|
500 |
+
use_image_num=use_image_num,
|
501 |
+
ip_hidden_states=ip_hidden_states,
|
502 |
+
encoder_temporal_hidden_states=encoder_temporal_hidden_states
|
503 |
+
)
|
504 |
+
else:
|
505 |
+
sample = upsample_block(
|
506 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
507 |
+
)
|
508 |
+
# post-process
|
509 |
+
sample = self.conv_norm_out(sample)
|
510 |
+
sample = self.conv_act(sample)
|
511 |
+
sample = self.conv_out(sample)
|
512 |
+
# print(sample.shape)
|
513 |
+
|
514 |
+
if not return_dict:
|
515 |
+
return (sample,)
|
516 |
+
sample = UNet3DConditionOutput(sample=sample)
|
517 |
+
return sample
|
518 |
+
|
519 |
+
def forward_with_cfg(self,
|
520 |
+
x,
|
521 |
+
t,
|
522 |
+
encoder_hidden_states = None,
|
523 |
+
class_labels: Optional[torch.Tensor] = None,
|
524 |
+
cfg_scale=4.0,
|
525 |
+
use_fp16=False,
|
526 |
+
ip_hidden_states = None):
|
527 |
+
"""
|
528 |
+
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
529 |
+
"""
|
530 |
+
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
531 |
+
half = x[: len(x) // 2]
|
532 |
+
combined = torch.cat([half, half], dim=0)
|
533 |
+
if use_fp16:
|
534 |
+
combined = combined.to(dtype=torch.float16)
|
535 |
+
model_out = self.forward(combined, t, encoder_hidden_states, class_labels, ip_hidden_states=ip_hidden_states).sample
|
536 |
+
# For exact reproducibility reasons, we apply classifier-free guidance on only
|
537 |
+
# three channels by default. The standard approach to cfg applies it to all channels.
|
538 |
+
# This can be done by uncommenting the following line and commenting-out the line following that.
|
539 |
+
eps, rest = model_out[:, :4], model_out[:, 4:]
|
540 |
+
# eps, rest = model_out[:, :3], model_out[:, 3:] # b c f h w
|
541 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
542 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
543 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
544 |
+
return torch.cat([eps, rest], dim=1)
|
545 |
+
|
546 |
+
@classmethod
|
547 |
+
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, use_concat=False):
|
548 |
+
if subfolder is not None:
|
549 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
550 |
+
|
551 |
+
|
552 |
+
# the content of the config file
|
553 |
+
# {
|
554 |
+
# "_class_name": "UNet2DConditionModel",
|
555 |
+
# "_diffusers_version": "0.2.2",
|
556 |
+
# "act_fn": "silu",
|
557 |
+
# "attention_head_dim": 8,
|
558 |
+
# "block_out_channels": [
|
559 |
+
# 320,
|
560 |
+
# 640,
|
561 |
+
# 1280,
|
562 |
+
# 1280
|
563 |
+
# ],
|
564 |
+
# "center_input_sample": false,
|
565 |
+
# "cross_attention_dim": 768,
|
566 |
+
# "down_block_types": [
|
567 |
+
# "CrossAttnDownBlock2D",
|
568 |
+
# "CrossAttnDownBlock2D",
|
569 |
+
# "CrossAttnDownBlock2D",
|
570 |
+
# "DownBlock2D"
|
571 |
+
# ],
|
572 |
+
# "downsample_padding": 1,
|
573 |
+
# "flip_sin_to_cos": true,
|
574 |
+
# "freq_shift": 0,
|
575 |
+
# "in_channels": 4,
|
576 |
+
# "layers_per_block": 2,
|
577 |
+
# "mid_block_scale_factor": 1,
|
578 |
+
# "norm_eps": 1e-05,
|
579 |
+
# "norm_num_groups": 32,
|
580 |
+
# "out_channels": 4,
|
581 |
+
# "sample_size": 64,
|
582 |
+
# "up_block_types": [
|
583 |
+
# "UpBlock2D",
|
584 |
+
# "CrossAttnUpBlock2D",
|
585 |
+
# "CrossAttnUpBlock2D",
|
586 |
+
# "CrossAttnUpBlock2D"
|
587 |
+
# ]
|
588 |
+
# }
|
589 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
590 |
+
if not os.path.isfile(config_file):
|
591 |
+
raise RuntimeError(f"{config_file} does not exist")
|
592 |
+
with open(config_file, "r") as f:
|
593 |
+
config = json.load(f)
|
594 |
+
config["_class_name"] = cls.__name__
|
595 |
+
config["down_block_types"] = [
|
596 |
+
"CrossAttnDownBlock3D",
|
597 |
+
"CrossAttnDownBlock3D",
|
598 |
+
"CrossAttnDownBlock3D",
|
599 |
+
"DownBlock3D"
|
600 |
+
]
|
601 |
+
config["up_block_types"] = [
|
602 |
+
"UpBlock3D",
|
603 |
+
"CrossAttnUpBlock3D",
|
604 |
+
"CrossAttnUpBlock3D",
|
605 |
+
"CrossAttnUpBlock3D"
|
606 |
+
]
|
607 |
+
|
608 |
+
# config["use_first_frame"] = True
|
609 |
+
|
610 |
+
config["use_first_frame"] = False
|
611 |
+
if use_concat:
|
612 |
+
config["in_channels"] = 9
|
613 |
+
# config["use_relative_position"] = True
|
614 |
+
|
615 |
+
# # tmp
|
616 |
+
# config["class_embed_type"] = "timestep"
|
617 |
+
# config["num_class_embeds"] = 100
|
618 |
+
|
619 |
+
from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin
|
620 |
+
|
621 |
+
# {'_class_name': 'UNet3DConditionModel',
|
622 |
+
# '_diffusers_version': '0.2.2',
|
623 |
+
# 'act_fn': 'silu',
|
624 |
+
# 'attention_head_dim': 8,
|
625 |
+
# 'block_out_channels': [320, 640, 1280, 1280],
|
626 |
+
# 'center_input_sample': False,
|
627 |
+
# 'cross_attention_dim': 768,
|
628 |
+
# 'down_block_types':
|
629 |
+
# ['CrossAttnDownBlock3D',
|
630 |
+
# 'CrossAttnDownBlock3D',
|
631 |
+
# 'CrossAttnDownBlock3D',
|
632 |
+
# 'DownBlock3D'],
|
633 |
+
# 'downsample_padding': 1,
|
634 |
+
# 'flip_sin_to_cos': True,
|
635 |
+
# 'freq_shift': 0,
|
636 |
+
# 'in_channels': 4,
|
637 |
+
# 'layers_per_block': 2,
|
638 |
+
# 'mid_block_scale_factor': 1,
|
639 |
+
# 'norm_eps': 1e-05,
|
640 |
+
# 'norm_num_groups': 32,
|
641 |
+
# 'out_channels': 4,
|
642 |
+
# 'sample_size': 64,
|
643 |
+
# 'up_block_types':
|
644 |
+
# ['UpBlock3D',
|
645 |
+
# 'CrossAttnUpBlock3D',
|
646 |
+
# 'CrossAttnUpBlock3D',
|
647 |
+
# 'CrossAttnUpBlock3D']}
|
648 |
+
|
649 |
+
model = cls.from_config(config)
|
650 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
651 |
+
if not os.path.isfile(model_file):
|
652 |
+
raise RuntimeError(f"{model_file} does not exist")
|
653 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
654 |
+
|
655 |
+
if use_concat:
|
656 |
+
new_state_dict = {}
|
657 |
+
conv_in_weight = state_dict["conv_in.weight"]
|
658 |
+
new_conv_weight = torch.zeros((conv_in_weight.shape[0], 9, *conv_in_weight.shape[2:]), dtype=conv_in_weight.dtype)
|
659 |
+
|
660 |
+
for i, j in zip([0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 6, 7, 8]):
|
661 |
+
new_conv_weight[:, j] = conv_in_weight[:, i]
|
662 |
+
new_state_dict["conv_in.weight"] = new_conv_weight
|
663 |
+
new_state_dict["conv_in.bias"] = state_dict["conv_in.bias"]
|
664 |
+
for k, v in model.state_dict().items():
|
665 |
+
# print(k)
|
666 |
+
if '_temp.' in k:
|
667 |
+
new_state_dict.update({k: v})
|
668 |
+
if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
|
669 |
+
k = k.replace('attn_fcross', 'attn1')
|
670 |
+
state_dict.update({k: state_dict[k]})
|
671 |
+
if 'norm_fcross' in k:
|
672 |
+
k = k.replace('norm_fcross', 'norm1')
|
673 |
+
state_dict.update({k: state_dict[k]})
|
674 |
+
|
675 |
+
if 'conv_in' in k:
|
676 |
+
continue
|
677 |
+
else:
|
678 |
+
new_state_dict[k] = v
|
679 |
+
# # tmp
|
680 |
+
# if 'class_embedding' in k:
|
681 |
+
# state_dict.update({k: v})
|
682 |
+
# breakpoint()
|
683 |
+
model.load_state_dict(new_state_dict)
|
684 |
+
else:
|
685 |
+
for k, v in model.state_dict().items():
|
686 |
+
# print(k)
|
687 |
+
if '_temp' in k:
|
688 |
+
state_dict.update({k: v})
|
689 |
+
if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
|
690 |
+
k = k.replace('attn_fcross', 'attn1')
|
691 |
+
state_dict.update({k: state_dict[k]})
|
692 |
+
if 'norm_fcross' in k:
|
693 |
+
k = k.replace('norm_fcross', 'norm1')
|
694 |
+
state_dict.update({k: state_dict[k]})
|
695 |
+
|
696 |
+
model.load_state_dict(state_dict)
|
697 |
+
|
698 |
+
return model
|
699 |
+
|
models/unet_blocks.py
ADDED
@@ -0,0 +1,650 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
sys.path.append(os.path.split(sys.path[0])[0])
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
try:
|
10 |
+
from .attention import Transformer3DModel
|
11 |
+
from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
|
12 |
+
except:
|
13 |
+
from attention import Transformer3DModel
|
14 |
+
from resnet import Downsample3D, ResnetBlock3D, Upsample3D
|
15 |
+
|
16 |
+
|
17 |
+
def get_down_block(
|
18 |
+
down_block_type,
|
19 |
+
num_layers,
|
20 |
+
in_channels,
|
21 |
+
out_channels,
|
22 |
+
temb_channels,
|
23 |
+
add_downsample,
|
24 |
+
resnet_eps,
|
25 |
+
resnet_act_fn,
|
26 |
+
attn_num_head_channels,
|
27 |
+
resnet_groups=None,
|
28 |
+
cross_attention_dim=None,
|
29 |
+
downsample_padding=None,
|
30 |
+
dual_cross_attention=False,
|
31 |
+
use_linear_projection=False,
|
32 |
+
only_cross_attention=False,
|
33 |
+
upcast_attention=False,
|
34 |
+
resnet_time_scale_shift="default",
|
35 |
+
use_first_frame=False,
|
36 |
+
use_relative_position=False,
|
37 |
+
rotary_emb=False,
|
38 |
+
):
|
39 |
+
# print(down_block_type)
|
40 |
+
# print(use_first_frame)
|
41 |
+
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
42 |
+
if down_block_type == "DownBlock3D":
|
43 |
+
return DownBlock3D(
|
44 |
+
num_layers=num_layers,
|
45 |
+
in_channels=in_channels,
|
46 |
+
out_channels=out_channels,
|
47 |
+
temb_channels=temb_channels,
|
48 |
+
add_downsample=add_downsample,
|
49 |
+
resnet_eps=resnet_eps,
|
50 |
+
resnet_act_fn=resnet_act_fn,
|
51 |
+
resnet_groups=resnet_groups,
|
52 |
+
downsample_padding=downsample_padding,
|
53 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
54 |
+
)
|
55 |
+
elif down_block_type == "CrossAttnDownBlock3D":
|
56 |
+
if cross_attention_dim is None:
|
57 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
|
58 |
+
return CrossAttnDownBlock3D(
|
59 |
+
num_layers=num_layers,
|
60 |
+
in_channels=in_channels,
|
61 |
+
out_channels=out_channels,
|
62 |
+
temb_channels=temb_channels,
|
63 |
+
add_downsample=add_downsample,
|
64 |
+
resnet_eps=resnet_eps,
|
65 |
+
resnet_act_fn=resnet_act_fn,
|
66 |
+
resnet_groups=resnet_groups,
|
67 |
+
downsample_padding=downsample_padding,
|
68 |
+
cross_attention_dim=cross_attention_dim,
|
69 |
+
attn_num_head_channels=attn_num_head_channels,
|
70 |
+
dual_cross_attention=dual_cross_attention,
|
71 |
+
use_linear_projection=use_linear_projection,
|
72 |
+
only_cross_attention=only_cross_attention,
|
73 |
+
upcast_attention=upcast_attention,
|
74 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
75 |
+
use_first_frame=use_first_frame,
|
76 |
+
use_relative_position=use_relative_position,
|
77 |
+
rotary_emb=rotary_emb,
|
78 |
+
)
|
79 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
80 |
+
|
81 |
+
|
82 |
+
def get_up_block(
|
83 |
+
up_block_type,
|
84 |
+
num_layers,
|
85 |
+
in_channels,
|
86 |
+
out_channels,
|
87 |
+
prev_output_channel,
|
88 |
+
temb_channels,
|
89 |
+
add_upsample,
|
90 |
+
resnet_eps,
|
91 |
+
resnet_act_fn,
|
92 |
+
attn_num_head_channels,
|
93 |
+
resnet_groups=None,
|
94 |
+
cross_attention_dim=None,
|
95 |
+
dual_cross_attention=False,
|
96 |
+
use_linear_projection=False,
|
97 |
+
only_cross_attention=False,
|
98 |
+
upcast_attention=False,
|
99 |
+
resnet_time_scale_shift="default",
|
100 |
+
use_first_frame=False,
|
101 |
+
use_relative_position=False,
|
102 |
+
rotary_emb=False,
|
103 |
+
):
|
104 |
+
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
105 |
+
if up_block_type == "UpBlock3D":
|
106 |
+
return UpBlock3D(
|
107 |
+
num_layers=num_layers,
|
108 |
+
in_channels=in_channels,
|
109 |
+
out_channels=out_channels,
|
110 |
+
prev_output_channel=prev_output_channel,
|
111 |
+
temb_channels=temb_channels,
|
112 |
+
add_upsample=add_upsample,
|
113 |
+
resnet_eps=resnet_eps,
|
114 |
+
resnet_act_fn=resnet_act_fn,
|
115 |
+
resnet_groups=resnet_groups,
|
116 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
117 |
+
)
|
118 |
+
elif up_block_type == "CrossAttnUpBlock3D":
|
119 |
+
if cross_attention_dim is None:
|
120 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
|
121 |
+
return CrossAttnUpBlock3D(
|
122 |
+
num_layers=num_layers,
|
123 |
+
in_channels=in_channels,
|
124 |
+
out_channels=out_channels,
|
125 |
+
prev_output_channel=prev_output_channel,
|
126 |
+
temb_channels=temb_channels,
|
127 |
+
add_upsample=add_upsample,
|
128 |
+
resnet_eps=resnet_eps,
|
129 |
+
resnet_act_fn=resnet_act_fn,
|
130 |
+
resnet_groups=resnet_groups,
|
131 |
+
cross_attention_dim=cross_attention_dim,
|
132 |
+
attn_num_head_channels=attn_num_head_channels,
|
133 |
+
dual_cross_attention=dual_cross_attention,
|
134 |
+
use_linear_projection=use_linear_projection,
|
135 |
+
only_cross_attention=only_cross_attention,
|
136 |
+
upcast_attention=upcast_attention,
|
137 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
138 |
+
use_first_frame=use_first_frame,
|
139 |
+
use_relative_position=use_relative_position,
|
140 |
+
rotary_emb=rotary_emb,
|
141 |
+
)
|
142 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
143 |
+
|
144 |
+
|
145 |
+
class UNetMidBlock3DCrossAttn(nn.Module):
|
146 |
+
def __init__(
|
147 |
+
self,
|
148 |
+
in_channels: int,
|
149 |
+
temb_channels: int,
|
150 |
+
dropout: float = 0.0,
|
151 |
+
num_layers: int = 1,
|
152 |
+
resnet_eps: float = 1e-6,
|
153 |
+
resnet_time_scale_shift: str = "default",
|
154 |
+
resnet_act_fn: str = "swish",
|
155 |
+
resnet_groups: int = 32,
|
156 |
+
resnet_pre_norm: bool = True,
|
157 |
+
attn_num_head_channels=1,
|
158 |
+
output_scale_factor=1.0,
|
159 |
+
cross_attention_dim=1280,
|
160 |
+
dual_cross_attention=False,
|
161 |
+
use_linear_projection=False,
|
162 |
+
upcast_attention=False,
|
163 |
+
use_first_frame=False,
|
164 |
+
use_relative_position=False,
|
165 |
+
rotary_emb=False,
|
166 |
+
):
|
167 |
+
super().__init__()
|
168 |
+
|
169 |
+
self.has_cross_attention = True
|
170 |
+
self.attn_num_head_channels = attn_num_head_channels
|
171 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
172 |
+
|
173 |
+
# there is always at least one resnet
|
174 |
+
resnets = [
|
175 |
+
ResnetBlock3D(
|
176 |
+
in_channels=in_channels,
|
177 |
+
out_channels=in_channels,
|
178 |
+
temb_channels=temb_channels,
|
179 |
+
eps=resnet_eps,
|
180 |
+
groups=resnet_groups,
|
181 |
+
dropout=dropout,
|
182 |
+
time_embedding_norm=resnet_time_scale_shift,
|
183 |
+
non_linearity=resnet_act_fn,
|
184 |
+
output_scale_factor=output_scale_factor,
|
185 |
+
pre_norm=resnet_pre_norm,
|
186 |
+
)
|
187 |
+
]
|
188 |
+
attentions = []
|
189 |
+
|
190 |
+
for _ in range(num_layers):
|
191 |
+
if dual_cross_attention:
|
192 |
+
raise NotImplementedError
|
193 |
+
attentions.append(
|
194 |
+
Transformer3DModel(
|
195 |
+
attn_num_head_channels,
|
196 |
+
in_channels // attn_num_head_channels,
|
197 |
+
in_channels=in_channels,
|
198 |
+
num_layers=1,
|
199 |
+
cross_attention_dim=cross_attention_dim,
|
200 |
+
norm_num_groups=resnet_groups,
|
201 |
+
use_linear_projection=use_linear_projection,
|
202 |
+
upcast_attention=upcast_attention,
|
203 |
+
use_first_frame=use_first_frame,
|
204 |
+
use_relative_position=use_relative_position,
|
205 |
+
rotary_emb=rotary_emb,
|
206 |
+
)
|
207 |
+
)
|
208 |
+
resnets.append(
|
209 |
+
ResnetBlock3D(
|
210 |
+
in_channels=in_channels,
|
211 |
+
out_channels=in_channels,
|
212 |
+
temb_channels=temb_channels,
|
213 |
+
eps=resnet_eps,
|
214 |
+
groups=resnet_groups,
|
215 |
+
dropout=dropout,
|
216 |
+
time_embedding_norm=resnet_time_scale_shift,
|
217 |
+
non_linearity=resnet_act_fn,
|
218 |
+
output_scale_factor=output_scale_factor,
|
219 |
+
pre_norm=resnet_pre_norm,
|
220 |
+
)
|
221 |
+
)
|
222 |
+
|
223 |
+
self.attentions = nn.ModuleList(attentions)
|
224 |
+
self.resnets = nn.ModuleList(resnets)
|
225 |
+
|
226 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None, ip_hidden_states=None, encoder_temporal_hidden_states=None):
|
227 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
228 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
229 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states, encoder_temporal_hidden_states=encoder_temporal_hidden_states).sample
|
230 |
+
hidden_states = resnet(hidden_states, temb)
|
231 |
+
|
232 |
+
return hidden_states
|
233 |
+
|
234 |
+
|
235 |
+
class CrossAttnDownBlock3D(nn.Module):
|
236 |
+
def __init__(
|
237 |
+
self,
|
238 |
+
in_channels: int,
|
239 |
+
out_channels: int,
|
240 |
+
temb_channels: int,
|
241 |
+
dropout: float = 0.0,
|
242 |
+
num_layers: int = 1,
|
243 |
+
resnet_eps: float = 1e-6,
|
244 |
+
resnet_time_scale_shift: str = "default",
|
245 |
+
resnet_act_fn: str = "swish",
|
246 |
+
resnet_groups: int = 32,
|
247 |
+
resnet_pre_norm: bool = True,
|
248 |
+
attn_num_head_channels=1,
|
249 |
+
cross_attention_dim=1280,
|
250 |
+
output_scale_factor=1.0,
|
251 |
+
downsample_padding=1,
|
252 |
+
add_downsample=True,
|
253 |
+
dual_cross_attention=False,
|
254 |
+
use_linear_projection=False,
|
255 |
+
only_cross_attention=False,
|
256 |
+
upcast_attention=False,
|
257 |
+
use_first_frame=False,
|
258 |
+
use_relative_position=False,
|
259 |
+
rotary_emb=False,
|
260 |
+
):
|
261 |
+
super().__init__()
|
262 |
+
resnets = []
|
263 |
+
attentions = []
|
264 |
+
|
265 |
+
# print(use_first_frame)
|
266 |
+
|
267 |
+
self.has_cross_attention = True
|
268 |
+
self.attn_num_head_channels = attn_num_head_channels
|
269 |
+
|
270 |
+
for i in range(num_layers):
|
271 |
+
in_channels = in_channels if i == 0 else out_channels
|
272 |
+
resnets.append(
|
273 |
+
ResnetBlock3D(
|
274 |
+
in_channels=in_channels,
|
275 |
+
out_channels=out_channels,
|
276 |
+
temb_channels=temb_channels,
|
277 |
+
eps=resnet_eps,
|
278 |
+
groups=resnet_groups,
|
279 |
+
dropout=dropout,
|
280 |
+
time_embedding_norm=resnet_time_scale_shift,
|
281 |
+
non_linearity=resnet_act_fn,
|
282 |
+
output_scale_factor=output_scale_factor,
|
283 |
+
pre_norm=resnet_pre_norm,
|
284 |
+
)
|
285 |
+
)
|
286 |
+
if dual_cross_attention:
|
287 |
+
raise NotImplementedError
|
288 |
+
attentions.append(
|
289 |
+
Transformer3DModel(
|
290 |
+
attn_num_head_channels,
|
291 |
+
out_channels // attn_num_head_channels,
|
292 |
+
in_channels=out_channels,
|
293 |
+
num_layers=1,
|
294 |
+
cross_attention_dim=cross_attention_dim,
|
295 |
+
norm_num_groups=resnet_groups,
|
296 |
+
use_linear_projection=use_linear_projection,
|
297 |
+
only_cross_attention=only_cross_attention,
|
298 |
+
upcast_attention=upcast_attention,
|
299 |
+
use_first_frame=use_first_frame,
|
300 |
+
use_relative_position=use_relative_position,
|
301 |
+
rotary_emb=rotary_emb,
|
302 |
+
)
|
303 |
+
)
|
304 |
+
self.attentions = nn.ModuleList(attentions)
|
305 |
+
self.resnets = nn.ModuleList(resnets)
|
306 |
+
|
307 |
+
if add_downsample:
|
308 |
+
self.downsamplers = nn.ModuleList(
|
309 |
+
[
|
310 |
+
Downsample3D(
|
311 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
312 |
+
)
|
313 |
+
]
|
314 |
+
)
|
315 |
+
else:
|
316 |
+
self.downsamplers = None
|
317 |
+
|
318 |
+
self.gradient_checkpointing = False
|
319 |
+
|
320 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None, ip_hidden_states=None, encoder_temporal_hidden_states=None):
|
321 |
+
output_states = ()
|
322 |
+
|
323 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
324 |
+
if self.training and self.gradient_checkpointing:
|
325 |
+
|
326 |
+
def create_custom_forward(module, return_dict=None):
|
327 |
+
def custom_forward(*inputs):
|
328 |
+
if return_dict is not None:
|
329 |
+
return module(*inputs, return_dict=return_dict)
|
330 |
+
else:
|
331 |
+
return module(*inputs)
|
332 |
+
|
333 |
+
return custom_forward
|
334 |
+
|
335 |
+
def create_custom_forward_attn(module, return_dict=None, use_image_num=None, ip_hidden_states=None):
|
336 |
+
def custom_forward(*inputs):
|
337 |
+
if return_dict is not None:
|
338 |
+
return module(*inputs, return_dict=return_dict, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states)
|
339 |
+
else:
|
340 |
+
return module(*inputs, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states)
|
341 |
+
|
342 |
+
return custom_forward
|
343 |
+
|
344 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
345 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
346 |
+
create_custom_forward_attn(attn, return_dict=False, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states),
|
347 |
+
hidden_states,
|
348 |
+
encoder_hidden_states,
|
349 |
+
)[0]
|
350 |
+
else:
|
351 |
+
hidden_states = resnet(hidden_states, temb)
|
352 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states, encoder_temporal_hidden_states=encoder_temporal_hidden_states).sample
|
353 |
+
|
354 |
+
output_states += (hidden_states,)
|
355 |
+
|
356 |
+
if self.downsamplers is not None:
|
357 |
+
for downsampler in self.downsamplers:
|
358 |
+
hidden_states = downsampler(hidden_states)
|
359 |
+
|
360 |
+
output_states += (hidden_states,)
|
361 |
+
|
362 |
+
return hidden_states, output_states
|
363 |
+
|
364 |
+
|
365 |
+
class DownBlock3D(nn.Module):
|
366 |
+
def __init__(
|
367 |
+
self,
|
368 |
+
in_channels: int,
|
369 |
+
out_channels: int,
|
370 |
+
temb_channels: int,
|
371 |
+
dropout: float = 0.0,
|
372 |
+
num_layers: int = 1,
|
373 |
+
resnet_eps: float = 1e-6,
|
374 |
+
resnet_time_scale_shift: str = "default",
|
375 |
+
resnet_act_fn: str = "swish",
|
376 |
+
resnet_groups: int = 32,
|
377 |
+
resnet_pre_norm: bool = True,
|
378 |
+
output_scale_factor=1.0,
|
379 |
+
add_downsample=True,
|
380 |
+
downsample_padding=1,
|
381 |
+
):
|
382 |
+
super().__init__()
|
383 |
+
resnets = []
|
384 |
+
|
385 |
+
for i in range(num_layers):
|
386 |
+
in_channels = in_channels if i == 0 else out_channels
|
387 |
+
resnets.append(
|
388 |
+
ResnetBlock3D(
|
389 |
+
in_channels=in_channels,
|
390 |
+
out_channels=out_channels,
|
391 |
+
temb_channels=temb_channels,
|
392 |
+
eps=resnet_eps,
|
393 |
+
groups=resnet_groups,
|
394 |
+
dropout=dropout,
|
395 |
+
time_embedding_norm=resnet_time_scale_shift,
|
396 |
+
non_linearity=resnet_act_fn,
|
397 |
+
output_scale_factor=output_scale_factor,
|
398 |
+
pre_norm=resnet_pre_norm,
|
399 |
+
)
|
400 |
+
)
|
401 |
+
|
402 |
+
self.resnets = nn.ModuleList(resnets)
|
403 |
+
|
404 |
+
if add_downsample:
|
405 |
+
self.downsamplers = nn.ModuleList(
|
406 |
+
[
|
407 |
+
Downsample3D(
|
408 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
409 |
+
)
|
410 |
+
]
|
411 |
+
)
|
412 |
+
else:
|
413 |
+
self.downsamplers = None
|
414 |
+
|
415 |
+
self.gradient_checkpointing = False
|
416 |
+
|
417 |
+
def forward(self, hidden_states, temb=None):
|
418 |
+
output_states = ()
|
419 |
+
|
420 |
+
for resnet in self.resnets:
|
421 |
+
if self.training and self.gradient_checkpointing:
|
422 |
+
|
423 |
+
def create_custom_forward(module):
|
424 |
+
def custom_forward(*inputs):
|
425 |
+
return module(*inputs)
|
426 |
+
|
427 |
+
return custom_forward
|
428 |
+
|
429 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
430 |
+
else:
|
431 |
+
hidden_states = resnet(hidden_states, temb)
|
432 |
+
|
433 |
+
output_states += (hidden_states,)
|
434 |
+
|
435 |
+
if self.downsamplers is not None:
|
436 |
+
for downsampler in self.downsamplers:
|
437 |
+
hidden_states = downsampler(hidden_states)
|
438 |
+
|
439 |
+
output_states += (hidden_states,)
|
440 |
+
|
441 |
+
return hidden_states, output_states
|
442 |
+
|
443 |
+
|
444 |
+
class CrossAttnUpBlock3D(nn.Module):
|
445 |
+
def __init__(
|
446 |
+
self,
|
447 |
+
in_channels: int,
|
448 |
+
out_channels: int,
|
449 |
+
prev_output_channel: int,
|
450 |
+
temb_channels: int,
|
451 |
+
dropout: float = 0.0,
|
452 |
+
num_layers: int = 1,
|
453 |
+
resnet_eps: float = 1e-6,
|
454 |
+
resnet_time_scale_shift: str = "default",
|
455 |
+
resnet_act_fn: str = "swish",
|
456 |
+
resnet_groups: int = 32,
|
457 |
+
resnet_pre_norm: bool = True,
|
458 |
+
attn_num_head_channels=1,
|
459 |
+
cross_attention_dim=1280,
|
460 |
+
output_scale_factor=1.0,
|
461 |
+
add_upsample=True,
|
462 |
+
dual_cross_attention=False,
|
463 |
+
use_linear_projection=False,
|
464 |
+
only_cross_attention=False,
|
465 |
+
upcast_attention=False,
|
466 |
+
use_first_frame=False,
|
467 |
+
use_relative_position=False,
|
468 |
+
rotary_emb=False
|
469 |
+
):
|
470 |
+
super().__init__()
|
471 |
+
resnets = []
|
472 |
+
attentions = []
|
473 |
+
|
474 |
+
self.has_cross_attention = True
|
475 |
+
self.attn_num_head_channels = attn_num_head_channels
|
476 |
+
|
477 |
+
for i in range(num_layers):
|
478 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
479 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
480 |
+
|
481 |
+
resnets.append(
|
482 |
+
ResnetBlock3D(
|
483 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
484 |
+
out_channels=out_channels,
|
485 |
+
temb_channels=temb_channels,
|
486 |
+
eps=resnet_eps,
|
487 |
+
groups=resnet_groups,
|
488 |
+
dropout=dropout,
|
489 |
+
time_embedding_norm=resnet_time_scale_shift,
|
490 |
+
non_linearity=resnet_act_fn,
|
491 |
+
output_scale_factor=output_scale_factor,
|
492 |
+
pre_norm=resnet_pre_norm,
|
493 |
+
)
|
494 |
+
)
|
495 |
+
if dual_cross_attention:
|
496 |
+
raise NotImplementedError
|
497 |
+
attentions.append(
|
498 |
+
Transformer3DModel(
|
499 |
+
attn_num_head_channels,
|
500 |
+
out_channels // attn_num_head_channels,
|
501 |
+
in_channels=out_channels,
|
502 |
+
num_layers=1,
|
503 |
+
cross_attention_dim=cross_attention_dim,
|
504 |
+
norm_num_groups=resnet_groups,
|
505 |
+
use_linear_projection=use_linear_projection,
|
506 |
+
only_cross_attention=only_cross_attention,
|
507 |
+
upcast_attention=upcast_attention,
|
508 |
+
use_first_frame=use_first_frame,
|
509 |
+
use_relative_position=use_relative_position,
|
510 |
+
rotary_emb=rotary_emb,
|
511 |
+
)
|
512 |
+
)
|
513 |
+
|
514 |
+
self.attentions = nn.ModuleList(attentions)
|
515 |
+
self.resnets = nn.ModuleList(resnets)
|
516 |
+
|
517 |
+
if add_upsample:
|
518 |
+
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
|
519 |
+
else:
|
520 |
+
self.upsamplers = None
|
521 |
+
|
522 |
+
self.gradient_checkpointing = False
|
523 |
+
|
524 |
+
def forward(
|
525 |
+
self,
|
526 |
+
hidden_states,
|
527 |
+
res_hidden_states_tuple,
|
528 |
+
temb=None,
|
529 |
+
encoder_hidden_states=None,
|
530 |
+
upsample_size=None,
|
531 |
+
attention_mask=None,
|
532 |
+
use_image_num=None,
|
533 |
+
ip_hidden_states=None,
|
534 |
+
encoder_temporal_hidden_states=None
|
535 |
+
):
|
536 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
537 |
+
# pop res hidden states
|
538 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
539 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
540 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
541 |
+
|
542 |
+
if self.training and self.gradient_checkpointing:
|
543 |
+
|
544 |
+
def create_custom_forward(module, return_dict=None):
|
545 |
+
def custom_forward(*inputs):
|
546 |
+
if return_dict is not None:
|
547 |
+
return module(*inputs, return_dict=return_dict)
|
548 |
+
else:
|
549 |
+
return module(*inputs)
|
550 |
+
|
551 |
+
return custom_forward
|
552 |
+
|
553 |
+
def create_custom_forward_attn(module, return_dict=None, use_image_num=None, ip_hidden_states=None):
|
554 |
+
def custom_forward(*inputs):
|
555 |
+
if return_dict is not None:
|
556 |
+
return module(*inputs, return_dict=return_dict, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states)
|
557 |
+
else:
|
558 |
+
return module(*inputs, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states)
|
559 |
+
|
560 |
+
return custom_forward
|
561 |
+
|
562 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
563 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
564 |
+
create_custom_forward_attn(attn, return_dict=False, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states),
|
565 |
+
hidden_states,
|
566 |
+
encoder_hidden_states,
|
567 |
+
)[0]
|
568 |
+
else:
|
569 |
+
hidden_states = resnet(hidden_states, temb)
|
570 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states, encoder_temporal_hidden_states=encoder_temporal_hidden_states).sample
|
571 |
+
|
572 |
+
if self.upsamplers is not None:
|
573 |
+
for upsampler in self.upsamplers:
|
574 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
575 |
+
|
576 |
+
return hidden_states
|
577 |
+
|
578 |
+
|
579 |
+
class UpBlock3D(nn.Module):
|
580 |
+
def __init__(
|
581 |
+
self,
|
582 |
+
in_channels: int,
|
583 |
+
prev_output_channel: int,
|
584 |
+
out_channels: int,
|
585 |
+
temb_channels: int,
|
586 |
+
dropout: float = 0.0,
|
587 |
+
num_layers: int = 1,
|
588 |
+
resnet_eps: float = 1e-6,
|
589 |
+
resnet_time_scale_shift: str = "default",
|
590 |
+
resnet_act_fn: str = "swish",
|
591 |
+
resnet_groups: int = 32,
|
592 |
+
resnet_pre_norm: bool = True,
|
593 |
+
output_scale_factor=1.0,
|
594 |
+
add_upsample=True,
|
595 |
+
):
|
596 |
+
super().__init__()
|
597 |
+
resnets = []
|
598 |
+
|
599 |
+
for i in range(num_layers):
|
600 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
601 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
602 |
+
|
603 |
+
resnets.append(
|
604 |
+
ResnetBlock3D(
|
605 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
606 |
+
out_channels=out_channels,
|
607 |
+
temb_channels=temb_channels,
|
608 |
+
eps=resnet_eps,
|
609 |
+
groups=resnet_groups,
|
610 |
+
dropout=dropout,
|
611 |
+
time_embedding_norm=resnet_time_scale_shift,
|
612 |
+
non_linearity=resnet_act_fn,
|
613 |
+
output_scale_factor=output_scale_factor,
|
614 |
+
pre_norm=resnet_pre_norm,
|
615 |
+
)
|
616 |
+
)
|
617 |
+
|
618 |
+
self.resnets = nn.ModuleList(resnets)
|
619 |
+
|
620 |
+
if add_upsample:
|
621 |
+
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
|
622 |
+
else:
|
623 |
+
self.upsamplers = None
|
624 |
+
|
625 |
+
self.gradient_checkpointing = False
|
626 |
+
|
627 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
628 |
+
for resnet in self.resnets:
|
629 |
+
# pop res hidden states
|
630 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
631 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
632 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
633 |
+
|
634 |
+
if self.training and self.gradient_checkpointing:
|
635 |
+
|
636 |
+
def create_custom_forward(module):
|
637 |
+
def custom_forward(*inputs):
|
638 |
+
return module(*inputs)
|
639 |
+
|
640 |
+
return custom_forward
|
641 |
+
|
642 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
643 |
+
else:
|
644 |
+
hidden_states = resnet(hidden_states, temb)
|
645 |
+
|
646 |
+
if self.upsamplers is not None:
|
647 |
+
for upsampler in self.upsamplers:
|
648 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
649 |
+
|
650 |
+
return hidden_states
|
models/utils.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# adopted from
|
2 |
+
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
3 |
+
# and
|
4 |
+
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
5 |
+
# and
|
6 |
+
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
7 |
+
#
|
8 |
+
# thanks!
|
9 |
+
|
10 |
+
|
11 |
+
import os
|
12 |
+
import math
|
13 |
+
import torch
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import torch.nn as nn
|
17 |
+
|
18 |
+
from einops import repeat
|
19 |
+
|
20 |
+
|
21 |
+
#################################################################################
|
22 |
+
# Unet Utils #
|
23 |
+
#################################################################################
|
24 |
+
|
25 |
+
def checkpoint(func, inputs, params, flag):
|
26 |
+
"""
|
27 |
+
Evaluate a function without caching intermediate activations, allowing for
|
28 |
+
reduced memory at the expense of extra compute in the backward pass.
|
29 |
+
:param func: the function to evaluate.
|
30 |
+
:param inputs: the argument sequence to pass to `func`.
|
31 |
+
:param params: a sequence of parameters `func` depends on but does not
|
32 |
+
explicitly take as arguments.
|
33 |
+
:param flag: if False, disable gradient checkpointing.
|
34 |
+
"""
|
35 |
+
if flag:
|
36 |
+
args = tuple(inputs) + tuple(params)
|
37 |
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
38 |
+
else:
|
39 |
+
return func(*inputs)
|
40 |
+
|
41 |
+
|
42 |
+
class CheckpointFunction(torch.autograd.Function):
|
43 |
+
@staticmethod
|
44 |
+
def forward(ctx, run_function, length, *args):
|
45 |
+
ctx.run_function = run_function
|
46 |
+
ctx.input_tensors = list(args[:length])
|
47 |
+
ctx.input_params = list(args[length:])
|
48 |
+
|
49 |
+
with torch.no_grad():
|
50 |
+
output_tensors = ctx.run_function(*ctx.input_tensors)
|
51 |
+
return output_tensors
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def backward(ctx, *output_grads):
|
55 |
+
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
56 |
+
with torch.enable_grad():
|
57 |
+
# Fixes a bug where the first op in run_function modifies the
|
58 |
+
# Tensor storage in place, which is not allowed for detach()'d
|
59 |
+
# Tensors.
|
60 |
+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
61 |
+
output_tensors = ctx.run_function(*shallow_copies)
|
62 |
+
input_grads = torch.autograd.grad(
|
63 |
+
output_tensors,
|
64 |
+
ctx.input_tensors + ctx.input_params,
|
65 |
+
output_grads,
|
66 |
+
allow_unused=True,
|
67 |
+
)
|
68 |
+
del ctx.input_tensors
|
69 |
+
del ctx.input_params
|
70 |
+
del output_tensors
|
71 |
+
return (None, None) + input_grads
|
72 |
+
|
73 |
+
|
74 |
+
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
75 |
+
"""
|
76 |
+
Create sinusoidal timestep embeddings.
|
77 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
78 |
+
These may be fractional.
|
79 |
+
:param dim: the dimension of the output.
|
80 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
81 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
82 |
+
"""
|
83 |
+
if not repeat_only:
|
84 |
+
half = dim // 2
|
85 |
+
freqs = torch.exp(
|
86 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
87 |
+
).to(device=timesteps.device)
|
88 |
+
args = timesteps[:, None].float() * freqs[None]
|
89 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
90 |
+
if dim % 2:
|
91 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
92 |
+
else:
|
93 |
+
embedding = repeat(timesteps, 'b -> b d', d=dim).contiguous()
|
94 |
+
return embedding
|
95 |
+
|
96 |
+
|
97 |
+
def zero_module(module):
|
98 |
+
"""
|
99 |
+
Zero out the parameters of a module and return it.
|
100 |
+
"""
|
101 |
+
for p in module.parameters():
|
102 |
+
p.detach().zero_()
|
103 |
+
return module
|
104 |
+
|
105 |
+
|
106 |
+
def scale_module(module, scale):
|
107 |
+
"""
|
108 |
+
Scale the parameters of a module and return it.
|
109 |
+
"""
|
110 |
+
for p in module.parameters():
|
111 |
+
p.detach().mul_(scale)
|
112 |
+
return module
|
113 |
+
|
114 |
+
|
115 |
+
def mean_flat(tensor):
|
116 |
+
"""
|
117 |
+
Take the mean over all non-batch dimensions.
|
118 |
+
"""
|
119 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
120 |
+
|
121 |
+
|
122 |
+
def normalization(channels):
|
123 |
+
"""
|
124 |
+
Make a standard normalization layer.
|
125 |
+
:param channels: number of input channels.
|
126 |
+
:return: an nn.Module for normalization.
|
127 |
+
"""
|
128 |
+
return GroupNorm32(32, channels)
|
129 |
+
|
130 |
+
|
131 |
+
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
132 |
+
class SiLU(nn.Module):
|
133 |
+
def forward(self, x):
|
134 |
+
return x * torch.sigmoid(x)
|
135 |
+
|
136 |
+
|
137 |
+
class GroupNorm32(nn.GroupNorm):
|
138 |
+
def forward(self, x):
|
139 |
+
return super().forward(x.float()).type(x.dtype)
|
140 |
+
|
141 |
+
def conv_nd(dims, *args, **kwargs):
|
142 |
+
"""
|
143 |
+
Create a 1D, 2D, or 3D convolution module.
|
144 |
+
"""
|
145 |
+
if dims == 1:
|
146 |
+
return nn.Conv1d(*args, **kwargs)
|
147 |
+
elif dims == 2:
|
148 |
+
return nn.Conv2d(*args, **kwargs)
|
149 |
+
elif dims == 3:
|
150 |
+
return nn.Conv3d(*args, **kwargs)
|
151 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
152 |
+
|
153 |
+
|
154 |
+
def linear(*args, **kwargs):
|
155 |
+
"""
|
156 |
+
Create a linear module.
|
157 |
+
"""
|
158 |
+
return nn.Linear(*args, **kwargs)
|
159 |
+
|
160 |
+
|
161 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
162 |
+
"""
|
163 |
+
Create a 1D, 2D, or 3D average pooling module.
|
164 |
+
"""
|
165 |
+
if dims == 1:
|
166 |
+
return nn.AvgPool1d(*args, **kwargs)
|
167 |
+
elif dims == 2:
|
168 |
+
return nn.AvgPool2d(*args, **kwargs)
|
169 |
+
elif dims == 3:
|
170 |
+
return nn.AvgPool3d(*args, **kwargs)
|
171 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
172 |
+
|
173 |
+
|
174 |
+
# class HybridConditioner(nn.Module):
|
175 |
+
|
176 |
+
# def __init__(self, c_concat_config, c_crossattn_config):
|
177 |
+
# super().__init__()
|
178 |
+
# self.concat_conditioner = instantiate_from_config(c_concat_config)
|
179 |
+
# self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
|
180 |
+
|
181 |
+
# def forward(self, c_concat, c_crossattn):
|
182 |
+
# c_concat = self.concat_conditioner(c_concat)
|
183 |
+
# c_crossattn = self.crossattn_conditioner(c_crossattn)
|
184 |
+
# return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
|
185 |
+
|
186 |
+
|
187 |
+
def noise_like(shape, device, repeat=False):
|
188 |
+
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
189 |
+
noise = lambda: torch.randn(shape, device=device)
|
190 |
+
return repeat_noise() if repeat else noise()
|
191 |
+
|
192 |
+
def count_flops_attn(model, _x, y):
|
193 |
+
"""
|
194 |
+
A counter for the `thop` package to count the operations in an
|
195 |
+
attention operation.
|
196 |
+
Meant to be used like:
|
197 |
+
macs, params = thop.profile(
|
198 |
+
model,
|
199 |
+
inputs=(inputs, timestamps),
|
200 |
+
custom_ops={QKVAttention: QKVAttention.count_flops},
|
201 |
+
)
|
202 |
+
"""
|
203 |
+
b, c, *spatial = y[0].shape
|
204 |
+
num_spatial = int(np.prod(spatial))
|
205 |
+
# We perform two matmuls with the same number of ops.
|
206 |
+
# The first computes the weight matrix, the second computes
|
207 |
+
# the combination of the value vectors.
|
208 |
+
matmul_ops = 2 * b * (num_spatial ** 2) * c
|
209 |
+
model.total_ops += torch.DoubleTensor([matmul_ops])
|
210 |
+
|
211 |
+
def count_params(model, verbose=False):
|
212 |
+
total_params = sum(p.numel() for p in model.parameters())
|
213 |
+
if verbose:
|
214 |
+
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
|
215 |
+
return total_params
|
requirements.txt
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
bark_ssg==1.3.4
|
2 |
+
decord==0.6.0
|
3 |
+
diffusers==0.25.0
|
4 |
+
einops==0.7.0
|
5 |
+
imageio==2.28.0
|
6 |
+
ipython==8.14.0
|
7 |
+
librosa==0.10.1
|
8 |
+
mmcv==2.1.0
|
9 |
+
moviepy==1.0.3
|
10 |
+
natsort==8.3.1
|
11 |
+
nltk==3.8.1
|
12 |
+
numpy==1.23.5
|
13 |
+
omegaconf==2.3.0
|
14 |
+
openai==0.27.8
|
15 |
+
opencv_python==4.7.0.72
|
16 |
+
Pillow==9.4.0
|
17 |
+
Pillow==10.2.0
|
18 |
+
pytorch_lightning==2.0.2
|
19 |
+
rotary_embedding_torch==0.2.3
|
20 |
+
soundfile==0.12.1
|
21 |
+
torch==2.0.0
|
22 |
+
torchvision==0.15.0
|
23 |
+
tqdm==4.65.0
|
24 |
+
transformers==4.28.1
|
25 |
+
xformers==0.0.19
|
results/mask_no_ref/Planet_hits_earth..mp4
ADDED
Binary file (326 kB). View file
|
|
results/mask_ref/Planet_hits_earth..mp4
ADDED
Binary file (345 kB). View file
|
|
results/vlog/teddy_travel/ref_img/teddy.jpg
ADDED
results/vlog/teddy_travel/script/protagonist_place_reference.txt
ADDED
Binary file (1.53 kB). View file
|
|
results/vlog/teddy_travel/script/protagonists_places.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"id": 1,
|
4 |
+
"name": "Teddy",
|
5 |
+
"description": "A teddy bear with a dream of traveling the world"
|
6 |
+
},
|
7 |
+
{
|
8 |
+
"id": 2,
|
9 |
+
"name": "Eiffel Tower",
|
10 |
+
"description": "An iconic wrought-iron lattice tower located in Paris, France"
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"id": 3,
|
14 |
+
"name": "Great Wall",
|
15 |
+
"description": "A vast, historic fortification system that stretches across the northern part of China"
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"id": 4,
|
19 |
+
"name": "Pyramids",
|
20 |
+
"description": "Ancient monumental structures located in Egypt"
|
21 |
+
}
|
22 |
+
]
|
results/vlog/teddy_travel/script/time_scripts.txt
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"video fragment id": 1,
|
4 |
+
"time": 2
|
5 |
+
},
|
6 |
+
{
|
7 |
+
"video fragment id": 2,
|
8 |
+
"time": 3
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"video fragment id": 3,
|
12 |
+
"time": 3
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"video fragment id": 4,
|
16 |
+
"time": 2
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"video fragment id": 5,
|
20 |
+
"time": 2
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"video fragment id": 6,
|
24 |
+
"time": 3
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"video fragment id": 7,
|
28 |
+
"time": 2
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"video fragment id": 8,
|
32 |
+
"time": 3
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"video fragment id": 9,
|
36 |
+
"time": 2
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"video fragment id": 10,
|
40 |
+
"time": 2
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"video fragment id": 11,
|
44 |
+
"time": 3
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"video fragment id": 12,
|
48 |
+
"time": 2
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"video fragment id": 13,
|
52 |
+
"time": 2
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"video fragment id": 14,
|
56 |
+
"time": 3
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"video fragment id": 15,
|
60 |
+
"time": 3
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"video fragment id": 16,
|
64 |
+
"time": 2
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"video fragment id": 17,
|
68 |
+
"time": 3
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"video fragment id": 18,
|
72 |
+
"time": 2
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"video fragment id": 19,
|
76 |
+
"time": 3
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"video fragment id": 20,
|
80 |
+
"time": 2
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"video fragment id": 21,
|
84 |
+
"time": 3
|
85 |
+
},
|
86 |
+
{
|
87 |
+
"video fragment id": 22,
|
88 |
+
"time": 2
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"video fragment id": 23,
|
92 |
+
"time": 3
|
93 |
+
}
|
94 |
+
]
|
results/vlog/teddy_travel/script/video_prompts.txt
ADDED
Binary file (2.61 kB). View file
|
|
results/vlog/teddy_travel/script/zh_video_prompts.txt
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"序号": 1,
|
4 |
+
"描述": "泰迪熊在孩子的房间里。",
|
5 |
+
},
|
6 |
+
{
|
7 |
+
"序号": 2,
|
8 |
+
"描述": "泰迪熊正在做梦。",
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"序号": 3,
|
12 |
+
"描述": "梦想着旅行。",
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"序号": 4,
|
16 |
+
"描述": "泰迪熊在机场。",
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"序号": 5,
|
20 |
+
"描述": "泰迪熊从背包中探出头来。",
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"序号": 6,
|
24 |
+
"描述": "泰迪熊在野餐毯上。",
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"序号": 7,
|
28 |
+
"描述": "背景是埃菲尔铁塔。",
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"序号": 8,
|
32 |
+
"描述": "泰迪熊正在享受巴黎野餐。",
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"序号": 9,
|
36 |
+
"描述": "泰迪熊周围是羊角面包。",
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"序号": 10,
|
40 |
+
"描述": "泰迪熊在长城顶部。",
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"序号": 11,
|
44 |
+
"描述": "泰迪熊正在欣赏风景。",
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"序号": 12,
|
48 |
+
"描述": "泰迪熊在埃及探索金字塔。",
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"序号": 13,
|
52 |
+
"描述": "炎热的埃及阳光下。",
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"序号": 14,
|
56 |
+
"描述": "泰迪熊找到了一个宝箱。",
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"序号": 15,
|
60 |
+
"描述": "宝箱在金字塔内部。",
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"序号": 16,
|
64 |
+
"描述": "泰迪熊回到卧室。",
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"序号": 17,
|
68 |
+
"描述": "分享旅行故事。",
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"序号": 18,
|
72 |
+
"描述": "一个小女孩在反应。",
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"序号": 19,
|
76 |
+
"描述": "惊讶于泰迪熊的故事。",
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"序号": 20,
|
80 |
+
"描述": "房间里满是纪念品。",
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"序号": 21,
|
84 |
+
"描述": "来自泰迪熊旅行的纪念品。",
|
85 |
+
},
|
86 |
+
{
|
87 |
+
"序号": 22,
|
88 |
+
"描述": "泰迪熊正在看世界地图。",
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"序号": 23,
|
92 |
+
"描述": "梦想着下一次的冒险。",
|
93 |
+
}
|
94 |
+
|
95 |
+
]
|
results/vlog/teddy_travel/story.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Once upon a time, there was a teddy bear named Teddy who dreamed of traveling the world. One day, his dream came true to travel around the world. Teddy sat in the airport lobby and traveled to many places of interest. Along the way, Teddy visited the Eiffel Tower, the Great Wall, and the pyramids. In Paris, Teddy had a picnic and enjoyed some delicious croissants. At the Great Wall of China, he climbed to the top and marveled at the breathtaking view. And in Egypt, he explored the pyramids and even found a secret treasure hidden inside. After his exciting journey, Teddy was eventually reunited with his owner who was thrilled to hear about all of Teddy’s adventures. From that day on, Teddy always dreamed of traveling the world again and experiencing new and exciting things.
|
results/vlog/teddy_travel_/story.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Once upon a time, there was a teddy bear named Teddy who dreamed of traveling the world. One day, his dream came true to travel around the world. Teddy sat in the airport lobby and traveled to many places of interest. Along the way, Teddy visited the Eiffel Tower, the Great Wall, and the pyramids. In Paris, Teddy had a picnic and enjoyed some delicious croissants. At the Great Wall of China, he climbed to the top and marveled at the breathtaking view. And in Egypt, he explored the pyramids and even found a secret treasure hidden inside. After his exciting journey, Teddy was eventually reunited with his owner who was thrilled to hear about all of Teddy’s adventures. From that day on, Teddy always dreamed of traveling the world again and experiencing new and exciting things.
|
sample_scripts/vlog_read_script_sample.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
4 |
+
torch.backends.cudnn.allow_tf32 = True
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
try:
|
8 |
+
import utils
|
9 |
+
from diffusion import create_diffusion
|
10 |
+
except:
|
11 |
+
sys.path.append(os.path.split(sys.path[0])[0])
|
12 |
+
import utils
|
13 |
+
from diffusion import create_diffusion
|
14 |
+
import argparse
|
15 |
+
import torchvision
|
16 |
+
from PIL import Image
|
17 |
+
from einops import rearrange
|
18 |
+
from models import get_models
|
19 |
+
from diffusers.models import AutoencoderKL
|
20 |
+
from models.clip import TextEmbedder
|
21 |
+
from omegaconf import OmegaConf
|
22 |
+
from pytorch_lightning import seed_everything
|
23 |
+
from utils import mask_generation_before
|
24 |
+
from diffusers.utils.import_utils import is_xformers_available
|
25 |
+
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
|
26 |
+
from vlogger.videofusion import fusion
|
27 |
+
from vlogger.videocaption import captioning
|
28 |
+
from vlogger.videoaudio import make_audio, merge_video_audio, concatenate_videos
|
29 |
+
from vlogger.STEB.model_transform import ip_scale_set, ip_transform_model, tca_transform_model
|
30 |
+
from vlogger.planning_utils.gpt4_utils import (readscript,
|
31 |
+
readtimescript,
|
32 |
+
readprotagonistscript,
|
33 |
+
readreferencescript,
|
34 |
+
readzhscript)
|
35 |
+
|
36 |
+
|
37 |
+
def auto_inpainting(args,
|
38 |
+
video_input,
|
39 |
+
masked_video,
|
40 |
+
mask,
|
41 |
+
prompt,
|
42 |
+
image,
|
43 |
+
vae,
|
44 |
+
text_encoder,
|
45 |
+
image_encoder,
|
46 |
+
diffusion,
|
47 |
+
model,
|
48 |
+
device,
|
49 |
+
):
|
50 |
+
image_prompt_embeds = None
|
51 |
+
if prompt is None:
|
52 |
+
prompt = ""
|
53 |
+
if image is not None:
|
54 |
+
clip_image = CLIPImageProcessor()(images=image, return_tensors="pt").pixel_values
|
55 |
+
clip_image_embeds = image_encoder(clip_image.to(device)).image_embeds
|
56 |
+
uncond_clip_image_embeds = torch.zeros_like(clip_image_embeds).to(device)
|
57 |
+
image_prompt_embeds = torch.cat([clip_image_embeds, uncond_clip_image_embeds], dim=0)
|
58 |
+
image_prompt_embeds = rearrange(image_prompt_embeds, '(b n) c -> b n c', b=2).contiguous()
|
59 |
+
model = ip_scale_set(model, args.ref_cfg_scale)
|
60 |
+
if args.use_fp16:
|
61 |
+
image_prompt_embeds = image_prompt_embeds.to(dtype=torch.float16)
|
62 |
+
b, f, c, h, w = video_input.shape
|
63 |
+
latent_h = video_input.shape[-2] // 8
|
64 |
+
latent_w = video_input.shape[-1] // 8
|
65 |
+
|
66 |
+
if args.use_fp16:
|
67 |
+
z = torch.randn(1, 4, 16, latent_h, latent_w, dtype=torch.float16, device=device) # b,c,f,h,w
|
68 |
+
masked_video = masked_video.to(dtype=torch.float16)
|
69 |
+
mask = mask.to(dtype=torch.float16)
|
70 |
+
else:
|
71 |
+
z = torch.randn(1, 4, 16, latent_h, latent_w, device=device) # b,c,f,h,w
|
72 |
+
|
73 |
+
masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous()
|
74 |
+
masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215)
|
75 |
+
masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous()
|
76 |
+
mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1)
|
77 |
+
masked_video = torch.cat([masked_video] * 2)
|
78 |
+
mask = torch.cat([mask] * 2)
|
79 |
+
z = torch.cat([z] * 2)
|
80 |
+
prompt_all = [prompt] + [args.negative_prompt]
|
81 |
+
|
82 |
+
text_prompt = text_encoder(text_prompts=prompt_all, train=False)
|
83 |
+
model_kwargs = dict(encoder_hidden_states=text_prompt,
|
84 |
+
class_labels=None,
|
85 |
+
cfg_scale=args.cfg_scale,
|
86 |
+
use_fp16=args.use_fp16,
|
87 |
+
ip_hidden_states=image_prompt_embeds)
|
88 |
+
|
89 |
+
# Sample images:
|
90 |
+
samples = diffusion.ddim_sample_loop(model.forward_with_cfg,
|
91 |
+
z.shape,
|
92 |
+
z,
|
93 |
+
clip_denoised=False,
|
94 |
+
model_kwargs=model_kwargs,
|
95 |
+
progress=True,
|
96 |
+
device=device,
|
97 |
+
mask=mask,
|
98 |
+
x_start=masked_video,
|
99 |
+
use_concat=True,
|
100 |
+
)
|
101 |
+
samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32]
|
102 |
+
if args.use_fp16:
|
103 |
+
samples = samples.to(dtype=torch.float16)
|
104 |
+
|
105 |
+
video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32]
|
106 |
+
video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256]
|
107 |
+
return video_clip
|
108 |
+
|
109 |
+
|
110 |
+
def main(args):
|
111 |
+
# Setup PyTorch:
|
112 |
+
if args.seed:
|
113 |
+
torch.manual_seed(args.seed)
|
114 |
+
torch.set_grad_enabled(False)
|
115 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
116 |
+
seed_everything(args.seed)
|
117 |
+
|
118 |
+
model = get_models(args).to(device)
|
119 |
+
model = tca_transform_model(model).to(device)
|
120 |
+
model = ip_transform_model(model).to(device)
|
121 |
+
if args.enable_xformers_memory_efficient_attention:
|
122 |
+
if is_xformers_available():
|
123 |
+
model.enable_xformers_memory_efficient_attention()
|
124 |
+
else:
|
125 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
126 |
+
if args.use_compile:
|
127 |
+
model = torch.compile(model)
|
128 |
+
|
129 |
+
ckpt_path = args.ckpt
|
130 |
+
state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema']
|
131 |
+
model_dict = model.state_dict()
|
132 |
+
pretrained_dict = {}
|
133 |
+
for k, v in state_dict.items():
|
134 |
+
if k in model_dict:
|
135 |
+
pretrained_dict[k] = v
|
136 |
+
model_dict.update(pretrained_dict)
|
137 |
+
model.load_state_dict(model_dict)
|
138 |
+
|
139 |
+
model.eval() # important!
|
140 |
+
diffusion = create_diffusion(str(args.num_sampling_steps))
|
141 |
+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device)
|
142 |
+
text_encoder = text_encoder = TextEmbedder(args.pretrained_model_path).to(device)
|
143 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path).to(device)
|
144 |
+
if args.use_fp16:
|
145 |
+
print('Warnning: using half percision for inferencing!')
|
146 |
+
vae.to(dtype=torch.float16)
|
147 |
+
model.to(dtype=torch.float16)
|
148 |
+
text_encoder.to(dtype=torch.float16)
|
149 |
+
print("model ready!\n", flush=True)
|
150 |
+
|
151 |
+
|
152 |
+
# load protagonist script
|
153 |
+
character_places = readprotagonistscript(args.protagonist_file_path)
|
154 |
+
print("protagonists ready!", flush=True)
|
155 |
+
|
156 |
+
# load script
|
157 |
+
video_list = readscript(args.script_file_path)
|
158 |
+
print("video script ready!", flush=True)
|
159 |
+
|
160 |
+
# load reference script
|
161 |
+
reference_lists = readreferencescript(video_list, character_places, args.reference_file_path)
|
162 |
+
print("reference script ready!", flush=True)
|
163 |
+
|
164 |
+
# load zh script
|
165 |
+
zh_video_list = readzhscript(args.zh_script_file_path)
|
166 |
+
print("zh script ready!", flush=True)
|
167 |
+
|
168 |
+
# load time script
|
169 |
+
key_list = []
|
170 |
+
for key, value in character_places.items():
|
171 |
+
key_list.append(key)
|
172 |
+
time_list = readtimescript(args.time_file_path)
|
173 |
+
print("time script ready!", flush=True)
|
174 |
+
|
175 |
+
|
176 |
+
# generation begin
|
177 |
+
sample_list = []
|
178 |
+
for i, text_prompt in enumerate(video_list):
|
179 |
+
sample_list.append([])
|
180 |
+
for time in range(time_list[i]):
|
181 |
+
if time == 0:
|
182 |
+
print('Generating the ({}) prompt'.format(text_prompt), flush=True)
|
183 |
+
if reference_lists[i][0] == 0 or reference_lists[i][0] > len(key_list):
|
184 |
+
pil_image = None
|
185 |
+
else:
|
186 |
+
pil_image = Image.open(args.reference_image_path[reference_lists[i][0] - 1])
|
187 |
+
pil_image.resize((256, 256))
|
188 |
+
video_input = torch.zeros([1, 16, 3, args.image_size[0], args.image_size[1]]).to(device)
|
189 |
+
mask = mask_generation_before("first0", video_input.shape, video_input.dtype, device) # b,f,c,h,w
|
190 |
+
masked_video = video_input * (mask == 0)
|
191 |
+
samples = auto_inpainting(args,
|
192 |
+
video_input,
|
193 |
+
masked_video,
|
194 |
+
mask,
|
195 |
+
text_prompt,
|
196 |
+
pil_image,
|
197 |
+
vae,
|
198 |
+
text_encoder,
|
199 |
+
image_encoder,
|
200 |
+
diffusion,
|
201 |
+
model,
|
202 |
+
device,
|
203 |
+
)
|
204 |
+
sample_list[i].append(samples)
|
205 |
+
else:
|
206 |
+
if sum(video.shape[0] for video in sample_list[i]) / args.fps >= time_list[i]:
|
207 |
+
break
|
208 |
+
print('Generating the ({}) prompt'.format(text_prompt), flush=True)
|
209 |
+
if reference_lists[i][0] == 0 or reference_lists[i][0] > len(key_list):
|
210 |
+
pil_image = None
|
211 |
+
else:
|
212 |
+
pil_image = Image.open(args.reference_image_path[reference_lists[i][0] - 1])
|
213 |
+
pil_image.resize((256, 256))
|
214 |
+
pre_video = sample_list[i][-1][-args.researve_frame:]
|
215 |
+
f, c, h, w = pre_video.shape
|
216 |
+
lat_video = torch.zeros(args.num_frames - args.researve_frame, c, h, w).to(device)
|
217 |
+
video_input = torch.concat([pre_video, lat_video], dim=0)
|
218 |
+
video_input = video_input.to(device).unsqueeze(0)
|
219 |
+
mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device)
|
220 |
+
masked_video = video_input * (mask == 0)
|
221 |
+
video_clip = auto_inpainting(args,
|
222 |
+
video_input,
|
223 |
+
masked_video,
|
224 |
+
mask,
|
225 |
+
text_prompt,
|
226 |
+
pil_image,
|
227 |
+
vae,
|
228 |
+
text_encoder,
|
229 |
+
image_encoder,
|
230 |
+
diffusion,
|
231 |
+
model,
|
232 |
+
device,
|
233 |
+
)
|
234 |
+
sample_list[i].append(video_clip[args.researve_frame:])
|
235 |
+
print(video_clip[args.researve_frame:].shape)
|
236 |
+
|
237 |
+
# transition
|
238 |
+
if args.video_transition and i != 0:
|
239 |
+
video_1 = sample_list[i - 1][-1][-1:]
|
240 |
+
video_2 = sample_list[i][0][:1]
|
241 |
+
f, c, h, w = video_1.shape
|
242 |
+
video_middle = torch.zeros(args.num_frames - 2, c, h, w).to(device)
|
243 |
+
video_input = torch.concat([video_1, video_middle, video_2], dim=0)
|
244 |
+
video_input = video_input.to(device).unsqueeze(0)
|
245 |
+
mask = mask_generation_before("onelast1", video_input.shape, video_input.dtype, device)
|
246 |
+
masked_video = masked_video = video_input * (mask == 0)
|
247 |
+
video_clip = auto_inpainting(args,
|
248 |
+
video_input,
|
249 |
+
masked_video,
|
250 |
+
mask,
|
251 |
+
"smooth transition, slow motion, slow changing.",
|
252 |
+
pil_image,
|
253 |
+
vae,
|
254 |
+
text_encoder,
|
255 |
+
image_encoder,
|
256 |
+
diffusion,
|
257 |
+
model,
|
258 |
+
device,
|
259 |
+
)
|
260 |
+
sample_list[i].insert(0, video_clip[1:-1])
|
261 |
+
|
262 |
+
# save videos
|
263 |
+
samples = torch.concat(sample_list[i], dim=0)
|
264 |
+
samples = samples[0: time_list[i] * args.fps]
|
265 |
+
if not os.path.exists(args.save_origin_video_path):
|
266 |
+
os.makedirs(args.save_origin_video_path)
|
267 |
+
video_ = ((samples * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1).contiguous()
|
268 |
+
torchvision.io.write_video(args.save_origin_video_path + "/" + f"{i}" + '.mp4', video_, fps=args.fps)
|
269 |
+
|
270 |
+
# post processing
|
271 |
+
fusion(args.save_origin_video_path)
|
272 |
+
captioning(args.script_file_path, args.zh_script_file_path, args.save_origin_video_path, args.save_caption_video_path)
|
273 |
+
fusion(args.save_caption_video_path)
|
274 |
+
make_audio(args.script_file_path, args.save_audio_path)
|
275 |
+
merge_video_audio(args.save_caption_video_path, args.save_audio_path, args.save_audio_caption_video_path)
|
276 |
+
concatenate_videos(args.save_audio_caption_video_path)
|
277 |
+
print('final video save path {}'.format(args.save_audio_caption_video_path))
|
278 |
+
|
279 |
+
|
280 |
+
if __name__ == "__main__":
|
281 |
+
parser = argparse.ArgumentParser()
|
282 |
+
parser.add_argument("--config", type=str, default="configs/vlog_read_script_sample.yaml")
|
283 |
+
args = parser.parse_args()
|
284 |
+
omega_conf = OmegaConf.load(args.config)
|
285 |
+
save_path = omega_conf.save_path
|
286 |
+
save_origin_video_path = os.path.join(save_path, "origin_video")
|
287 |
+
save_caption_video_path = os.path.join(save_path.rsplit('/', 1)[0], "caption_video")
|
288 |
+
save_audio_path = os.path.join(save_path.rsplit('/', 1)[0], "audio")
|
289 |
+
save_audio_caption_video_path = os.path.join(save_path.rsplit('/', 1)[0], "audio_caption_video")
|
290 |
+
if omega_conf.sample_num is not None:
|
291 |
+
for i in range(omega_conf.sample_num):
|
292 |
+
omega_conf.save_origin_video_path = save_origin_video_path + f'-{i}'
|
293 |
+
omega_conf.save_caption_video_path = save_caption_video_path + f'-{i}'
|
294 |
+
omega_conf.save_audio_path = save_audio_path + f'-{i}'
|
295 |
+
omega_conf.save_audio_caption_video_path = save_audio_caption_video_path + f'-{i}'
|
296 |
+
omega_conf.seed += i
|
297 |
+
main(omega_conf)
|
298 |
+
else:
|
299 |
+
omega_conf.save_origin_video_path = save_origin_video_path
|
300 |
+
omega_conf.save_caption_video_path = save_caption_video_path
|
301 |
+
omega_conf.save_audio_path = save_audio_path
|
302 |
+
omega_conf.save_audio_caption_video_path = save_audio_caption_video_path
|
303 |
+
main(omega_conf)
|
sample_scripts/vlog_write_script.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
os.environ['CURL_CA_BUNDLE'] = ''
|
4 |
+
import argparse
|
5 |
+
from omegaconf import OmegaConf
|
6 |
+
from diffusers import DiffusionPipeline
|
7 |
+
from vlogger.planning_utils.gpt4_utils import (ExtractProtagonist,
|
8 |
+
ExtractAProtagonist,
|
9 |
+
split_story,
|
10 |
+
patch_story_scripts,
|
11 |
+
refine_story_scripts,
|
12 |
+
protagonist_place_reference1,
|
13 |
+
translate_video_script,
|
14 |
+
time_scripts,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
def main(args):
|
19 |
+
story_path = args.story_path
|
20 |
+
save_script_path = os.path.join(story_path.rsplit('/', 1)[0], "script")
|
21 |
+
if not os.path.exists(save_script_path):
|
22 |
+
os.makedirs(save_script_path)
|
23 |
+
with open(story_path, "r") as story_file:
|
24 |
+
story = story_file.read()
|
25 |
+
|
26 |
+
# summerize protagonists and places
|
27 |
+
protagonists_places_file_path = os.path.join(save_script_path, "protagonists_places.txt")
|
28 |
+
if args.only_one_protagonist:
|
29 |
+
character_places = ExtractAProtagonist(story, protagonists_places_file_path)
|
30 |
+
else:
|
31 |
+
character_places = ExtractProtagonist(story, protagonists_places_file_path)
|
32 |
+
print("Protagonists and places OK", flush=True)
|
33 |
+
|
34 |
+
# make script
|
35 |
+
script_file_path = os.path.join(save_script_path, "video_prompts.txt")
|
36 |
+
video_list = split_story(story, script_file_path)
|
37 |
+
video_list = patch_story_scripts(story, video_list, script_file_path)
|
38 |
+
video_list = refine_story_scripts(video_list, script_file_path)
|
39 |
+
print("Scripts OK", flush=True)
|
40 |
+
|
41 |
+
# think about the protagonist in each scene
|
42 |
+
reference_file_path = os.path.join(save_script_path, "protagonist_place_reference.txt")
|
43 |
+
reference_lists = protagonist_place_reference1(video_list, character_places, reference_file_path)
|
44 |
+
print("Reference protagonist OK", flush=True)
|
45 |
+
|
46 |
+
# translate the English script to Chinese
|
47 |
+
zh_file_path = os.path.join(save_script_path, "zh_video_prompts.txt")
|
48 |
+
zh_video_list = translate_video_script(video_list, zh_file_path)
|
49 |
+
print("Translation OK", flush=True)
|
50 |
+
|
51 |
+
# schedule the time of script
|
52 |
+
time_file_path = os.path.join(save_script_path, "time_scripts.txt")
|
53 |
+
time_list = time_scripts(video_list, time_file_path)
|
54 |
+
print("Time script OK", flush=True)
|
55 |
+
|
56 |
+
# make reference image
|
57 |
+
base = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",
|
58 |
+
torch_dtype=torch.float16,
|
59 |
+
variant="fp16",
|
60 |
+
use_safetensors=True,
|
61 |
+
).to("cuda")
|
62 |
+
refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0",
|
63 |
+
text_encoder_2=base.text_encoder_2,
|
64 |
+
vae=base.vae,
|
65 |
+
torch_dtype=torch.float16,
|
66 |
+
use_safetensors=True,
|
67 |
+
variant="fp16",
|
68 |
+
).to("cuda")
|
69 |
+
ref_dir_path = os.path.join(story_path.rsplit('/', 1)[0], "ref_img")
|
70 |
+
if not os.path.exists(ref_dir_path):
|
71 |
+
os.makedirs(ref_dir_path)
|
72 |
+
for key, value in character_places.items():
|
73 |
+
prompt = key + ", " + value
|
74 |
+
img_path = os.path.join(ref_dir_path, key + ".jpg")
|
75 |
+
image = base(prompt=prompt,
|
76 |
+
output_type="latent",
|
77 |
+
height=1024,
|
78 |
+
width=1024,
|
79 |
+
guidance_scale=7
|
80 |
+
).images[0]
|
81 |
+
image = refiner(prompt=prompt, image=image[None, :]).images[0]
|
82 |
+
image.save(img_path)
|
83 |
+
print("Reference image OK", flush=True)
|
84 |
+
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
parser = argparse.ArgumentParser()
|
88 |
+
parser.add_argument("--config", type=str, default="configs/vlog_write_script.yaml")
|
89 |
+
args = parser.parse_args()
|
90 |
+
omega_conf = OmegaConf.load(args.config)
|
91 |
+
main(omega_conf)
|
sample_scripts/with_mask_ref_sample.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Sample new images from a pre-trained DiT.
|
9 |
+
"""
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import math
|
13 |
+
try:
|
14 |
+
import utils
|
15 |
+
from diffusion import create_diffusion
|
16 |
+
except:
|
17 |
+
# sys.path.append(os.getcwd())
|
18 |
+
sys.path.append(os.path.split(sys.path[0])[0])
|
19 |
+
# sys.path[0]
|
20 |
+
# os.path.split(sys.path[0])
|
21 |
+
import utils
|
22 |
+
|
23 |
+
from diffusion import create_diffusion
|
24 |
+
|
25 |
+
import torch
|
26 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
27 |
+
torch.backends.cudnn.allow_tf32 = True
|
28 |
+
import argparse
|
29 |
+
import torchvision
|
30 |
+
|
31 |
+
from einops import rearrange
|
32 |
+
from models import get_models
|
33 |
+
from torchvision.utils import save_image
|
34 |
+
from diffusers.models import AutoencoderKL
|
35 |
+
from models.clip import TextEmbedder
|
36 |
+
from omegaconf import OmegaConf
|
37 |
+
from PIL import Image
|
38 |
+
import numpy as np
|
39 |
+
from torchvision import transforms
|
40 |
+
sys.path.append("..")
|
41 |
+
from datasets import video_transforms
|
42 |
+
from utils import mask_generation_before
|
43 |
+
from natsort import natsorted
|
44 |
+
from diffusers.utils.import_utils import is_xformers_available
|
45 |
+
from vlogger.STEB.model_transform import ip_scale_set, ip_transform_model, tca_transform_model
|
46 |
+
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
|
47 |
+
|
48 |
+
def get_input(args):
|
49 |
+
input_path = args.input_path
|
50 |
+
transform_video = transforms.Compose([
|
51 |
+
video_transforms.ToTensorVideo(), # TCHW
|
52 |
+
video_transforms.ResizeVideo((args.image_h, args.image_w)),
|
53 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
|
54 |
+
])
|
55 |
+
if input_path is not None:
|
56 |
+
print(f'loading video from {input_path}')
|
57 |
+
if os.path.isdir(input_path):
|
58 |
+
file_list = os.listdir(input_path)
|
59 |
+
video_frames = []
|
60 |
+
if args.mask_type.startswith('onelast'):
|
61 |
+
num = int(args.mask_type.split('onelast')[-1])
|
62 |
+
# get first and last frame
|
63 |
+
first_frame_path = os.path.join(input_path, natsorted(file_list)[0])
|
64 |
+
last_frame_path = os.path.join(input_path, natsorted(file_list)[-1])
|
65 |
+
first_frame = torch.as_tensor(np.array(Image.open(first_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0)
|
66 |
+
last_frame = torch.as_tensor(np.array(Image.open(last_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0)
|
67 |
+
for i in range(num):
|
68 |
+
video_frames.append(first_frame)
|
69 |
+
# add zeros to frames
|
70 |
+
num_zeros = args.num_frames-2*num
|
71 |
+
for i in range(num_zeros):
|
72 |
+
zeros = torch.zeros_like(first_frame)
|
73 |
+
video_frames.append(zeros)
|
74 |
+
for i in range(num):
|
75 |
+
video_frames.append(last_frame)
|
76 |
+
n = 0
|
77 |
+
video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
|
78 |
+
video_frames = transform_video(video_frames)
|
79 |
+
else:
|
80 |
+
for file in file_list:
|
81 |
+
if file.endswith('jpg') or file.endswith('png'):
|
82 |
+
image = torch.as_tensor(np.array(Image.open(file), dtype=np.uint8, copy=True)).unsqueeze(0)
|
83 |
+
video_frames.append(image)
|
84 |
+
else:
|
85 |
+
continue
|
86 |
+
n = 0
|
87 |
+
video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
|
88 |
+
video_frames = transform_video(video_frames)
|
89 |
+
return video_frames, n
|
90 |
+
elif os.path.isfile(input_path):
|
91 |
+
_, full_file_name = os.path.split(input_path)
|
92 |
+
file_name, extention = os.path.splitext(full_file_name)
|
93 |
+
if extention == '.jpg' or extention == '.png':
|
94 |
+
print("loading the input image")
|
95 |
+
video_frames = []
|
96 |
+
num = int(args.mask_type.split('first')[-1])
|
97 |
+
first_frame = torch.as_tensor(np.array(Image.open(input_path), dtype=np.uint8, copy=True)).unsqueeze(0)
|
98 |
+
for i in range(num):
|
99 |
+
video_frames.append(first_frame)
|
100 |
+
num_zeros = args.num_frames-num
|
101 |
+
for i in range(num_zeros):
|
102 |
+
zeros = torch.zeros_like(first_frame)
|
103 |
+
video_frames.append(zeros)
|
104 |
+
n = 0
|
105 |
+
video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
|
106 |
+
video_frames = transform_video(video_frames)
|
107 |
+
return video_frames, n
|
108 |
+
else:
|
109 |
+
raise TypeError(f'{extention} is not supported !!')
|
110 |
+
else:
|
111 |
+
raise ValueError('Please check your path input!!')
|
112 |
+
else:
|
113 |
+
raise ValueError('Need to give a video or some images')
|
114 |
+
|
115 |
+
def auto_inpainting(args,
|
116 |
+
video_input,
|
117 |
+
masked_video,
|
118 |
+
mask,
|
119 |
+
prompt,
|
120 |
+
image,
|
121 |
+
vae,
|
122 |
+
text_encoder,
|
123 |
+
image_encoder,
|
124 |
+
diffusion,
|
125 |
+
model,
|
126 |
+
device,
|
127 |
+
):
|
128 |
+
image_prompt_embeds = None
|
129 |
+
if prompt is None:
|
130 |
+
prompt = ""
|
131 |
+
if image is not None:
|
132 |
+
clip_image = CLIPImageProcessor()(images=image, return_tensors="pt").pixel_values
|
133 |
+
clip_image_embeds = image_encoder(clip_image.to(device)).image_embeds
|
134 |
+
uncond_clip_image_embeds = torch.zeros_like(clip_image_embeds).to(device)
|
135 |
+
image_prompt_embeds = torch.cat([clip_image_embeds, uncond_clip_image_embeds], dim=0)
|
136 |
+
image_prompt_embeds = rearrange(image_prompt_embeds, '(b n) c -> b n c', b=2).contiguous()
|
137 |
+
model = ip_scale_set(model, args.ref_cfg_scale)
|
138 |
+
if args.use_fp16:
|
139 |
+
image_prompt_embeds = image_prompt_embeds.to(dtype=torch.float16)
|
140 |
+
b, f, c, h, w = video_input.shape
|
141 |
+
latent_h = video_input.shape[-2] // 8
|
142 |
+
latent_w = video_input.shape[-1] // 8
|
143 |
+
|
144 |
+
if args.use_fp16:
|
145 |
+
z = torch.randn(1, 4, 16, latent_h, latent_w, dtype=torch.float16, device=device) # b,c,f,h,w
|
146 |
+
masked_video = masked_video.to(dtype=torch.float16)
|
147 |
+
mask = mask.to(dtype=torch.float16)
|
148 |
+
else:
|
149 |
+
z = torch.randn(1, 4, 16, latent_h, latent_w, device=device) # b,c,f,h,w
|
150 |
+
|
151 |
+
masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous()
|
152 |
+
masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215)
|
153 |
+
masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous()
|
154 |
+
mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1)
|
155 |
+
masked_video = torch.cat([masked_video] * 2)
|
156 |
+
mask = torch.cat([mask] * 2)
|
157 |
+
z = torch.cat([z] * 2)
|
158 |
+
prompt_all = [prompt] + [args.negative_prompt]
|
159 |
+
|
160 |
+
text_prompt = text_encoder(text_prompts=prompt_all, train=False)
|
161 |
+
model_kwargs = dict(encoder_hidden_states=text_prompt,
|
162 |
+
class_labels=None,
|
163 |
+
cfg_scale=args.cfg_scale,
|
164 |
+
use_fp16=args.use_fp16,
|
165 |
+
ip_hidden_states=image_prompt_embeds)
|
166 |
+
|
167 |
+
# Sample images:
|
168 |
+
samples = diffusion.ddim_sample_loop(
|
169 |
+
model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \
|
170 |
+
mask=mask, x_start=masked_video, use_concat=True
|
171 |
+
)
|
172 |
+
samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32]
|
173 |
+
if args.use_fp16:
|
174 |
+
samples = samples.to(dtype=torch.float16)
|
175 |
+
|
176 |
+
video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32]
|
177 |
+
video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256]
|
178 |
+
return video_clip
|
179 |
+
|
180 |
+
def main(args):
|
181 |
+
# Setup PyTorch:
|
182 |
+
if args.seed:
|
183 |
+
torch.manual_seed(args.seed)
|
184 |
+
torch.set_grad_enabled(False)
|
185 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
186 |
+
# device = "cpu"
|
187 |
+
|
188 |
+
if args.ckpt is None:
|
189 |
+
raise ValueError("Please specify a checkpoint path using --ckpt <path>")
|
190 |
+
|
191 |
+
# Load model:
|
192 |
+
latent_h = args.image_size[0] // 8
|
193 |
+
latent_w = args.image_size[1] // 8
|
194 |
+
args.image_h = args.image_size[0]
|
195 |
+
args.image_w = args.image_size[1]
|
196 |
+
args.latent_h = latent_h
|
197 |
+
args.latent_w = latent_w
|
198 |
+
print('loading model')
|
199 |
+
model = get_models(args).to(device)
|
200 |
+
model = tca_transform_model(model).to(device)
|
201 |
+
model = ip_transform_model(model).to(device)
|
202 |
+
|
203 |
+
if args.enable_xformers_memory_efficient_attention:
|
204 |
+
if is_xformers_available():
|
205 |
+
model.enable_xformers_memory_efficient_attention()
|
206 |
+
else:
|
207 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
208 |
+
|
209 |
+
# load model
|
210 |
+
ckpt_path = args.ckpt
|
211 |
+
state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema']
|
212 |
+
model_dict = model.state_dict()
|
213 |
+
pretrained_dict = {}
|
214 |
+
for k, v in state_dict.items():
|
215 |
+
if k in model_dict:
|
216 |
+
pretrained_dict[k] = v
|
217 |
+
model_dict.update(pretrained_dict)
|
218 |
+
model.load_state_dict(model_dict)
|
219 |
+
|
220 |
+
model.eval()
|
221 |
+
pretrained_model_path = args.pretrained_model_path
|
222 |
+
diffusion = create_diffusion(str(args.num_sampling_steps))
|
223 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
|
224 |
+
text_encoder = TextEmbedder(pretrained_model_path).to(device)
|
225 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path).to(device)
|
226 |
+
if args.use_fp16:
|
227 |
+
print('Warnning: using half percision for inferencing!')
|
228 |
+
vae.to(dtype=torch.float16)
|
229 |
+
model.to(dtype=torch.float16)
|
230 |
+
text_encoder.to(dtype=torch.float16)
|
231 |
+
|
232 |
+
# prompt:
|
233 |
+
prompt = args.text_prompt
|
234 |
+
if prompt ==[]:
|
235 |
+
prompt = args.input_path.split('/')[-1].split('.')[0].replace('_', ' ')
|
236 |
+
else:
|
237 |
+
prompt = prompt[0]
|
238 |
+
prompt_base = prompt.replace(' ','_')
|
239 |
+
prompt = prompt + args.additional_prompt
|
240 |
+
|
241 |
+
if not os.path.exists(os.path.join(args.save_path)):
|
242 |
+
os.makedirs(os.path.join(args.save_path))
|
243 |
+
video_input, researve_frames = get_input(args) # f,c,h,w
|
244 |
+
video_input = video_input.to(device).unsqueeze(0) # b,f,c,h,w
|
245 |
+
mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device) # b,f,c,h,w
|
246 |
+
masked_video = video_input * (mask == 0)
|
247 |
+
|
248 |
+
pil_image = Image.open(args.ref_path)
|
249 |
+
pil_image.resize((256, 256))
|
250 |
+
|
251 |
+
video_clip = auto_inpainting(args,
|
252 |
+
video_input,
|
253 |
+
masked_video,
|
254 |
+
mask,
|
255 |
+
prompt,
|
256 |
+
pil_image,
|
257 |
+
vae,
|
258 |
+
text_encoder,
|
259 |
+
image_encoder,
|
260 |
+
diffusion,
|
261 |
+
model,
|
262 |
+
device,
|
263 |
+
)
|
264 |
+
video_ = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
|
265 |
+
save_video_path = os.path.join(args.save_path, prompt_base+ '.mp4')
|
266 |
+
torchvision.io.write_video(save_video_path, video_, fps=8)
|
267 |
+
print(f'save in {save_video_path}')
|
268 |
+
|
269 |
+
|
270 |
+
if __name__ == "__main__":
|
271 |
+
parser = argparse.ArgumentParser()
|
272 |
+
parser.add_argument("--config", type=str, default="configs/with_mask_ref_sample.yaml")
|
273 |
+
args = parser.parse_args()
|
274 |
+
omega_conf = OmegaConf.load(args.config)
|
275 |
+
main(omega_conf)
|
sample_scripts/with_mask_sample.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Sample new images from a pre-trained DiT.
|
9 |
+
"""
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import math
|
13 |
+
try:
|
14 |
+
import utils
|
15 |
+
from diffusion import create_diffusion
|
16 |
+
except:
|
17 |
+
# sys.path.append(os.getcwd())
|
18 |
+
sys.path.append(os.path.split(sys.path[0])[0])
|
19 |
+
# sys.path[0]
|
20 |
+
# os.path.split(sys.path[0])
|
21 |
+
import utils
|
22 |
+
|
23 |
+
from diffusion import create_diffusion
|
24 |
+
|
25 |
+
import torch
|
26 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
27 |
+
torch.backends.cudnn.allow_tf32 = True
|
28 |
+
import argparse
|
29 |
+
import torchvision
|
30 |
+
|
31 |
+
from einops import rearrange
|
32 |
+
from models import get_models
|
33 |
+
from torchvision.utils import save_image
|
34 |
+
from diffusers.models import AutoencoderKL
|
35 |
+
from models.clip import TextEmbedder
|
36 |
+
from omegaconf import OmegaConf
|
37 |
+
from PIL import Image
|
38 |
+
import numpy as np
|
39 |
+
from torchvision import transforms
|
40 |
+
sys.path.append("..")
|
41 |
+
from datasets import video_transforms
|
42 |
+
from utils import mask_generation_before
|
43 |
+
from natsort import natsorted
|
44 |
+
from diffusers.utils.import_utils import is_xformers_available
|
45 |
+
from vlogger.STEB.model_transform import tca_transform_model
|
46 |
+
|
47 |
+
def get_input(args):
|
48 |
+
input_path = args.input_path
|
49 |
+
transform_video = transforms.Compose([
|
50 |
+
video_transforms.ToTensorVideo(), # TCHW
|
51 |
+
video_transforms.ResizeVideo((args.image_h, args.image_w)),
|
52 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
|
53 |
+
])
|
54 |
+
if input_path is not None:
|
55 |
+
print(f'loading video from {input_path}')
|
56 |
+
if os.path.isdir(input_path):
|
57 |
+
file_list = os.listdir(input_path)
|
58 |
+
video_frames = []
|
59 |
+
if args.mask_type.startswith('onelast'):
|
60 |
+
num = int(args.mask_type.split('onelast')[-1])
|
61 |
+
# get first and last frame
|
62 |
+
first_frame_path = os.path.join(input_path, natsorted(file_list)[0])
|
63 |
+
last_frame_path = os.path.join(input_path, natsorted(file_list)[-1])
|
64 |
+
first_frame = torch.as_tensor(np.array(Image.open(first_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0)
|
65 |
+
last_frame = torch.as_tensor(np.array(Image.open(last_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0)
|
66 |
+
for i in range(num):
|
67 |
+
video_frames.append(first_frame)
|
68 |
+
# add zeros to frames
|
69 |
+
num_zeros = args.num_frames-2*num
|
70 |
+
for i in range(num_zeros):
|
71 |
+
zeros = torch.zeros_like(first_frame)
|
72 |
+
video_frames.append(zeros)
|
73 |
+
for i in range(num):
|
74 |
+
video_frames.append(last_frame)
|
75 |
+
n = 0
|
76 |
+
video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
|
77 |
+
video_frames = transform_video(video_frames)
|
78 |
+
else:
|
79 |
+
for file in file_list:
|
80 |
+
if file.endswith('jpg') or file.endswith('png'):
|
81 |
+
image = torch.as_tensor(np.array(Image.open(file), dtype=np.uint8, copy=True)).unsqueeze(0)
|
82 |
+
video_frames.append(image)
|
83 |
+
else:
|
84 |
+
continue
|
85 |
+
n = 0
|
86 |
+
video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
|
87 |
+
video_frames = transform_video(video_frames)
|
88 |
+
return video_frames, n
|
89 |
+
elif os.path.isfile(input_path):
|
90 |
+
_, full_file_name = os.path.split(input_path)
|
91 |
+
file_name, extention = os.path.splitext(full_file_name)
|
92 |
+
if extention == '.jpg' or extention == '.png':
|
93 |
+
print("loading the input image")
|
94 |
+
video_frames = []
|
95 |
+
num = int(args.mask_type.split('first')[-1])
|
96 |
+
first_frame = torch.as_tensor(np.array(Image.open(input_path), dtype=np.uint8, copy=True)).unsqueeze(0)
|
97 |
+
for i in range(num):
|
98 |
+
video_frames.append(first_frame)
|
99 |
+
num_zeros = args.num_frames-num
|
100 |
+
for i in range(num_zeros):
|
101 |
+
zeros = torch.zeros_like(first_frame)
|
102 |
+
video_frames.append(zeros)
|
103 |
+
n = 0
|
104 |
+
video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
|
105 |
+
video_frames = transform_video(video_frames)
|
106 |
+
return video_frames, n
|
107 |
+
else:
|
108 |
+
raise TypeError(f'{extention} is not supported !!')
|
109 |
+
else:
|
110 |
+
raise ValueError('Please check your path input!!')
|
111 |
+
else:
|
112 |
+
raise ValueError('Need to give a video or some images')
|
113 |
+
|
114 |
+
def auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,):
|
115 |
+
b,f,c,h,w=video_input.shape
|
116 |
+
latent_h = args.image_size[0] // 8
|
117 |
+
latent_w = args.image_size[1] // 8
|
118 |
+
|
119 |
+
# prepare inputs
|
120 |
+
if args.use_fp16:
|
121 |
+
z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, dtype=torch.float16, device=device) # b,c,f,h,w
|
122 |
+
masked_video = masked_video.to(dtype=torch.float16)
|
123 |
+
mask = mask.to(dtype=torch.float16)
|
124 |
+
else:
|
125 |
+
z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, device=device) # b,c,f,h,w
|
126 |
+
|
127 |
+
|
128 |
+
masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous()
|
129 |
+
masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215)
|
130 |
+
masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous()
|
131 |
+
mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1)
|
132 |
+
|
133 |
+
# classifier_free_guidance
|
134 |
+
if args.do_classifier_free_guidance:
|
135 |
+
masked_video = torch.cat([masked_video] * 2)
|
136 |
+
mask = torch.cat([mask] * 2)
|
137 |
+
z = torch.cat([z] * 2)
|
138 |
+
prompt_all = [prompt] + [args.negative_prompt]
|
139 |
+
|
140 |
+
else:
|
141 |
+
masked_video = masked_video
|
142 |
+
mask = mask
|
143 |
+
z = z
|
144 |
+
prompt_all = [prompt]
|
145 |
+
|
146 |
+
text_prompt = text_encoder(text_prompts=prompt_all, train=False)
|
147 |
+
model_kwargs = dict(encoder_hidden_states=text_prompt,
|
148 |
+
class_labels=None,
|
149 |
+
cfg_scale=args.cfg_scale,
|
150 |
+
use_fp16=args.use_fp16,) # tav unet
|
151 |
+
|
152 |
+
# Sample video:
|
153 |
+
if args.sample_method == 'ddim':
|
154 |
+
samples = diffusion.ddim_sample_loop(
|
155 |
+
model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \
|
156 |
+
mask=mask, x_start=masked_video, use_concat=args.use_mask
|
157 |
+
)
|
158 |
+
elif args.sample_method == 'ddpm':
|
159 |
+
samples = diffusion.p_sample_loop(
|
160 |
+
model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \
|
161 |
+
mask=mask, x_start=masked_video, use_concat=args.use_mask
|
162 |
+
)
|
163 |
+
samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32]
|
164 |
+
if args.use_fp16:
|
165 |
+
samples = samples.to(dtype=torch.float16)
|
166 |
+
|
167 |
+
video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32]
|
168 |
+
video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256]
|
169 |
+
return video_clip
|
170 |
+
|
171 |
+
def main(args):
|
172 |
+
# Setup PyTorch:
|
173 |
+
if args.seed:
|
174 |
+
torch.manual_seed(args.seed)
|
175 |
+
torch.set_grad_enabled(False)
|
176 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
177 |
+
# device = "cpu"
|
178 |
+
|
179 |
+
if args.ckpt is None:
|
180 |
+
raise ValueError("Please specify a checkpoint path using --ckpt <path>")
|
181 |
+
|
182 |
+
# Load model:
|
183 |
+
latent_h = args.image_size[0] // 8
|
184 |
+
latent_w = args.image_size[1] // 8
|
185 |
+
args.image_h = args.image_size[0]
|
186 |
+
args.image_w = args.image_size[1]
|
187 |
+
args.latent_h = latent_h
|
188 |
+
args.latent_w = latent_w
|
189 |
+
print('loading model')
|
190 |
+
model = get_models(args).to(device)
|
191 |
+
model = tca_transform_model(model).to(device)
|
192 |
+
|
193 |
+
if args.enable_xformers_memory_efficient_attention:
|
194 |
+
if is_xformers_available():
|
195 |
+
model.enable_xformers_memory_efficient_attention()
|
196 |
+
else:
|
197 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
198 |
+
|
199 |
+
# load model
|
200 |
+
ckpt_path = args.ckpt
|
201 |
+
state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema']
|
202 |
+
model_dict = model.state_dict()
|
203 |
+
pretrained_dict = {}
|
204 |
+
for k, v in state_dict.items():
|
205 |
+
if k in model_dict:
|
206 |
+
pretrained_dict[k] = v
|
207 |
+
model_dict.update(pretrained_dict)
|
208 |
+
model.load_state_dict(model_dict)
|
209 |
+
|
210 |
+
model.eval()
|
211 |
+
pretrained_model_path = args.pretrained_model_path
|
212 |
+
diffusion = create_diffusion(str(args.num_sampling_steps))
|
213 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
|
214 |
+
text_encoder = TextEmbedder(pretrained_model_path).to(device)
|
215 |
+
if args.use_fp16:
|
216 |
+
print('Warnning: using half percision for inferencing!')
|
217 |
+
vae.to(dtype=torch.float16)
|
218 |
+
model.to(dtype=torch.float16)
|
219 |
+
text_encoder.to(dtype=torch.float16)
|
220 |
+
|
221 |
+
# prompt:
|
222 |
+
prompt = args.text_prompt
|
223 |
+
if prompt ==[]:
|
224 |
+
prompt = args.input_path.split('/')[-1].split('.')[0].replace('_', ' ')
|
225 |
+
else:
|
226 |
+
prompt = prompt[0]
|
227 |
+
prompt_base = prompt.replace(' ','_')
|
228 |
+
prompt = prompt + args.additional_prompt
|
229 |
+
|
230 |
+
if not os.path.exists(os.path.join(args.save_path)):
|
231 |
+
os.makedirs(os.path.join(args.save_path))
|
232 |
+
video_input, researve_frames = get_input(args) # f,c,h,w
|
233 |
+
video_input = video_input.to(device).unsqueeze(0) # b,f,c,h,w
|
234 |
+
mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device) # b,f,c,h,w
|
235 |
+
masked_video = video_input * (mask == 0)
|
236 |
+
|
237 |
+
video_clip = auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,)
|
238 |
+
video_ = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
|
239 |
+
save_video_path = os.path.join(args.save_path, prompt_base+ '.mp4')
|
240 |
+
torchvision.io.write_video(save_video_path, video_, fps=8)
|
241 |
+
print(f'save in {save_video_path}')
|
242 |
+
|
243 |
+
|
244 |
+
if __name__ == "__main__":
|
245 |
+
parser = argparse.ArgumentParser()
|
246 |
+
parser.add_argument("--config", type=str, default="configs/with_mask_sample.yaml")
|
247 |
+
args = parser.parse_args()
|
248 |
+
omega_conf = OmegaConf.load(args.config)
|
249 |
+
main(omega_conf)
|