Spaces:
Running
on
L40S
Running
on
L40S
Migrated from GitHub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- LICENSE +202 -0
- ORIGINAL_README.md +171 -0
- assets/DiffuEraser_pipeline.png +0 -0
- diffueraser/diffueraser.py +432 -0
- diffueraser/pipeline_diffueraser.py +1349 -0
- examples/example1/mask.mp4 +0 -0
- examples/example1/video.mp4 +0 -0
- examples/example2/mask.mp4 +3 -0
- examples/example2/video.mp4 +0 -0
- examples/example3/mask.mp4 +0 -0
- examples/example3/video.mp4 +3 -0
- libs/brushnet_CA.py +939 -0
- libs/transformer_temporal.py +375 -0
- libs/unet_2d_blocks.py +0 -0
- libs/unet_2d_condition.py +1359 -0
- libs/unet_3d_blocks.py +2463 -0
- libs/unet_motion_model.py +975 -0
- propainter/RAFT/__init__.py +2 -0
- propainter/RAFT/corr.py +111 -0
- propainter/RAFT/datasets.py +235 -0
- propainter/RAFT/demo.py +79 -0
- propainter/RAFT/extractor.py +267 -0
- propainter/RAFT/raft.py +146 -0
- propainter/RAFT/update.py +139 -0
- propainter/RAFT/utils/__init__.py +2 -0
- propainter/RAFT/utils/augmentor.py +246 -0
- propainter/RAFT/utils/flow_viz.py +132 -0
- propainter/RAFT/utils/flow_viz_pt.py +118 -0
- propainter/RAFT/utils/frame_utils.py +137 -0
- propainter/RAFT/utils/utils.py +82 -0
- propainter/core/dataset.py +232 -0
- propainter/core/dist.py +47 -0
- propainter/core/loss.py +180 -0
- propainter/core/lr_scheduler.py +112 -0
- propainter/core/metrics.py +571 -0
- propainter/core/prefetch_dataloader.py +125 -0
- propainter/core/trainer.py +509 -0
- propainter/core/trainer_flow_w_edge.py +380 -0
- propainter/core/utils.py +371 -0
- propainter/inference.py +520 -0
- propainter/model/__init__.py +1 -0
- propainter/model/canny/canny_filter.py +256 -0
- propainter/model/canny/filter.py +288 -0
- propainter/model/canny/gaussian.py +116 -0
- propainter/model/canny/kernels.py +690 -0
- propainter/model/canny/sobel.py +263 -0
- propainter/model/misc.py +131 -0
- propainter/model/modules/base_module.py +131 -0
- propainter/model/modules/deformconv.py +54 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
examples/example2/mask.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
examples/example3/video.mp4 filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Apache License
|
3 |
+
Version 2.0, January 2004
|
4 |
+
http://www.apache.org/licenses/
|
5 |
+
|
6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
|
10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
12 |
+
|
13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
14 |
+
the copyright owner that is granting the License.
|
15 |
+
|
16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
17 |
+
other entities that control, are controlled by, or are under common
|
18 |
+
control with that entity. For the purposes of this definition,
|
19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
20 |
+
direction or management of such entity, whether by contract or
|
21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
23 |
+
|
24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
25 |
+
exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications,
|
28 |
+
including but not limited to software source code, documentation
|
29 |
+
source, and configuration files.
|
30 |
+
|
31 |
+
"Object" form shall mean any form resulting from mechanical
|
32 |
+
transformation or translation of a Source form, including but
|
33 |
+
not limited to compiled object code, generated documentation,
|
34 |
+
and conversions to other media types.
|
35 |
+
|
36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
37 |
+
Object form, made available under the License, as indicated by a
|
38 |
+
copyright notice that is included in or attached to the work
|
39 |
+
(an example is provided in the Appendix below).
|
40 |
+
|
41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
42 |
+
form, that is based on (or derived from) the Work and for which the
|
43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
45 |
+
of this License, Derivative Works shall not include works that remain
|
46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
47 |
+
the Work and Derivative Works thereof.
|
48 |
+
|
49 |
+
"Contribution" shall mean any work of authorship, including
|
50 |
+
the original version of the Work and any modifications or additions
|
51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
55 |
+
means any form of electronic, verbal, or written communication sent
|
56 |
+
to the Licensor or its representatives, including but not limited to
|
57 |
+
communication on electronic mailing lists, source code control systems,
|
58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
60 |
+
excluding communication that is conspicuously marked or otherwise
|
61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
62 |
+
|
63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
65 |
+
subsequently incorporated within the Work.
|
66 |
+
|
67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
72 |
+
Work and such Derivative Works in Source or Object form.
|
73 |
+
|
74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
(except as stated in this section) patent license to make, have made,
|
78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
79 |
+
where such license applies only to those patent claims licensable
|
80 |
+
by such Contributor that are necessarily infringed by their
|
81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
83 |
+
institute patent litigation against any entity (including a
|
84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
85 |
+
or a Contribution incorporated within the Work constitutes direct
|
86 |
+
or contributory patent infringement, then any patent licenses
|
87 |
+
granted to You under this License for that Work shall terminate
|
88 |
+
as of the date such litigation is filed.
|
89 |
+
|
90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
91 |
+
Work or Derivative Works thereof in any medium, with or without
|
92 |
+
modifications, and in Source or Object form, provided that You
|
93 |
+
meet the following conditions:
|
94 |
+
|
95 |
+
(a) You must give any other recipients of the Work or
|
96 |
+
Derivative Works a copy of this License; and
|
97 |
+
|
98 |
+
(b) You must cause any modified files to carry prominent notices
|
99 |
+
stating that You changed the files; and
|
100 |
+
|
101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
102 |
+
that You distribute, all copyright, patent, trademark, and
|
103 |
+
attribution notices from the Source form of the Work,
|
104 |
+
excluding those notices that do not pertain to any part of
|
105 |
+
the Derivative Works; and
|
106 |
+
|
107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
108 |
+
distribution, then any Derivative Works that You distribute must
|
109 |
+
include a readable copy of the attribution notices contained
|
110 |
+
within such NOTICE file, excluding those notices that do not
|
111 |
+
pertain to any part of the Derivative Works, in at least one
|
112 |
+
of the following places: within a NOTICE text file distributed
|
113 |
+
as part of the Derivative Works; within the Source form or
|
114 |
+
documentation, if provided along with the Derivative Works; or,
|
115 |
+
within a display generated by the Derivative Works, if and
|
116 |
+
wherever such third-party notices normally appear. The contents
|
117 |
+
of the NOTICE file are for informational purposes only and
|
118 |
+
do not modify the License. You may add Your own attribution
|
119 |
+
notices within Derivative Works that You distribute, alongside
|
120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
121 |
+
that such additional attribution notices cannot be construed
|
122 |
+
as modifying the License.
|
123 |
+
|
124 |
+
You may add Your own copyright statement to Your modifications and
|
125 |
+
may provide additional or different license terms and conditions
|
126 |
+
for use, reproduction, or distribution of Your modifications, or
|
127 |
+
for any such Derivative Works as a whole, provided Your use,
|
128 |
+
reproduction, and distribution of the Work otherwise complies with
|
129 |
+
the conditions stated in this License.
|
130 |
+
|
131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
133 |
+
by You to the Licensor shall be under the terms and conditions of
|
134 |
+
this License, without any additional terms or conditions.
|
135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
136 |
+
the terms of any separate license agreement you may have executed
|
137 |
+
with Licensor regarding such Contributions.
|
138 |
+
|
139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
141 |
+
except as required for reasonable and customary use in describing the
|
142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
143 |
+
|
144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
145 |
+
agreed to in writing, Licensor provides the Work (and each
|
146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
148 |
+
implied, including, without limitation, any warranties or conditions
|
149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
151 |
+
appropriateness of using or redistributing the Work and assume any
|
152 |
+
risks associated with Your exercise of permissions under this License.
|
153 |
+
|
154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
155 |
+
whether in tort (including negligence), contract, or otherwise,
|
156 |
+
unless required by applicable law (such as deliberate and grossly
|
157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
158 |
+
liable to You for damages, including any direct, indirect, special,
|
159 |
+
incidental, or consequential damages of any character arising as a
|
160 |
+
result of this License or out of the use or inability to use the
|
161 |
+
Work (including but not limited to damages for loss of goodwill,
|
162 |
+
work stoppage, computer failure or malfunction, or any and all
|
163 |
+
other commercial damages or losses), even if such Contributor
|
164 |
+
has been advised of the possibility of such damages.
|
165 |
+
|
166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
169 |
+
or other liability obligations and/or rights consistent with this
|
170 |
+
License. However, in accepting such obligations, You may act only
|
171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
172 |
+
of any other Contributor, and only if You agree to indemnify,
|
173 |
+
defend, and hold each Contributor harmless for any liability
|
174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
175 |
+
of your accepting any such warranty or additional liability.
|
176 |
+
|
177 |
+
END OF TERMS AND CONDITIONS
|
178 |
+
|
179 |
+
APPENDIX: How to apply the Apache License to your work.
|
180 |
+
|
181 |
+
To apply the Apache License to your work, attach the following
|
182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
183 |
+
replaced with your own identifying information. (Don't include
|
184 |
+
the brackets!) The text should be enclosed in the appropriate
|
185 |
+
comment syntax for the file format. We also recommend that a
|
186 |
+
file or class name and description of purpose be included on the
|
187 |
+
same "printed page" as the copyright notice for easier
|
188 |
+
identification within third-party archives.
|
189 |
+
|
190 |
+
Copyright [yyyy] [name of copyright owner]
|
191 |
+
|
192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
193 |
+
you may not use this file except in compliance with the License.
|
194 |
+
You may obtain a copy of the License at
|
195 |
+
|
196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
197 |
+
|
198 |
+
Unless required by applicable law or agreed to in writing, software
|
199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
201 |
+
See the License for the specific language governing permissions and
|
202 |
+
limitations under the License.
|
ORIGINAL_README.md
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
|
3 |
+
<h1>DiffuEraser: A Diffusion Model for Video Inpainting</h1>
|
4 |
+
|
5 |
+
<div>
|
6 |
+
Xiaowen Li 
|
7 |
+
Haolan Xue 
|
8 |
+
Peiran Ren 
|
9 |
+
Liefeng Bo
|
10 |
+
</div>
|
11 |
+
<div>
|
12 |
+
Tongyi Lab, Alibaba Group 
|
13 |
+
</div>
|
14 |
+
|
15 |
+
<div>
|
16 |
+
<strong>TECHNICAL REPORT</strong>
|
17 |
+
</div>
|
18 |
+
|
19 |
+
<div>
|
20 |
+
<h4 align="center">
|
21 |
+
<a href="https://lixiaowen-xw.github.io/DiffuEraser-page" target='_blank'>
|
22 |
+
<img src="https://img.shields.io/badge/%F0%9F%8C%B1-Project%20Page-blue">
|
23 |
+
</a>
|
24 |
+
<a href="https://arxiv.org/abs/2501.10018" target='_blank'>
|
25 |
+
<img src="https://img.shields.io/badge/arXiv-2501.10018-B31B1B.svg">
|
26 |
+
</a>
|
27 |
+
</h4>
|
28 |
+
</div>
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
</div>
|
34 |
+
|
35 |
+
DiffuEraser is a diffusion model for video inpainting, which outperforms state-of-the-art model Propainter in both content completeness and temporal consistency while maintaining acceptable efficiency.
|
36 |
+
|
37 |
+
---
|
38 |
+
|
39 |
+
|
40 |
+
## Update Log
|
41 |
+
- *2025.01.20*: Release inference code.
|
42 |
+
|
43 |
+
|
44 |
+
## TODO
|
45 |
+
- [ ] Release training code.
|
46 |
+
- [ ] Release HuggingFace/ModelScope demo.
|
47 |
+
- [ ] Release gradio demo.
|
48 |
+
|
49 |
+
|
50 |
+
## Results
|
51 |
+
More results will be displayed on the project page.
|
52 |
+
|
53 |
+
https://github.com/user-attachments/assets/b59d0b88-4186-4531-8698-adf6e62058f8
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
## Method Overview
|
59 |
+
Our network is inspired by [BrushNet](https://github.com/TencentARC/BrushNet) and [Animatediff](https://github.com/guoyww/AnimateDiff). The architecture comprises the primary `denoising UNet` and an auxiliary `BrushNet branch`. Features extracted by BrushNet branch are integrated into the denoising UNet layer by layer after a zero convolution block. The denoising UNet performs the denoising process to generate the final output. To enhance temporal consistency, `temporal attention` mechanisms are incorporated following both self-attention and cross-attention layers. After denoising, the generated images are blended with the input masked images using blurred masks.
|
60 |
+
|
61 |
+
![overall_structure](assets/DiffuEraser_pipeline.png)
|
62 |
+
|
63 |
+
We incorporate `prior` information to provide initialization and weak conditioning, which helps mitigate noisy artifacts and suppress hallucinations.
|
64 |
+
Additionally, to improve temporal consistency during long-sequence inference, we expand the `temporal receptive fields` of both the prior model and DiffuEraser, and further enhance consistency by leveraging the temporal smoothing capabilities of Video Diffusion Models. Please read the paper for details.
|
65 |
+
|
66 |
+
|
67 |
+
## Getting Started
|
68 |
+
|
69 |
+
#### Installation
|
70 |
+
|
71 |
+
1. Clone Repo
|
72 |
+
|
73 |
+
```bash
|
74 |
+
git clone https://github.com/lixiaowen-xw/DiffuEraser.git
|
75 |
+
```
|
76 |
+
|
77 |
+
2. Create Conda Environment and Install Dependencies
|
78 |
+
|
79 |
+
```bash
|
80 |
+
# create new anaconda env
|
81 |
+
conda create -n diffueraser python=3.9.19
|
82 |
+
conda activate diffueraser
|
83 |
+
# install python dependencies
|
84 |
+
pip install -r requirements.txt
|
85 |
+
```
|
86 |
+
|
87 |
+
#### Prepare pretrained models
|
88 |
+
Weights will be placed under the `./weights` directory.
|
89 |
+
1. Download our pretrained models from [Hugging Face](https://huggingface.co/lixiaowen/diffuEraser) or [ModelScope](https://www.modelscope.cn/xingzi/diffuEraser.git) to the `weights` folder.
|
90 |
+
2. Download pretrained weight of based models and other components:
|
91 |
+
- [stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) . The full folder size is over 30 GB. If you want to save storage space, you can download only the necessary files: feature_extractor, model_index.json, safety_checker, scheduler, text_encoder, and tokenizer,about 4GB.
|
92 |
+
- [PCM_Weights](https://huggingface.co/wangfuyun/PCM_Weights)
|
93 |
+
- [propainter](https://github.com/sczhou/ProPainter/releases/tag/v0.1.0)
|
94 |
+
- [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)
|
95 |
+
|
96 |
+
|
97 |
+
The directory structure will be arranged as:
|
98 |
+
```
|
99 |
+
weights
|
100 |
+
|- diffuEraser
|
101 |
+
|-brushnet
|
102 |
+
|-unet_main
|
103 |
+
|- stable-diffusion-v1-5
|
104 |
+
|-feature_extractor
|
105 |
+
|-...
|
106 |
+
|- PCM_Weights
|
107 |
+
|-sd15
|
108 |
+
|- propainter
|
109 |
+
|-ProPainter.pth
|
110 |
+
|-raft-things.pth
|
111 |
+
|-recurrent_flow_completion.pth
|
112 |
+
|- sd-vae-ft-mse
|
113 |
+
|-diffusion_pytorch_model.bin
|
114 |
+
|-...
|
115 |
+
|- README.md
|
116 |
+
```
|
117 |
+
|
118 |
+
#### Main Inference
|
119 |
+
We provide some examples in the [`examples`](./examples) folder.
|
120 |
+
Run the following commands to try it out:
|
121 |
+
```shell
|
122 |
+
cd DiffuEraser
|
123 |
+
python run_diffueraser.py
|
124 |
+
```
|
125 |
+
The results will be saved in the `results` folder.
|
126 |
+
To test your own videos, please replace the `input_video` and `input_mask` in run_diffueraser.py . The first inference may take a long time.
|
127 |
+
|
128 |
+
The `frame rate` of input_video and input_mask needs to be consistent. We currently only support `mp4 video` as input intead of split frames, you can convert frames to video using ffmepg:
|
129 |
+
```shell
|
130 |
+
ffmpeg -i image%03d.jpg -c:v libx264 -r 25 output.mp4
|
131 |
+
```
|
132 |
+
Notice: Do not convert the frame rate of mask video if it is not consitent with that of the input video, which would lead to errors due to misalignment.
|
133 |
+
|
134 |
+
|
135 |
+
Blow shows the estimated GPU memory requirements and inference time for different resolution:
|
136 |
+
|
137 |
+
| Resolution | Gpu Memeory | Inference Time(250f(~10s), L20) |
|
138 |
+
| :--------- | :---------: | :-----------------------------: |
|
139 |
+
| 1280 x 720 | 33G | 314s |
|
140 |
+
| 960 x 540 | 20G | 175s |
|
141 |
+
| 640 x 360 | 12G | 92s |
|
142 |
+
|
143 |
+
|
144 |
+
## Citation
|
145 |
+
|
146 |
+
If you find our repo useful for your research, please consider citing our paper:
|
147 |
+
|
148 |
+
```bibtex
|
149 |
+
@misc{li2025diffueraserdiffusionmodelvideo,
|
150 |
+
title={DiffuEraser: A Diffusion Model for Video Inpainting},
|
151 |
+
author={Xiaowen Li and Haolan Xue and Peiran Ren and Liefeng Bo},
|
152 |
+
year={2025},
|
153 |
+
eprint={2501.10018},
|
154 |
+
archivePrefix={arXiv},
|
155 |
+
primaryClass={cs.CV},
|
156 |
+
url={https://arxiv.org/abs/2501.10018},
|
157 |
+
}
|
158 |
+
```
|
159 |
+
|
160 |
+
|
161 |
+
## License
|
162 |
+
This repository uses [Propainter](https://github.com/sczhou/ProPainter) as the prior model. Users must comply with [Propainter's license](https://github.com/sczhou/ProPainter/blob/main/LICENSE) when using this code. Or you can use other model to replace it.
|
163 |
+
|
164 |
+
This project is licensed under the [Apache License Version 2.0](./LICENSE) except for the third-party components listed below.
|
165 |
+
|
166 |
+
|
167 |
+
## Acknowledgement
|
168 |
+
|
169 |
+
This code is based on [BrushNet](https://github.com/TencentARC/BrushNet), [Propainter](https://github.com/sczhou/ProPainter) and [Animatediff](https://github.com/guoyww/AnimateDiff). The example videos come from [Pexels](https://www.pexels.com/), [DAVIS](https://davischallenge.org/), [SA-V](https://ai.meta.com/datasets/segment-anything-video) and [DanceTrack](https://dancetrack.github.io/). Thanks for their awesome works.
|
170 |
+
|
171 |
+
|
assets/DiffuEraser_pipeline.png
ADDED
diffueraser/diffueraser.py
ADDED
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import copy
|
3 |
+
import cv2
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
from einops import repeat
|
9 |
+
from PIL import Image, ImageFilter
|
10 |
+
from diffusers import (
|
11 |
+
AutoencoderKL,
|
12 |
+
DDPMScheduler,
|
13 |
+
UniPCMultistepScheduler,
|
14 |
+
LCMScheduler,
|
15 |
+
)
|
16 |
+
from diffusers.schedulers import TCDScheduler
|
17 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
18 |
+
from diffusers.utils.torch_utils import randn_tensor
|
19 |
+
from transformers import AutoTokenizer, PretrainedConfig
|
20 |
+
|
21 |
+
from libs.unet_motion_model import MotionAdapter, UNetMotionModel
|
22 |
+
from libs.brushnet_CA import BrushNetModel
|
23 |
+
from libs.unet_2d_condition import UNet2DConditionModel
|
24 |
+
from diffueraser.pipeline_diffueraser import StableDiffusionDiffuEraserPipeline
|
25 |
+
|
26 |
+
|
27 |
+
checkpoints = {
|
28 |
+
"2-Step": ["pcm_{}_smallcfg_2step_converted.safetensors", 2, 0.0],
|
29 |
+
"4-Step": ["pcm_{}_smallcfg_4step_converted.safetensors", 4, 0.0],
|
30 |
+
"8-Step": ["pcm_{}_smallcfg_8step_converted.safetensors", 8, 0.0],
|
31 |
+
"16-Step": ["pcm_{}_smallcfg_16step_converted.safetensors", 16, 0.0],
|
32 |
+
"Normal CFG 4-Step": ["pcm_{}_normalcfg_4step_converted.safetensors", 4, 7.5],
|
33 |
+
"Normal CFG 8-Step": ["pcm_{}_normalcfg_8step_converted.safetensors", 8, 7.5],
|
34 |
+
"Normal CFG 16-Step": ["pcm_{}_normalcfg_16step_converted.safetensors", 16, 7.5],
|
35 |
+
"LCM-Like LoRA": [
|
36 |
+
"pcm_{}_lcmlike_lora_converted.safetensors",
|
37 |
+
4,
|
38 |
+
0.0,
|
39 |
+
],
|
40 |
+
}
|
41 |
+
|
42 |
+
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
43 |
+
text_encoder_config = PretrainedConfig.from_pretrained(
|
44 |
+
pretrained_model_name_or_path,
|
45 |
+
subfolder="text_encoder",
|
46 |
+
revision=revision,
|
47 |
+
)
|
48 |
+
model_class = text_encoder_config.architectures[0]
|
49 |
+
|
50 |
+
if model_class == "CLIPTextModel":
|
51 |
+
from transformers import CLIPTextModel
|
52 |
+
|
53 |
+
return CLIPTextModel
|
54 |
+
elif model_class == "RobertaSeriesModelWithTransformation":
|
55 |
+
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
|
56 |
+
|
57 |
+
return RobertaSeriesModelWithTransformation
|
58 |
+
else:
|
59 |
+
raise ValueError(f"{model_class} is not supported.")
|
60 |
+
|
61 |
+
def resize_frames(frames, size=None):
|
62 |
+
if size is not None:
|
63 |
+
out_size = size
|
64 |
+
process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
|
65 |
+
frames = [f.resize(process_size) for f in frames]
|
66 |
+
else:
|
67 |
+
out_size = frames[0].size
|
68 |
+
process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
|
69 |
+
if not out_size == process_size:
|
70 |
+
frames = [f.resize(process_size) for f in frames]
|
71 |
+
|
72 |
+
return frames
|
73 |
+
|
74 |
+
def read_mask(validation_mask, fps, n_total_frames, img_size, mask_dilation_iter, frames):
|
75 |
+
cap = cv2.VideoCapture(validation_mask)
|
76 |
+
if not cap.isOpened():
|
77 |
+
print("Error: Could not open mask video.")
|
78 |
+
exit()
|
79 |
+
mask_fps = cap.get(cv2.CAP_PROP_FPS)
|
80 |
+
if mask_fps != fps:
|
81 |
+
cap.release()
|
82 |
+
raise ValueError("The frame rate of all input videos needs to be consistent.")
|
83 |
+
|
84 |
+
masks = []
|
85 |
+
masked_images = []
|
86 |
+
idx = 0
|
87 |
+
while True:
|
88 |
+
ret, frame = cap.read()
|
89 |
+
if not ret:
|
90 |
+
break
|
91 |
+
if(idx >= n_total_frames):
|
92 |
+
break
|
93 |
+
mask = Image.fromarray(frame[...,::-1]).convert('L')
|
94 |
+
if mask.size != img_size:
|
95 |
+
mask = mask.resize(img_size, Image.NEAREST)
|
96 |
+
mask = np.asarray(mask)
|
97 |
+
m = np.array(mask > 0).astype(np.uint8)
|
98 |
+
m = cv2.erode(m,
|
99 |
+
cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)),
|
100 |
+
iterations=1)
|
101 |
+
m = cv2.dilate(m,
|
102 |
+
cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)),
|
103 |
+
iterations=mask_dilation_iter)
|
104 |
+
|
105 |
+
mask = Image.fromarray(m * 255)
|
106 |
+
masks.append(mask)
|
107 |
+
|
108 |
+
masked_image = np.array(frames[idx])*(1-(np.array(mask)[:,:,np.newaxis].astype(np.float32)/255))
|
109 |
+
masked_image = Image.fromarray(masked_image.astype(np.uint8))
|
110 |
+
masked_images.append(masked_image)
|
111 |
+
|
112 |
+
idx += 1
|
113 |
+
cap.release()
|
114 |
+
|
115 |
+
return masks, masked_images
|
116 |
+
|
117 |
+
def read_priori(priori, fps, n_total_frames, img_size):
|
118 |
+
cap = cv2.VideoCapture(priori)
|
119 |
+
if not cap.isOpened():
|
120 |
+
print("Error: Could not open video.")
|
121 |
+
exit()
|
122 |
+
priori_fps = cap.get(cv2.CAP_PROP_FPS)
|
123 |
+
if priori_fps != fps:
|
124 |
+
cap.release()
|
125 |
+
raise ValueError("The frame rate of all input videos needs to be consistent.")
|
126 |
+
|
127 |
+
prioris=[]
|
128 |
+
idx = 0
|
129 |
+
while True:
|
130 |
+
ret, frame = cap.read()
|
131 |
+
if not ret:
|
132 |
+
break
|
133 |
+
if(idx >= n_total_frames):
|
134 |
+
break
|
135 |
+
img = Image.fromarray(frame[...,::-1])
|
136 |
+
if img.size != img_size:
|
137 |
+
img = img.resize(img_size)
|
138 |
+
prioris.append(img)
|
139 |
+
idx += 1
|
140 |
+
cap.release()
|
141 |
+
|
142 |
+
os.remove(priori) # remove priori
|
143 |
+
|
144 |
+
return prioris
|
145 |
+
|
146 |
+
def read_video(validation_image, video_length, nframes, max_img_size):
|
147 |
+
vframes, aframes, info = torchvision.io.read_video(filename=validation_image, pts_unit='sec', end_pts=video_length) # RGB
|
148 |
+
fps = info['video_fps']
|
149 |
+
n_total_frames = int(video_length * fps)
|
150 |
+
n_clip = int(np.ceil(n_total_frames/nframes))
|
151 |
+
|
152 |
+
frames = list(vframes.numpy())[:n_total_frames]
|
153 |
+
frames = [Image.fromarray(f) for f in frames]
|
154 |
+
max_size = max(frames[0].size)
|
155 |
+
if(max_size<256):
|
156 |
+
raise ValueError("The resolution of the uploaded video must be larger than 256x256.")
|
157 |
+
if(max_size>4096):
|
158 |
+
raise ValueError("The resolution of the uploaded video must be smaller than 4096x4096.")
|
159 |
+
if max_size>max_img_size:
|
160 |
+
ratio = max_size/max_img_size
|
161 |
+
ratio_size = (int(frames[0].size[0]/ratio),int(frames[0].size[1]/ratio))
|
162 |
+
img_size = (ratio_size[0]-ratio_size[0]%8, ratio_size[1]-ratio_size[1]%8)
|
163 |
+
resize_flag=True
|
164 |
+
elif (frames[0].size[0]%8==0) and (frames[0].size[1]%8==0):
|
165 |
+
img_size = frames[0].size
|
166 |
+
resize_flag=False
|
167 |
+
else:
|
168 |
+
ratio_size = frames[0].size
|
169 |
+
img_size = (ratio_size[0]-ratio_size[0]%8, ratio_size[1]-ratio_size[1]%8)
|
170 |
+
resize_flag=True
|
171 |
+
if resize_flag:
|
172 |
+
frames = resize_frames(frames, img_size)
|
173 |
+
img_size = frames[0].size
|
174 |
+
|
175 |
+
return frames, fps, img_size, n_clip, n_total_frames
|
176 |
+
|
177 |
+
|
178 |
+
class DiffuEraser:
|
179 |
+
def __init__(
|
180 |
+
self, device, base_model_path, vae_path, diffueraser_path, revision=None,
|
181 |
+
ckpt="Normal CFG 4-Step", mode="sd15", loaded=None):
|
182 |
+
self.device = device
|
183 |
+
|
184 |
+
## load model
|
185 |
+
self.vae = AutoencoderKL.from_pretrained(vae_path)
|
186 |
+
self.noise_scheduler = DDPMScheduler.from_pretrained(base_model_path,
|
187 |
+
subfolder="scheduler",
|
188 |
+
prediction_type="v_prediction",
|
189 |
+
timestep_spacing="trailing",
|
190 |
+
rescale_betas_zero_snr=True
|
191 |
+
)
|
192 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
193 |
+
base_model_path,
|
194 |
+
subfolder="tokenizer",
|
195 |
+
use_fast=False,
|
196 |
+
)
|
197 |
+
text_encoder_cls = import_model_class_from_model_name_or_path(base_model_path,revision)
|
198 |
+
self.text_encoder = text_encoder_cls.from_pretrained(
|
199 |
+
base_model_path, subfolder="text_encoder"
|
200 |
+
)
|
201 |
+
self.brushnet = BrushNetModel.from_pretrained(diffueraser_path, subfolder="brushnet")
|
202 |
+
self.unet_main = UNetMotionModel.from_pretrained(
|
203 |
+
diffueraser_path, subfolder="unet_main",
|
204 |
+
)
|
205 |
+
|
206 |
+
## set pipeline
|
207 |
+
self.pipeline = StableDiffusionDiffuEraserPipeline.from_pretrained(
|
208 |
+
base_model_path,
|
209 |
+
vae=self.vae,
|
210 |
+
text_encoder=self.text_encoder,
|
211 |
+
tokenizer=self.tokenizer,
|
212 |
+
unet=self.unet_main,
|
213 |
+
brushnet=self.brushnet
|
214 |
+
).to(self.device, torch.float16)
|
215 |
+
self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config)
|
216 |
+
self.pipeline.set_progress_bar_config(disable=True)
|
217 |
+
|
218 |
+
self.noise_scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config)
|
219 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
220 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
221 |
+
|
222 |
+
## use PCM
|
223 |
+
self.ckpt = ckpt
|
224 |
+
PCM_ckpts = checkpoints[ckpt][0].format(mode)
|
225 |
+
self.guidance_scale = checkpoints[ckpt][2]
|
226 |
+
if loaded != (ckpt + mode):
|
227 |
+
self.pipeline.load_lora_weights(
|
228 |
+
"weights/PCM_Weights", weight_name=PCM_ckpts, subfolder=mode
|
229 |
+
)
|
230 |
+
loaded = ckpt + mode
|
231 |
+
|
232 |
+
if ckpt == "LCM-Like LoRA":
|
233 |
+
self.pipeline.scheduler = LCMScheduler()
|
234 |
+
else:
|
235 |
+
self.pipeline.scheduler = TCDScheduler(
|
236 |
+
num_train_timesteps=1000,
|
237 |
+
beta_start=0.00085,
|
238 |
+
beta_end=0.012,
|
239 |
+
beta_schedule="scaled_linear",
|
240 |
+
timestep_spacing="trailing",
|
241 |
+
)
|
242 |
+
self.num_inference_steps = checkpoints[ckpt][1]
|
243 |
+
self.guidance_scale = 0
|
244 |
+
|
245 |
+
def forward(self, validation_image, validation_mask, priori, output_path,
|
246 |
+
max_img_size = 1280, video_length=2, mask_dilation_iter=4,
|
247 |
+
nframes=22, seed=None, revision = None, guidance_scale=None, blended=True):
|
248 |
+
validation_prompt = "" #
|
249 |
+
guidance_scale_final = self.guidance_scale if guidance_scale==None else guidance_scale
|
250 |
+
|
251 |
+
if (max_img_size<256 or max_img_size>1920):
|
252 |
+
raise ValueError("The max_img_size must be larger than 256, smaller than 1920.")
|
253 |
+
|
254 |
+
################ read input video ################
|
255 |
+
frames, fps, img_size, n_clip, n_total_frames = read_video(validation_image, video_length, nframes, max_img_size)
|
256 |
+
video_len = len(frames)
|
257 |
+
|
258 |
+
################ read mask ################
|
259 |
+
validation_masks_input, validation_images_input = read_mask(validation_mask, fps, video_len, img_size, mask_dilation_iter, frames)
|
260 |
+
|
261 |
+
################ read priori ################
|
262 |
+
prioris = read_priori(priori, fps, n_total_frames, img_size)
|
263 |
+
|
264 |
+
## recheck
|
265 |
+
n_total_frames = min(min(len(frames), len(validation_masks_input)), len(prioris))
|
266 |
+
if(n_total_frames<22):
|
267 |
+
raise ValueError("The effective video duration is too short. Please make sure that the number of frames of video, mask, and priori is at least greater than 22 frames.")
|
268 |
+
validation_masks_input = validation_masks_input[:n_total_frames]
|
269 |
+
validation_images_input = validation_images_input[:n_total_frames]
|
270 |
+
frames = frames[:n_total_frames]
|
271 |
+
prioris = prioris[:n_total_frames]
|
272 |
+
|
273 |
+
prioris = resize_frames(prioris)
|
274 |
+
validation_masks_input = resize_frames(validation_masks_input)
|
275 |
+
validation_images_input = resize_frames(validation_images_input)
|
276 |
+
resized_frames = resize_frames(frames)
|
277 |
+
|
278 |
+
##############################################
|
279 |
+
# DiffuEraser inference
|
280 |
+
##############################################
|
281 |
+
print("DiffuEraser inference...")
|
282 |
+
if seed is None:
|
283 |
+
generator = None
|
284 |
+
else:
|
285 |
+
generator = torch.Generator(device=self.device).manual_seed(seed)
|
286 |
+
|
287 |
+
## random noise
|
288 |
+
real_video_length = len(validation_images_input)
|
289 |
+
tar_width, tar_height = validation_images_input[0].size
|
290 |
+
shape = (
|
291 |
+
nframes,
|
292 |
+
4,
|
293 |
+
tar_height//8,
|
294 |
+
tar_width//8
|
295 |
+
)
|
296 |
+
if self.text_encoder is not None:
|
297 |
+
prompt_embeds_dtype = self.text_encoder.dtype
|
298 |
+
elif self.unet_main is not None:
|
299 |
+
prompt_embeds_dtype = self.unet_main.dtype
|
300 |
+
else:
|
301 |
+
prompt_embeds_dtype = torch.float16
|
302 |
+
noise_pre = randn_tensor(shape, device=torch.device(self.device), dtype=prompt_embeds_dtype, generator=generator)
|
303 |
+
noise = repeat(noise_pre, "t c h w->(repeat t) c h w", repeat=n_clip)[:real_video_length,...]
|
304 |
+
|
305 |
+
################ prepare priori ################
|
306 |
+
images_preprocessed = []
|
307 |
+
for image in prioris:
|
308 |
+
image = self.image_processor.preprocess(image, height=tar_height, width=tar_width).to(dtype=torch.float32)
|
309 |
+
image = image.to(device=torch.device(self.device), dtype=torch.float16)
|
310 |
+
images_preprocessed.append(image)
|
311 |
+
pixel_values = torch.cat(images_preprocessed)
|
312 |
+
|
313 |
+
with torch.no_grad():
|
314 |
+
pixel_values = pixel_values.to(dtype=torch.float16)
|
315 |
+
latents = []
|
316 |
+
num=4
|
317 |
+
for i in range(0, pixel_values.shape[0], num):
|
318 |
+
latents.append(self.vae.encode(pixel_values[i : i + num]).latent_dist.sample())
|
319 |
+
latents = torch.cat(latents, dim=0)
|
320 |
+
latents = latents * self.vae.config.scaling_factor #[(b f), c1, h, w], c1=4
|
321 |
+
torch.cuda.empty_cache()
|
322 |
+
timesteps = torch.tensor([0], device=self.device)
|
323 |
+
timesteps = timesteps.long()
|
324 |
+
|
325 |
+
validation_masks_input_ori = copy.deepcopy(validation_masks_input)
|
326 |
+
resized_frames_ori = copy.deepcopy(resized_frames)
|
327 |
+
################ Pre-inference ################
|
328 |
+
if n_total_frames > nframes*2: ## do pre-inference only when number of input frames is larger than nframes*2
|
329 |
+
## sample
|
330 |
+
step = n_total_frames / nframes
|
331 |
+
sample_index = [int(i * step) for i in range(nframes)]
|
332 |
+
sample_index = sample_index[:22]
|
333 |
+
validation_masks_input_pre = [validation_masks_input[i] for i in sample_index]
|
334 |
+
validation_images_input_pre = [validation_images_input[i] for i in sample_index]
|
335 |
+
latents_pre = torch.stack([latents[i] for i in sample_index])
|
336 |
+
|
337 |
+
## add proiri
|
338 |
+
noisy_latents_pre = self.noise_scheduler.add_noise(latents_pre, noise_pre, timesteps)
|
339 |
+
latents_pre = noisy_latents_pre
|
340 |
+
|
341 |
+
with torch.no_grad():
|
342 |
+
latents_pre_out = self.pipeline(
|
343 |
+
num_frames=nframes,
|
344 |
+
prompt=validation_prompt,
|
345 |
+
images=validation_images_input_pre,
|
346 |
+
masks=validation_masks_input_pre,
|
347 |
+
num_inference_steps=self.num_inference_steps,
|
348 |
+
generator=generator,
|
349 |
+
guidance_scale=guidance_scale_final,
|
350 |
+
latents=latents_pre,
|
351 |
+
).latents
|
352 |
+
torch.cuda.empty_cache()
|
353 |
+
|
354 |
+
def decode_latents(latents, weight_dtype):
|
355 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
356 |
+
video = []
|
357 |
+
for t in range(latents.shape[0]):
|
358 |
+
video.append(self.vae.decode(latents[t:t+1, ...].to(weight_dtype)).sample)
|
359 |
+
video = torch.concat(video, dim=0)
|
360 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
361 |
+
video = video.float()
|
362 |
+
return video
|
363 |
+
with torch.no_grad():
|
364 |
+
video_tensor_temp = decode_latents(latents_pre_out, weight_dtype=torch.float16)
|
365 |
+
images_pre_out = self.image_processor.postprocess(video_tensor_temp, output_type="pil")
|
366 |
+
torch.cuda.empty_cache()
|
367 |
+
|
368 |
+
## replace input frames with updated frames
|
369 |
+
black_image = Image.new('L', validation_masks_input[0].size, color=0)
|
370 |
+
for i,index in enumerate(sample_index):
|
371 |
+
latents[index] = latents_pre_out[i]
|
372 |
+
validation_masks_input[index] = black_image
|
373 |
+
validation_images_input[index] = images_pre_out[i]
|
374 |
+
resized_frames[index] = images_pre_out[i]
|
375 |
+
else:
|
376 |
+
latents_pre_out=None
|
377 |
+
sample_index=None
|
378 |
+
gc.collect()
|
379 |
+
torch.cuda.empty_cache()
|
380 |
+
|
381 |
+
################ Frame-by-frame inference ################
|
382 |
+
## add priori
|
383 |
+
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
|
384 |
+
latents = noisy_latents
|
385 |
+
with torch.no_grad():
|
386 |
+
images = self.pipeline(
|
387 |
+
num_frames=nframes,
|
388 |
+
prompt=validation_prompt,
|
389 |
+
images=validation_images_input,
|
390 |
+
masks=validation_masks_input,
|
391 |
+
num_inference_steps=self.num_inference_steps,
|
392 |
+
generator=generator,
|
393 |
+
guidance_scale=guidance_scale_final,
|
394 |
+
latents=latents,
|
395 |
+
).frames
|
396 |
+
images = images[:real_video_length]
|
397 |
+
|
398 |
+
gc.collect()
|
399 |
+
torch.cuda.empty_cache()
|
400 |
+
|
401 |
+
################ Compose ################
|
402 |
+
binary_masks = validation_masks_input_ori
|
403 |
+
mask_blurreds = []
|
404 |
+
if blended:
|
405 |
+
# blur, you can adjust the parameters for better performance
|
406 |
+
for i in range(len(binary_masks)):
|
407 |
+
mask_blurred = cv2.GaussianBlur(np.array(binary_masks[i]), (21, 21), 0)/255.
|
408 |
+
binary_mask = 1-(1-np.array(binary_masks[i])/255.) * (1-mask_blurred)
|
409 |
+
mask_blurreds.append(Image.fromarray((binary_mask*255).astype(np.uint8)))
|
410 |
+
binary_masks = mask_blurreds
|
411 |
+
|
412 |
+
comp_frames = []
|
413 |
+
for i in range(len(images)):
|
414 |
+
mask = np.expand_dims(np.array(binary_masks[i]),2).repeat(3, axis=2).astype(np.float32)/255.
|
415 |
+
img = (np.array(images[i]).astype(np.uint8) * mask \
|
416 |
+
+ np.array(resized_frames_ori[i]).astype(np.uint8) * (1 - mask)).astype(np.uint8)
|
417 |
+
comp_frames.append(Image.fromarray(img))
|
418 |
+
|
419 |
+
default_fps = fps
|
420 |
+
writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"),
|
421 |
+
default_fps, comp_frames[0].size)
|
422 |
+
for f in range(real_video_length):
|
423 |
+
img = np.array(comp_frames[f]).astype(np.uint8)
|
424 |
+
writer.write(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
425 |
+
writer.release()
|
426 |
+
################################
|
427 |
+
|
428 |
+
return output_path
|
429 |
+
|
430 |
+
|
431 |
+
|
432 |
+
|
diffueraser/pipeline_diffueraser.py
ADDED
@@ -0,0 +1,1349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
3 |
+
import numpy as np
|
4 |
+
import PIL.Image
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
from dataclasses import dataclass
|
7 |
+
import copy
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
11 |
+
|
12 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
13 |
+
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
14 |
+
from diffusers.models import AutoencoderKL, ImageProjection
|
15 |
+
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
16 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
17 |
+
from diffusers.utils import (
|
18 |
+
USE_PEFT_BACKEND,
|
19 |
+
deprecate,
|
20 |
+
logging,
|
21 |
+
replace_example_docstring,
|
22 |
+
scale_lora_layers,
|
23 |
+
unscale_lora_layers,
|
24 |
+
BaseOutput
|
25 |
+
)
|
26 |
+
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
27 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
28 |
+
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
29 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
30 |
+
from diffusers import (
|
31 |
+
AutoencoderKL,
|
32 |
+
DDPMScheduler,
|
33 |
+
UniPCMultistepScheduler,
|
34 |
+
)
|
35 |
+
|
36 |
+
from libs.unet_2d_condition import UNet2DConditionModel
|
37 |
+
from libs.brushnet_CA import BrushNetModel
|
38 |
+
|
39 |
+
|
40 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
41 |
+
|
42 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
43 |
+
def retrieve_timesteps(
|
44 |
+
scheduler,
|
45 |
+
num_inference_steps: Optional[int] = None,
|
46 |
+
device: Optional[Union[str, torch.device]] = None,
|
47 |
+
timesteps: Optional[List[int]] = None,
|
48 |
+
**kwargs,
|
49 |
+
):
|
50 |
+
"""
|
51 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
52 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
scheduler (`SchedulerMixin`):
|
56 |
+
The scheduler to get timesteps from.
|
57 |
+
num_inference_steps (`int`):
|
58 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
59 |
+
`timesteps` must be `None`.
|
60 |
+
device (`str` or `torch.device`, *optional*):
|
61 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
62 |
+
timesteps (`List[int]`, *optional*):
|
63 |
+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
64 |
+
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
65 |
+
must be `None`.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
69 |
+
second element is the number of inference steps.
|
70 |
+
"""
|
71 |
+
if timesteps is not None:
|
72 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
73 |
+
if not accepts_timesteps:
|
74 |
+
raise ValueError(
|
75 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
76 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
77 |
+
)
|
78 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
79 |
+
timesteps = scheduler.timesteps
|
80 |
+
num_inference_steps = len(timesteps)
|
81 |
+
else:
|
82 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
83 |
+
timesteps = scheduler.timesteps
|
84 |
+
return timesteps, num_inference_steps
|
85 |
+
|
86 |
+
def get_frames_context_swap(total_frames=192, overlap=4, num_frames_per_clip=24):
|
87 |
+
if total_frames<num_frames_per_clip:
|
88 |
+
num_frames_per_clip = total_frames
|
89 |
+
context_list = []
|
90 |
+
context_list_swap = []
|
91 |
+
for i in range(1, 2): # i=1
|
92 |
+
sample_interval = np.array(range(0,total_frames,i))
|
93 |
+
n = len(sample_interval)
|
94 |
+
if n>num_frames_per_clip:
|
95 |
+
## [0,num_frames_per_clip-1], [num_frames_per_clip, 2*num_frames_per_clip-1]....
|
96 |
+
for k in range(0,n-num_frames_per_clip,num_frames_per_clip-overlap):
|
97 |
+
context_list.append(sample_interval[k:k+num_frames_per_clip])
|
98 |
+
if k+num_frames_per_clip < n and i==1:
|
99 |
+
context_list.append(sample_interval[n-num_frames_per_clip:n])
|
100 |
+
context_list_swap.append(sample_interval[0:num_frames_per_clip])
|
101 |
+
for k in range(num_frames_per_clip//2, n-num_frames_per_clip, num_frames_per_clip-overlap):
|
102 |
+
context_list_swap.append(sample_interval[k:k+num_frames_per_clip])
|
103 |
+
if k+num_frames_per_clip < n and i==1:
|
104 |
+
context_list_swap.append(sample_interval[n-num_frames_per_clip:n])
|
105 |
+
if n==num_frames_per_clip:
|
106 |
+
context_list.append(sample_interval[n-num_frames_per_clip:n])
|
107 |
+
context_list_swap.append(sample_interval[n-num_frames_per_clip:n])
|
108 |
+
return context_list, context_list_swap
|
109 |
+
|
110 |
+
@dataclass
|
111 |
+
class DiffuEraserPipelineOutput(BaseOutput):
|
112 |
+
frames: Union[torch.Tensor, np.ndarray]
|
113 |
+
latents: Union[torch.Tensor, np.ndarray]
|
114 |
+
|
115 |
+
class StableDiffusionDiffuEraserPipeline(
|
116 |
+
DiffusionPipeline,
|
117 |
+
StableDiffusionMixin,
|
118 |
+
TextualInversionLoaderMixin,
|
119 |
+
LoraLoaderMixin,
|
120 |
+
IPAdapterMixin,
|
121 |
+
FromSingleFileMixin,
|
122 |
+
):
|
123 |
+
r"""
|
124 |
+
Pipeline for video inpainting using Video Diffusion Model with BrushNet guidance.
|
125 |
+
|
126 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
127 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
128 |
+
|
129 |
+
The pipeline also inherits the following loading methods:
|
130 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
131 |
+
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
132 |
+
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
133 |
+
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
134 |
+
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
135 |
+
|
136 |
+
Args:
|
137 |
+
vae ([`AutoencoderKL`]):
|
138 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
139 |
+
text_encoder ([`~transformers.CLIPTextModel`]):
|
140 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
141 |
+
tokenizer ([`~transformers.CLIPTokenizer`]):
|
142 |
+
A `CLIPTokenizer` to tokenize text.
|
143 |
+
unet ([`UNet2DConditionModel`]):
|
144 |
+
A `UNet2DConditionModel` to denoise the encoded image latents.
|
145 |
+
brushnet ([`BrushNetModel`]`):
|
146 |
+
Provides additional conditioning to the `unet` during the denoising process.
|
147 |
+
scheduler ([`SchedulerMixin`]):
|
148 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
149 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
150 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
151 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
152 |
+
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
153 |
+
about a model's potential harms.
|
154 |
+
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
155 |
+
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
156 |
+
"""
|
157 |
+
|
158 |
+
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
|
159 |
+
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
|
160 |
+
_exclude_from_cpu_offload = ["safety_checker"]
|
161 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
162 |
+
|
163 |
+
def __init__(
|
164 |
+
self,
|
165 |
+
vae: AutoencoderKL,
|
166 |
+
text_encoder: CLIPTextModel,
|
167 |
+
tokenizer: CLIPTokenizer,
|
168 |
+
unet: UNet2DConditionModel,
|
169 |
+
brushnet: BrushNetModel,
|
170 |
+
scheduler: KarrasDiffusionSchedulers,
|
171 |
+
safety_checker: StableDiffusionSafetyChecker,
|
172 |
+
feature_extractor: CLIPImageProcessor,
|
173 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
174 |
+
requires_safety_checker: bool = True,
|
175 |
+
):
|
176 |
+
super().__init__()
|
177 |
+
|
178 |
+
if safety_checker is None and requires_safety_checker:
|
179 |
+
logger.warning(
|
180 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
181 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
182 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
183 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
184 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
185 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
186 |
+
)
|
187 |
+
|
188 |
+
if safety_checker is not None and feature_extractor is None:
|
189 |
+
raise ValueError(
|
190 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
191 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
192 |
+
)
|
193 |
+
|
194 |
+
self.register_modules(
|
195 |
+
vae=vae,
|
196 |
+
text_encoder=text_encoder,
|
197 |
+
tokenizer=tokenizer,
|
198 |
+
unet=unet,
|
199 |
+
brushnet=brushnet,
|
200 |
+
scheduler=scheduler,
|
201 |
+
safety_checker=safety_checker,
|
202 |
+
feature_extractor=feature_extractor,
|
203 |
+
image_encoder=image_encoder,
|
204 |
+
)
|
205 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
206 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
207 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
208 |
+
|
209 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
210 |
+
def _encode_prompt(
|
211 |
+
self,
|
212 |
+
prompt,
|
213 |
+
device,
|
214 |
+
num_images_per_prompt,
|
215 |
+
do_classifier_free_guidance,
|
216 |
+
negative_prompt=None,
|
217 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
218 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
219 |
+
lora_scale: Optional[float] = None,
|
220 |
+
**kwargs,
|
221 |
+
):
|
222 |
+
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
223 |
+
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
224 |
+
|
225 |
+
prompt_embeds_tuple = self.encode_prompt(
|
226 |
+
prompt=prompt,
|
227 |
+
device=device,
|
228 |
+
num_images_per_prompt=num_images_per_prompt,
|
229 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
230 |
+
negative_prompt=negative_prompt,
|
231 |
+
prompt_embeds=prompt_embeds,
|
232 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
233 |
+
lora_scale=lora_scale,
|
234 |
+
**kwargs,
|
235 |
+
)
|
236 |
+
|
237 |
+
# concatenate for backwards comp
|
238 |
+
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
239 |
+
|
240 |
+
return prompt_embeds
|
241 |
+
|
242 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
243 |
+
def encode_prompt(
|
244 |
+
self,
|
245 |
+
prompt,
|
246 |
+
device,
|
247 |
+
num_images_per_prompt,
|
248 |
+
do_classifier_free_guidance,
|
249 |
+
negative_prompt=None,
|
250 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
251 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
252 |
+
lora_scale: Optional[float] = None,
|
253 |
+
clip_skip: Optional[int] = None,
|
254 |
+
):
|
255 |
+
r"""
|
256 |
+
Encodes the prompt into text encoder hidden states.
|
257 |
+
|
258 |
+
Args:
|
259 |
+
prompt (`str` or `List[str]`, *optional*):
|
260 |
+
prompt to be encoded
|
261 |
+
device: (`torch.device`):
|
262 |
+
torch device
|
263 |
+
num_images_per_prompt (`int`):
|
264 |
+
number of images that should be generated per prompt
|
265 |
+
do_classifier_free_guidance (`bool`):
|
266 |
+
whether to use classifier free guidance or not
|
267 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
268 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
269 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
270 |
+
less than `1`).
|
271 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
272 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
273 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
274 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
275 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
276 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
277 |
+
argument.
|
278 |
+
lora_scale (`float`, *optional*):
|
279 |
+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
280 |
+
clip_skip (`int`, *optional*):
|
281 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
282 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
283 |
+
"""
|
284 |
+
# set lora scale so that monkey patched LoRA
|
285 |
+
# function of text encoder can correctly access it
|
286 |
+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
287 |
+
self._lora_scale = lora_scale
|
288 |
+
|
289 |
+
# dynamically adjust the LoRA scale
|
290 |
+
if not USE_PEFT_BACKEND:
|
291 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
292 |
+
else:
|
293 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
294 |
+
|
295 |
+
if prompt is not None and isinstance(prompt, str):
|
296 |
+
batch_size = 1
|
297 |
+
elif prompt is not None and isinstance(prompt, list):
|
298 |
+
batch_size = len(prompt)
|
299 |
+
else:
|
300 |
+
batch_size = prompt_embeds.shape[0]
|
301 |
+
|
302 |
+
if prompt_embeds is None:
|
303 |
+
# textual inversion: process multi-vector tokens if necessary
|
304 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
305 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
306 |
+
|
307 |
+
text_inputs = self.tokenizer(
|
308 |
+
prompt,
|
309 |
+
padding="max_length",
|
310 |
+
max_length=self.tokenizer.model_max_length,
|
311 |
+
truncation=True,
|
312 |
+
return_tensors="pt",
|
313 |
+
)
|
314 |
+
text_input_ids = text_inputs.input_ids
|
315 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
316 |
+
|
317 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
318 |
+
text_input_ids, untruncated_ids
|
319 |
+
):
|
320 |
+
removed_text = self.tokenizer.batch_decode(
|
321 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
322 |
+
)
|
323 |
+
logger.warning(
|
324 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
325 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
326 |
+
)
|
327 |
+
|
328 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
329 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
330 |
+
else:
|
331 |
+
attention_mask = None
|
332 |
+
|
333 |
+
if clip_skip is None:
|
334 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
335 |
+
prompt_embeds = prompt_embeds[0]
|
336 |
+
else:
|
337 |
+
prompt_embeds = self.text_encoder(
|
338 |
+
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
339 |
+
)
|
340 |
+
# Access the `hidden_states` first, that contains a tuple of
|
341 |
+
# all the hidden states from the encoder layers. Then index into
|
342 |
+
# the tuple to access the hidden states from the desired layer.
|
343 |
+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
344 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
345 |
+
# representations. The `last_hidden_states` that we typically use for
|
346 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
347 |
+
# layer.
|
348 |
+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
349 |
+
|
350 |
+
if self.text_encoder is not None:
|
351 |
+
prompt_embeds_dtype = self.text_encoder.dtype
|
352 |
+
elif self.unet is not None:
|
353 |
+
prompt_embeds_dtype = self.unet.dtype
|
354 |
+
else:
|
355 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
356 |
+
|
357 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
358 |
+
|
359 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
360 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
361 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
362 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
363 |
+
|
364 |
+
# get unconditional embeddings for classifier free guidance
|
365 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
366 |
+
uncond_tokens: List[str]
|
367 |
+
if negative_prompt is None:
|
368 |
+
uncond_tokens = [""] * batch_size
|
369 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
370 |
+
raise TypeError(
|
371 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
372 |
+
f" {type(prompt)}."
|
373 |
+
)
|
374 |
+
elif isinstance(negative_prompt, str):
|
375 |
+
uncond_tokens = [negative_prompt]
|
376 |
+
elif batch_size != len(negative_prompt):
|
377 |
+
raise ValueError(
|
378 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
379 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
380 |
+
" the batch size of `prompt`."
|
381 |
+
)
|
382 |
+
else:
|
383 |
+
uncond_tokens = negative_prompt
|
384 |
+
|
385 |
+
# textual inversion: process multi-vector tokens if necessary
|
386 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
387 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
388 |
+
|
389 |
+
max_length = prompt_embeds.shape[1]
|
390 |
+
uncond_input = self.tokenizer(
|
391 |
+
uncond_tokens,
|
392 |
+
padding="max_length",
|
393 |
+
max_length=max_length,
|
394 |
+
truncation=True,
|
395 |
+
return_tensors="pt",
|
396 |
+
)
|
397 |
+
|
398 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
399 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
400 |
+
else:
|
401 |
+
attention_mask = None
|
402 |
+
|
403 |
+
negative_prompt_embeds = self.text_encoder(
|
404 |
+
uncond_input.input_ids.to(device),
|
405 |
+
attention_mask=attention_mask,
|
406 |
+
)
|
407 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
408 |
+
|
409 |
+
if do_classifier_free_guidance:
|
410 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
411 |
+
seq_len = negative_prompt_embeds.shape[1]
|
412 |
+
|
413 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
414 |
+
|
415 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
416 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
417 |
+
|
418 |
+
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
419 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
420 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
421 |
+
|
422 |
+
return prompt_embeds, negative_prompt_embeds
|
423 |
+
|
424 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
425 |
+
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
426 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
427 |
+
|
428 |
+
if not isinstance(image, torch.Tensor):
|
429 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
430 |
+
|
431 |
+
image = image.to(device=device, dtype=dtype)
|
432 |
+
if output_hidden_states:
|
433 |
+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
434 |
+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
435 |
+
uncond_image_enc_hidden_states = self.image_encoder(
|
436 |
+
torch.zeros_like(image), output_hidden_states=True
|
437 |
+
).hidden_states[-2]
|
438 |
+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
439 |
+
num_images_per_prompt, dim=0
|
440 |
+
)
|
441 |
+
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
442 |
+
else:
|
443 |
+
image_embeds = self.image_encoder(image).image_embeds
|
444 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
445 |
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
446 |
+
|
447 |
+
return image_embeds, uncond_image_embeds
|
448 |
+
|
449 |
+
def decode_latents(self, latents, weight_dtype):
|
450 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
451 |
+
video = []
|
452 |
+
for t in range(latents.shape[0]):
|
453 |
+
video.append(self.vae.decode(latents[t:t+1, ...].to(weight_dtype)).sample)
|
454 |
+
video = torch.concat(video, dim=0)
|
455 |
+
|
456 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
457 |
+
video = video.float()
|
458 |
+
return video
|
459 |
+
|
460 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
461 |
+
def prepare_ip_adapter_image_embeds(
|
462 |
+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
|
463 |
+
):
|
464 |
+
if ip_adapter_image_embeds is None:
|
465 |
+
if not isinstance(ip_adapter_image, list):
|
466 |
+
ip_adapter_image = [ip_adapter_image]
|
467 |
+
|
468 |
+
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
469 |
+
raise ValueError(
|
470 |
+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
471 |
+
)
|
472 |
+
|
473 |
+
image_embeds = []
|
474 |
+
for single_ip_adapter_image, image_proj_layer in zip(
|
475 |
+
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
476 |
+
):
|
477 |
+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
478 |
+
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
479 |
+
single_ip_adapter_image, device, 1, output_hidden_state
|
480 |
+
)
|
481 |
+
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
482 |
+
single_negative_image_embeds = torch.stack(
|
483 |
+
[single_negative_image_embeds] * num_images_per_prompt, dim=0
|
484 |
+
)
|
485 |
+
|
486 |
+
if do_classifier_free_guidance:
|
487 |
+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
488 |
+
single_image_embeds = single_image_embeds.to(device)
|
489 |
+
|
490 |
+
image_embeds.append(single_image_embeds)
|
491 |
+
else:
|
492 |
+
repeat_dims = [1]
|
493 |
+
image_embeds = []
|
494 |
+
for single_image_embeds in ip_adapter_image_embeds:
|
495 |
+
if do_classifier_free_guidance:
|
496 |
+
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
497 |
+
single_image_embeds = single_image_embeds.repeat(
|
498 |
+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
499 |
+
)
|
500 |
+
single_negative_image_embeds = single_negative_image_embeds.repeat(
|
501 |
+
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
502 |
+
)
|
503 |
+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
504 |
+
else:
|
505 |
+
single_image_embeds = single_image_embeds.repeat(
|
506 |
+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
507 |
+
)
|
508 |
+
image_embeds.append(single_image_embeds)
|
509 |
+
|
510 |
+
return image_embeds
|
511 |
+
|
512 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
513 |
+
def run_safety_checker(self, image, device, dtype):
|
514 |
+
if self.safety_checker is None:
|
515 |
+
has_nsfw_concept = None
|
516 |
+
else:
|
517 |
+
if torch.is_tensor(image):
|
518 |
+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
519 |
+
else:
|
520 |
+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
521 |
+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
522 |
+
image, has_nsfw_concept = self.safety_checker(
|
523 |
+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
524 |
+
)
|
525 |
+
return image, has_nsfw_concept
|
526 |
+
|
527 |
+
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
|
528 |
+
def decode_latents(self, latents, weight_dtype):
|
529 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
530 |
+
video = []
|
531 |
+
for t in range(latents.shape[0]):
|
532 |
+
video.append(self.vae.decode(latents[t:t+1, ...].to(weight_dtype)).sample)
|
533 |
+
video = torch.concat(video, dim=0)
|
534 |
+
|
535 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
536 |
+
video = video.float()
|
537 |
+
return video
|
538 |
+
|
539 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
540 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
541 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
542 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
543 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
544 |
+
# and should be between [0, 1]
|
545 |
+
|
546 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
547 |
+
extra_step_kwargs = {}
|
548 |
+
if accepts_eta:
|
549 |
+
extra_step_kwargs["eta"] = eta
|
550 |
+
|
551 |
+
# check if the scheduler accepts generator
|
552 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
553 |
+
if accepts_generator:
|
554 |
+
extra_step_kwargs["generator"] = generator
|
555 |
+
return extra_step_kwargs
|
556 |
+
|
557 |
+
def check_inputs(
|
558 |
+
self,
|
559 |
+
prompt,
|
560 |
+
images,
|
561 |
+
masks,
|
562 |
+
callback_steps,
|
563 |
+
negative_prompt=None,
|
564 |
+
prompt_embeds=None,
|
565 |
+
negative_prompt_embeds=None,
|
566 |
+
ip_adapter_image=None,
|
567 |
+
ip_adapter_image_embeds=None,
|
568 |
+
brushnet_conditioning_scale=1.0,
|
569 |
+
control_guidance_start=0.0,
|
570 |
+
control_guidance_end=1.0,
|
571 |
+
callback_on_step_end_tensor_inputs=None,
|
572 |
+
):
|
573 |
+
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
574 |
+
raise ValueError(
|
575 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
576 |
+
f" {type(callback_steps)}."
|
577 |
+
)
|
578 |
+
|
579 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
580 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
581 |
+
):
|
582 |
+
raise ValueError(
|
583 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
584 |
+
)
|
585 |
+
|
586 |
+
if prompt is not None and prompt_embeds is not None:
|
587 |
+
raise ValueError(
|
588 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
589 |
+
" only forward one of the two."
|
590 |
+
)
|
591 |
+
elif prompt is None and prompt_embeds is None:
|
592 |
+
raise ValueError(
|
593 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
594 |
+
)
|
595 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
596 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
597 |
+
|
598 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
599 |
+
raise ValueError(
|
600 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
601 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
602 |
+
)
|
603 |
+
|
604 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
605 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
606 |
+
raise ValueError(
|
607 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
608 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
609 |
+
f" {negative_prompt_embeds.shape}."
|
610 |
+
)
|
611 |
+
|
612 |
+
# Check `image`
|
613 |
+
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
614 |
+
self.brushnet, torch._dynamo.eval_frame.OptimizedModule
|
615 |
+
)
|
616 |
+
if (
|
617 |
+
isinstance(self.brushnet, BrushNetModel)
|
618 |
+
or is_compiled
|
619 |
+
and isinstance(self.brushnet._orig_mod, BrushNetModel)
|
620 |
+
):
|
621 |
+
self.check_image(images, masks, prompt, prompt_embeds)
|
622 |
+
else:
|
623 |
+
assert False
|
624 |
+
|
625 |
+
# Check `brushnet_conditioning_scale`
|
626 |
+
if (
|
627 |
+
isinstance(self.brushnet, BrushNetModel)
|
628 |
+
or is_compiled
|
629 |
+
and isinstance(self.brushnet._orig_mod, BrushNetModel)
|
630 |
+
):
|
631 |
+
if not isinstance(brushnet_conditioning_scale, float):
|
632 |
+
raise TypeError("For single brushnet: `brushnet_conditioning_scale` must be type `float`.")
|
633 |
+
else:
|
634 |
+
assert False
|
635 |
+
|
636 |
+
if not isinstance(control_guidance_start, (tuple, list)):
|
637 |
+
control_guidance_start = [control_guidance_start]
|
638 |
+
|
639 |
+
if not isinstance(control_guidance_end, (tuple, list)):
|
640 |
+
control_guidance_end = [control_guidance_end]
|
641 |
+
|
642 |
+
if len(control_guidance_start) != len(control_guidance_end):
|
643 |
+
raise ValueError(
|
644 |
+
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
|
645 |
+
)
|
646 |
+
|
647 |
+
for start, end in zip(control_guidance_start, control_guidance_end):
|
648 |
+
if start >= end:
|
649 |
+
raise ValueError(
|
650 |
+
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
|
651 |
+
)
|
652 |
+
if start < 0.0:
|
653 |
+
raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
|
654 |
+
if end > 1.0:
|
655 |
+
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
|
656 |
+
|
657 |
+
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
658 |
+
raise ValueError(
|
659 |
+
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
660 |
+
)
|
661 |
+
|
662 |
+
if ip_adapter_image_embeds is not None:
|
663 |
+
if not isinstance(ip_adapter_image_embeds, list):
|
664 |
+
raise ValueError(
|
665 |
+
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
666 |
+
)
|
667 |
+
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
668 |
+
raise ValueError(
|
669 |
+
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
670 |
+
)
|
671 |
+
|
672 |
+
def check_image(self, images, masks, prompt, prompt_embeds):
|
673 |
+
for image in images:
|
674 |
+
image_is_pil = isinstance(image, PIL.Image.Image)
|
675 |
+
image_is_tensor = isinstance(image, torch.Tensor)
|
676 |
+
image_is_np = isinstance(image, np.ndarray)
|
677 |
+
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
|
678 |
+
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
|
679 |
+
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
|
680 |
+
|
681 |
+
if (
|
682 |
+
not image_is_pil
|
683 |
+
and not image_is_tensor
|
684 |
+
and not image_is_np
|
685 |
+
and not image_is_pil_list
|
686 |
+
and not image_is_tensor_list
|
687 |
+
and not image_is_np_list
|
688 |
+
):
|
689 |
+
raise TypeError(
|
690 |
+
f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
|
691 |
+
)
|
692 |
+
for mask in masks:
|
693 |
+
mask_is_pil = isinstance(mask, PIL.Image.Image)
|
694 |
+
mask_is_tensor = isinstance(mask, torch.Tensor)
|
695 |
+
mask_is_np = isinstance(mask, np.ndarray)
|
696 |
+
mask_is_pil_list = isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image)
|
697 |
+
mask_is_tensor_list = isinstance(mask, list) and isinstance(mask[0], torch.Tensor)
|
698 |
+
mask_is_np_list = isinstance(mask, list) and isinstance(mask[0], np.ndarray)
|
699 |
+
|
700 |
+
if (
|
701 |
+
not mask_is_pil
|
702 |
+
and not mask_is_tensor
|
703 |
+
and not mask_is_np
|
704 |
+
and not mask_is_pil_list
|
705 |
+
and not mask_is_tensor_list
|
706 |
+
and not mask_is_np_list
|
707 |
+
):
|
708 |
+
raise TypeError(
|
709 |
+
f"mask must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(mask)}"
|
710 |
+
)
|
711 |
+
|
712 |
+
if image_is_pil:
|
713 |
+
image_batch_size = 1
|
714 |
+
else:
|
715 |
+
image_batch_size = len(image)
|
716 |
+
|
717 |
+
if prompt is not None and isinstance(prompt, str):
|
718 |
+
prompt_batch_size = 1
|
719 |
+
elif prompt is not None and isinstance(prompt, list):
|
720 |
+
prompt_batch_size = len(prompt)
|
721 |
+
elif prompt_embeds is not None:
|
722 |
+
prompt_batch_size = prompt_embeds.shape[0]
|
723 |
+
|
724 |
+
if image_batch_size != 1 and image_batch_size != prompt_batch_size:
|
725 |
+
raise ValueError(
|
726 |
+
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
|
727 |
+
)
|
728 |
+
|
729 |
+
def prepare_image(
|
730 |
+
self,
|
731 |
+
images,
|
732 |
+
width,
|
733 |
+
height,
|
734 |
+
batch_size,
|
735 |
+
num_images_per_prompt,
|
736 |
+
device,
|
737 |
+
dtype,
|
738 |
+
do_classifier_free_guidance=False,
|
739 |
+
guess_mode=False,
|
740 |
+
):
|
741 |
+
images_new = []
|
742 |
+
for image in images:
|
743 |
+
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
744 |
+
image_batch_size = image.shape[0]
|
745 |
+
|
746 |
+
if image_batch_size == 1:
|
747 |
+
repeat_by = batch_size
|
748 |
+
else:
|
749 |
+
# image batch size is the same as prompt batch size
|
750 |
+
repeat_by = num_images_per_prompt
|
751 |
+
|
752 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
753 |
+
|
754 |
+
image = image.to(device=device, dtype=dtype)
|
755 |
+
|
756 |
+
# if do_classifier_free_guidance and not guess_mode:
|
757 |
+
# image = torch.cat([image] * 2)
|
758 |
+
images_new.append(image)
|
759 |
+
|
760 |
+
return images_new
|
761 |
+
|
762 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
763 |
+
def prepare_latents(self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None):
|
764 |
+
# shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
765 |
+
#b,c,n,h,w
|
766 |
+
shape = (
|
767 |
+
batch_size,
|
768 |
+
num_channels_latents,
|
769 |
+
num_frames,
|
770 |
+
height // self.vae_scale_factor,
|
771 |
+
width // self.vae_scale_factor
|
772 |
+
)
|
773 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
774 |
+
raise ValueError(
|
775 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
776 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
777 |
+
)
|
778 |
+
|
779 |
+
if latents is None:
|
780 |
+
# noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
781 |
+
noise = rearrange(randn_tensor(shape, generator=generator, device=device, dtype=dtype), "b c t h w -> (b t) c h w")
|
782 |
+
else:
|
783 |
+
noise = latents.to(device)
|
784 |
+
|
785 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
786 |
+
latents = noise * self.scheduler.init_noise_sigma
|
787 |
+
return latents, noise
|
788 |
+
|
789 |
+
@staticmethod
|
790 |
+
def temp_blend(a, b, overlap):
|
791 |
+
factor = torch.arange(overlap).to(b.device).view(overlap, 1, 1, 1) / (overlap - 1)
|
792 |
+
a[:overlap, ...] = (1 - factor) * a[:overlap, ...] + factor * b[:overlap, ...]
|
793 |
+
a[overlap:, ...] = b[overlap:, ...]
|
794 |
+
return a
|
795 |
+
|
796 |
+
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
797 |
+
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
798 |
+
"""
|
799 |
+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
800 |
+
|
801 |
+
Args:
|
802 |
+
timesteps (`torch.Tensor`):
|
803 |
+
generate embedding vectors at these timesteps
|
804 |
+
embedding_dim (`int`, *optional*, defaults to 512):
|
805 |
+
dimension of the embeddings to generate
|
806 |
+
dtype:
|
807 |
+
data type of the generated embeddings
|
808 |
+
|
809 |
+
Returns:
|
810 |
+
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
|
811 |
+
"""
|
812 |
+
assert len(w.shape) == 1
|
813 |
+
w = w * 1000.0
|
814 |
+
|
815 |
+
half_dim = embedding_dim // 2
|
816 |
+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
817 |
+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
818 |
+
emb = w.to(dtype)[:, None] * emb[None, :]
|
819 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
820 |
+
if embedding_dim % 2 == 1: # zero pad
|
821 |
+
emb = torch.nn.functional.pad(emb, (0, 1))
|
822 |
+
assert emb.shape == (w.shape[0], embedding_dim)
|
823 |
+
return emb
|
824 |
+
|
825 |
+
@property
|
826 |
+
def guidance_scale(self):
|
827 |
+
return self._guidance_scale
|
828 |
+
|
829 |
+
@property
|
830 |
+
def clip_skip(self):
|
831 |
+
return self._clip_skip
|
832 |
+
|
833 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
834 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
835 |
+
# corresponds to doing no classifier free guidance.
|
836 |
+
@property
|
837 |
+
def do_classifier_free_guidance(self):
|
838 |
+
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
839 |
+
|
840 |
+
@property
|
841 |
+
def cross_attention_kwargs(self):
|
842 |
+
return self._cross_attention_kwargs
|
843 |
+
|
844 |
+
@property
|
845 |
+
def num_timesteps(self):
|
846 |
+
return self._num_timesteps
|
847 |
+
|
848 |
+
# based on BrushNet: https://github.com/TencentARC/BrushNet/blob/main/src/diffusers/pipelines/brushnet/pipeline_brushnet.py
|
849 |
+
@torch.no_grad()
|
850 |
+
def __call__(
|
851 |
+
self,
|
852 |
+
num_frames: Optional[int] = 24,
|
853 |
+
prompt: Union[str, List[str]] = None,
|
854 |
+
images: PipelineImageInput = None, ##masked images
|
855 |
+
masks: PipelineImageInput = None,
|
856 |
+
height: Optional[int] = None,
|
857 |
+
width: Optional[int] = None,
|
858 |
+
num_inference_steps: int = 50,
|
859 |
+
timesteps: List[int] = None,
|
860 |
+
guidance_scale: float = 7.5,
|
861 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
862 |
+
num_images_per_prompt: Optional[int] = 1,
|
863 |
+
eta: float = 0.0,
|
864 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
865 |
+
latents: Optional[torch.FloatTensor] = None,
|
866 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
867 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
868 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
869 |
+
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
|
870 |
+
output_type: Optional[str] = "pil",
|
871 |
+
return_dict: bool = True,
|
872 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
873 |
+
brushnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
874 |
+
guess_mode: bool = False,
|
875 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
876 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
877 |
+
clip_skip: Optional[int] = None,
|
878 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
879 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
880 |
+
**kwargs,
|
881 |
+
):
|
882 |
+
r"""
|
883 |
+
The call function to the pipeline for generation.
|
884 |
+
|
885 |
+
Args:
|
886 |
+
prompt (`str` or `List[str]`, *optional*):
|
887 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
888 |
+
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
889 |
+
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
890 |
+
The BrushNet branch input condition to provide guidance to the `unet` for generation.
|
891 |
+
mask (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
892 |
+
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
893 |
+
The BrushNet branch input condition to provide guidance to the `unet` for generation.
|
894 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
895 |
+
The height in pixels of the generated image.
|
896 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
897 |
+
The width in pixels of the generated image.
|
898 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
899 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
900 |
+
expense of slower inference.
|
901 |
+
timesteps (`List[int]`, *optional*):
|
902 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
903 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
904 |
+
passed will be used. Must be in descending order.
|
905 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
906 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
907 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
908 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
909 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
910 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
911 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
912 |
+
The number of images to generate per prompt.
|
913 |
+
eta (`float`, *optional*, defaults to 0.0):
|
914 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
915 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
916 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
917 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
918 |
+
generation deterministic.
|
919 |
+
latents (`torch.FloatTensor`, *optional*):
|
920 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
921 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
922 |
+
tensor is generated by sampling using the supplied random `generator`.
|
923 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
924 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
925 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
926 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
927 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
928 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
929 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
930 |
+
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
|
931 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
|
932 |
+
Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
|
933 |
+
if `do_classifier_free_guidance` is set to `True`.
|
934 |
+
If not provided, embeddings are computed from the `ip_adapter_image` input argument.
|
935 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
936 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
937 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
938 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
939 |
+
plain tuple.
|
940 |
+
callback (`Callable`, *optional*):
|
941 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
942 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
943 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
944 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
945 |
+
every step.
|
946 |
+
cross_attention_kwargs (`dict`, *optional*):
|
947 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
948 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
949 |
+
brushnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
950 |
+
The outputs of the BrushNet are multiplied by `brushnet_conditioning_scale` before they are added
|
951 |
+
to the residual in the original `unet`. If multiple BrushNets are specified in `init`, you can set
|
952 |
+
the corresponding scale as a list.
|
953 |
+
guess_mode (`bool`, *optional*, defaults to `False`):
|
954 |
+
The BrushNet encoder tries to recognize the content of the input image even if you remove all
|
955 |
+
prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
|
956 |
+
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
957 |
+
The percentage of total steps at which the BrushNet starts applying.
|
958 |
+
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
959 |
+
The percentage of total steps at which the BrushNet stops applying.
|
960 |
+
clip_skip (`int`, *optional*):
|
961 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
962 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
963 |
+
callback_on_step_end (`Callable`, *optional*):
|
964 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
965 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
966 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
967 |
+
`callback_on_step_end_tensor_inputs`.
|
968 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
969 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
970 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
971 |
+
`._callback_tensor_inputs` attribute of your pipeine class.
|
972 |
+
|
973 |
+
Examples:
|
974 |
+
|
975 |
+
Returns:
|
976 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
977 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
978 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
979 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
980 |
+
"not-safe-for-work" (nsfw) content.
|
981 |
+
"""
|
982 |
+
|
983 |
+
callback = kwargs.pop("callback", None)
|
984 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
985 |
+
|
986 |
+
if callback is not None:
|
987 |
+
deprecate(
|
988 |
+
"callback",
|
989 |
+
"1.0.0",
|
990 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
991 |
+
)
|
992 |
+
if callback_steps is not None:
|
993 |
+
deprecate(
|
994 |
+
"callback_steps",
|
995 |
+
"1.0.0",
|
996 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
997 |
+
)
|
998 |
+
|
999 |
+
brushnet = self.brushnet._orig_mod if is_compiled_module(self.brushnet) else self.brushnet
|
1000 |
+
|
1001 |
+
# align format for control guidance
|
1002 |
+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
1003 |
+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
1004 |
+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
1005 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
1006 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
1007 |
+
control_guidance_start, control_guidance_end = (
|
1008 |
+
[control_guidance_start],
|
1009 |
+
[control_guidance_end],
|
1010 |
+
)
|
1011 |
+
|
1012 |
+
# 1. Check inputs. Raise error if not correct
|
1013 |
+
self.check_inputs(
|
1014 |
+
prompt,
|
1015 |
+
images,
|
1016 |
+
masks,
|
1017 |
+
callback_steps,
|
1018 |
+
negative_prompt,
|
1019 |
+
prompt_embeds,
|
1020 |
+
negative_prompt_embeds,
|
1021 |
+
ip_adapter_image,
|
1022 |
+
ip_adapter_image_embeds,
|
1023 |
+
brushnet_conditioning_scale,
|
1024 |
+
control_guidance_start,
|
1025 |
+
control_guidance_end,
|
1026 |
+
callback_on_step_end_tensor_inputs,
|
1027 |
+
)
|
1028 |
+
|
1029 |
+
self._guidance_scale = guidance_scale
|
1030 |
+
self._clip_skip = clip_skip
|
1031 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
1032 |
+
|
1033 |
+
# 2. Define call parameters
|
1034 |
+
if prompt is not None and isinstance(prompt, str):
|
1035 |
+
batch_size = 1
|
1036 |
+
elif prompt is not None and isinstance(prompt, list):
|
1037 |
+
batch_size = len(prompt)
|
1038 |
+
else:
|
1039 |
+
batch_size = prompt_embeds.shape[0]
|
1040 |
+
|
1041 |
+
device = self._execution_device
|
1042 |
+
|
1043 |
+
global_pool_conditions = (
|
1044 |
+
brushnet.config.global_pool_conditions
|
1045 |
+
if isinstance(brushnet, BrushNetModel)
|
1046 |
+
else brushnet.nets[0].config.global_pool_conditions
|
1047 |
+
)
|
1048 |
+
guess_mode = guess_mode or global_pool_conditions
|
1049 |
+
video_length = len(images)
|
1050 |
+
|
1051 |
+
# 3. Encode input prompt
|
1052 |
+
text_encoder_lora_scale = (
|
1053 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
1054 |
+
)
|
1055 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
1056 |
+
prompt,
|
1057 |
+
device,
|
1058 |
+
num_images_per_prompt,
|
1059 |
+
self.do_classifier_free_guidance,
|
1060 |
+
negative_prompt,
|
1061 |
+
prompt_embeds=prompt_embeds,
|
1062 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
1063 |
+
lora_scale=text_encoder_lora_scale,
|
1064 |
+
clip_skip=self.clip_skip,
|
1065 |
+
)
|
1066 |
+
# For classifier free guidance, we need to do two forward passes.
|
1067 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
1068 |
+
# to avoid doing two forward passes
|
1069 |
+
if self.do_classifier_free_guidance:
|
1070 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
1071 |
+
|
1072 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
1073 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
1074 |
+
ip_adapter_image,
|
1075 |
+
ip_adapter_image_embeds,
|
1076 |
+
device,
|
1077 |
+
batch_size * num_images_per_prompt,
|
1078 |
+
self.do_classifier_free_guidance,
|
1079 |
+
)
|
1080 |
+
|
1081 |
+
# 4. Prepare image
|
1082 |
+
if isinstance(brushnet, BrushNetModel):
|
1083 |
+
images = self.prepare_image(
|
1084 |
+
images=images,
|
1085 |
+
width=width,
|
1086 |
+
height=height,
|
1087 |
+
batch_size=batch_size * num_images_per_prompt,
|
1088 |
+
num_images_per_prompt=num_images_per_prompt,
|
1089 |
+
device=device,
|
1090 |
+
dtype=brushnet.dtype,
|
1091 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
1092 |
+
guess_mode=guess_mode,
|
1093 |
+
)
|
1094 |
+
original_masks = self.prepare_image(
|
1095 |
+
images=masks,
|
1096 |
+
width=width,
|
1097 |
+
height=height,
|
1098 |
+
batch_size=batch_size * num_images_per_prompt,
|
1099 |
+
num_images_per_prompt=num_images_per_prompt,
|
1100 |
+
device=device,
|
1101 |
+
dtype=brushnet.dtype,
|
1102 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
1103 |
+
guess_mode=guess_mode,
|
1104 |
+
)
|
1105 |
+
original_masks_new = []
|
1106 |
+
for original_mask in original_masks:
|
1107 |
+
original_mask=(original_mask.sum(1)[:,None,:,:] < 0).to(images[0].dtype)
|
1108 |
+
original_masks_new.append(original_mask)
|
1109 |
+
original_masks = original_masks_new
|
1110 |
+
|
1111 |
+
height, width = images[0].shape[-2:]
|
1112 |
+
else:
|
1113 |
+
assert False
|
1114 |
+
|
1115 |
+
# 5. Prepare timesteps
|
1116 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
1117 |
+
self._num_timesteps = len(timesteps)
|
1118 |
+
|
1119 |
+
# 6. Prepare latent variables
|
1120 |
+
num_channels_latents = self.unet.config.in_channels
|
1121 |
+
latents, noise = self.prepare_latents(
|
1122 |
+
batch_size * num_images_per_prompt,
|
1123 |
+
num_channels_latents,
|
1124 |
+
num_frames,
|
1125 |
+
height,
|
1126 |
+
width,
|
1127 |
+
prompt_embeds.dtype,
|
1128 |
+
device,
|
1129 |
+
generator,
|
1130 |
+
latents,
|
1131 |
+
)
|
1132 |
+
|
1133 |
+
# 6.1 prepare condition latents
|
1134 |
+
images = torch.cat(images)
|
1135 |
+
images = images.to(dtype=images[0].dtype)
|
1136 |
+
conditioning_latents = []
|
1137 |
+
num=4
|
1138 |
+
for i in range(0, images.shape[0], num):
|
1139 |
+
conditioning_latents.append(self.vae.encode(images[i : i + num]).latent_dist.sample())
|
1140 |
+
conditioning_latents = torch.cat(conditioning_latents, dim=0)
|
1141 |
+
|
1142 |
+
conditioning_latents = conditioning_latents * self.vae.config.scaling_factor #[(f c h w],c2=4
|
1143 |
+
|
1144 |
+
original_masks = torch.cat(original_masks)
|
1145 |
+
masks = torch.nn.functional.interpolate(
|
1146 |
+
original_masks,
|
1147 |
+
size=(
|
1148 |
+
latents.shape[-2],
|
1149 |
+
latents.shape[-1]
|
1150 |
+
)
|
1151 |
+
) ##[ f c h w],c=1
|
1152 |
+
|
1153 |
+
conditioning_latents=torch.concat([conditioning_latents,masks],1)
|
1154 |
+
|
1155 |
+
# 6.5 Optionally get Guidance Scale Embedding
|
1156 |
+
timestep_cond = None
|
1157 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
1158 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
1159 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
1160 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
1161 |
+
).to(device=device, dtype=latents.dtype)
|
1162 |
+
|
1163 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
1164 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
1165 |
+
|
1166 |
+
# 7.1 Add image embeds for IP-Adapter
|
1167 |
+
added_cond_kwargs = (
|
1168 |
+
{"image_embeds": image_embeds}
|
1169 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
1170 |
+
else None
|
1171 |
+
)
|
1172 |
+
|
1173 |
+
# 7.2 Create tensor stating which brushnets to keep
|
1174 |
+
brushnet_keep = []
|
1175 |
+
for i in range(len(timesteps)):
|
1176 |
+
keeps = [
|
1177 |
+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
1178 |
+
for s, e in zip(control_guidance_start, control_guidance_end)
|
1179 |
+
]
|
1180 |
+
brushnet_keep.append(keeps[0] if isinstance(brushnet, BrushNetModel) else keeps)
|
1181 |
+
|
1182 |
+
|
1183 |
+
overlap = num_frames//4
|
1184 |
+
context_list, context_list_swap = get_frames_context_swap(video_length, overlap=overlap, num_frames_per_clip=num_frames)
|
1185 |
+
scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(context_list)
|
1186 |
+
scheduler_status_swap = [copy.deepcopy(self.scheduler.__dict__)] * len(context_list_swap)
|
1187 |
+
count = torch.zeros_like(latents)
|
1188 |
+
value = torch.zeros_like(latents)
|
1189 |
+
|
1190 |
+
|
1191 |
+
# 8. Denoising loop
|
1192 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1193 |
+
is_unet_compiled = is_compiled_module(self.unet)
|
1194 |
+
is_brushnet_compiled = is_compiled_module(self.brushnet)
|
1195 |
+
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
1196 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1197 |
+
for i, t in enumerate(timesteps):
|
1198 |
+
|
1199 |
+
count.zero_()
|
1200 |
+
value.zero_()
|
1201 |
+
## swap
|
1202 |
+
if (i%2==1):
|
1203 |
+
context_list_choose = context_list_swap
|
1204 |
+
scheduler_status_choose = scheduler_status_swap
|
1205 |
+
else:
|
1206 |
+
context_list_choose = context_list
|
1207 |
+
scheduler_status_choose = scheduler_status
|
1208 |
+
|
1209 |
+
|
1210 |
+
for j, context in enumerate(context_list_choose):
|
1211 |
+
self.scheduler.__dict__.update(scheduler_status_choose[j])
|
1212 |
+
|
1213 |
+
latents_j = latents[context, :, :, :]
|
1214 |
+
|
1215 |
+
# Relevant thread:
|
1216 |
+
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
|
1217 |
+
if (is_unet_compiled and is_brushnet_compiled) and is_torch_higher_equal_2_1:
|
1218 |
+
torch._inductor.cudagraph_mark_step_begin()
|
1219 |
+
# expand the latents if we are doing classifier free guidance
|
1220 |
+
latent_model_input = torch.cat([latents_j] * 2) if self.do_classifier_free_guidance else latents_j
|
1221 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1222 |
+
|
1223 |
+
# brushnet(s) inference
|
1224 |
+
if guess_mode and self.do_classifier_free_guidance:
|
1225 |
+
# Infer BrushNet only for the conditional batch.
|
1226 |
+
control_model_input = latents_j
|
1227 |
+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
1228 |
+
brushnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
1229 |
+
brushnet_prompt_embeds = rearrange(repeat(brushnet_prompt_embeds, "b c d -> b t c d", t=num_frames), 'b t c d -> (b t) c d')
|
1230 |
+
else:
|
1231 |
+
control_model_input = latent_model_input
|
1232 |
+
brushnet_prompt_embeds = prompt_embeds
|
1233 |
+
if self.do_classifier_free_guidance:
|
1234 |
+
neg_brushnet_prompt_embeds, brushnet_prompt_embeds = brushnet_prompt_embeds.chunk(2)
|
1235 |
+
brushnet_prompt_embeds = rearrange(repeat(brushnet_prompt_embeds, "b c d -> b t c d", t=num_frames), 'b t c d -> (b t) c d')
|
1236 |
+
neg_brushnet_prompt_embeds = rearrange(repeat(neg_brushnet_prompt_embeds, "b c d -> b t c d", t=num_frames), 'b t c d -> (b t) c d')
|
1237 |
+
brushnet_prompt_embeds = torch.cat([neg_brushnet_prompt_embeds, brushnet_prompt_embeds])
|
1238 |
+
else:
|
1239 |
+
brushnet_prompt_embeds = rearrange(repeat(brushnet_prompt_embeds, "b c d -> b t c d", t=num_frames), 'b t c d -> (b t) c d')
|
1240 |
+
|
1241 |
+
if isinstance(brushnet_keep[i], list):
|
1242 |
+
cond_scale = [c * s for c, s in zip(brushnet_conditioning_scale, brushnet_keep[i])]
|
1243 |
+
else:
|
1244 |
+
brushnet_cond_scale = brushnet_conditioning_scale
|
1245 |
+
if isinstance(brushnet_cond_scale, list):
|
1246 |
+
brushnet_cond_scale = brushnet_cond_scale[0]
|
1247 |
+
cond_scale = brushnet_cond_scale * brushnet_keep[i]
|
1248 |
+
|
1249 |
+
|
1250 |
+
down_block_res_samples, mid_block_res_sample, up_block_res_samples = self.brushnet(
|
1251 |
+
control_model_input,
|
1252 |
+
t,
|
1253 |
+
encoder_hidden_states=brushnet_prompt_embeds,
|
1254 |
+
brushnet_cond=torch.cat([conditioning_latents[context, :, :, :]]*2) if self.do_classifier_free_guidance else conditioning_latents[context, :, :, :],
|
1255 |
+
conditioning_scale=cond_scale,
|
1256 |
+
guess_mode=guess_mode,
|
1257 |
+
return_dict=False,
|
1258 |
+
)
|
1259 |
+
|
1260 |
+
if guess_mode and self.do_classifier_free_guidance:
|
1261 |
+
# Infered BrushNet only for the conditional batch.
|
1262 |
+
# To apply the output of BrushNet to both the unconditional and conditional batches,
|
1263 |
+
# add 0 to the unconditional batch to keep it unchanged.
|
1264 |
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
1265 |
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
1266 |
+
up_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in up_block_res_samples]
|
1267 |
+
|
1268 |
+
# predict the noise residual
|
1269 |
+
noise_pred = self.unet(
|
1270 |
+
latent_model_input,
|
1271 |
+
t,
|
1272 |
+
encoder_hidden_states=prompt_embeds,
|
1273 |
+
timestep_cond=timestep_cond,
|
1274 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
1275 |
+
down_block_add_samples=down_block_res_samples,
|
1276 |
+
mid_block_add_sample=mid_block_res_sample,
|
1277 |
+
up_block_add_samples=up_block_res_samples,
|
1278 |
+
added_cond_kwargs=added_cond_kwargs,
|
1279 |
+
return_dict=False,
|
1280 |
+
num_frames=num_frames,
|
1281 |
+
)[0]
|
1282 |
+
|
1283 |
+
# perform guidance
|
1284 |
+
if self.do_classifier_free_guidance:
|
1285 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1286 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1287 |
+
|
1288 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1289 |
+
latents_j = self.scheduler.step(noise_pred, t, latents_j, **extra_step_kwargs, return_dict=False)[0]
|
1290 |
+
|
1291 |
+
count[context, ...] += 1
|
1292 |
+
|
1293 |
+
if j==0:
|
1294 |
+
value[context, ...] += latents_j
|
1295 |
+
else:
|
1296 |
+
overlap_index_list = [index for index, value in enumerate(count[context, 0, 0, 0]) if value > 1]
|
1297 |
+
overlap_cur = len(overlap_index_list)
|
1298 |
+
ratio_next = torch.linspace(0, 1, overlap_cur+2)[1:-1]
|
1299 |
+
ratio_pre = 1-ratio_next
|
1300 |
+
for i_overlap in overlap_index_list:
|
1301 |
+
value[context[i_overlap], ...] = value[context[i_overlap], ...]*ratio_pre[i_overlap] + latents_j[i_overlap, ...]*ratio_next[i_overlap]
|
1302 |
+
value[context[i_overlap:num_frames], ...] = latents_j[i_overlap:num_frames, ...]
|
1303 |
+
|
1304 |
+
latents = value.clone()
|
1305 |
+
|
1306 |
+
if callback_on_step_end is not None:
|
1307 |
+
callback_kwargs = {}
|
1308 |
+
for k in callback_on_step_end_tensor_inputs:
|
1309 |
+
callback_kwargs[k] = locals()[k]
|
1310 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1311 |
+
|
1312 |
+
latents = callback_outputs.pop("latents", latents)
|
1313 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1314 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
1315 |
+
|
1316 |
+
# call the callback, if provided
|
1317 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1318 |
+
progress_bar.update()
|
1319 |
+
if callback is not None and i % callback_steps == 0:
|
1320 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
1321 |
+
callback(step_idx, t, latents)
|
1322 |
+
|
1323 |
+
|
1324 |
+
# If we do sequential model offloading, let's offload unet and brushnet
|
1325 |
+
# manually for max memory savings
|
1326 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
1327 |
+
self.unet.to("cpu")
|
1328 |
+
self.brushnet.to("cpu")
|
1329 |
+
torch.cuda.empty_cache()
|
1330 |
+
|
1331 |
+
if output_type == "latent":
|
1332 |
+
image = latents
|
1333 |
+
has_nsfw_concept = None
|
1334 |
+
return DiffuEraserPipelineOutput(frames=image, nsfw_content_detected=has_nsfw_concept)
|
1335 |
+
|
1336 |
+
video_tensor = self.decode_latents(latents, weight_dtype=prompt_embeds.dtype)
|
1337 |
+
|
1338 |
+
if output_type == "pt":
|
1339 |
+
video = video_tensor
|
1340 |
+
else:
|
1341 |
+
video = self.image_processor.postprocess(video_tensor, output_type=output_type)
|
1342 |
+
|
1343 |
+
# Offload all models
|
1344 |
+
self.maybe_free_model_hooks()
|
1345 |
+
|
1346 |
+
if not return_dict:
|
1347 |
+
return (video, has_nsfw_concept)
|
1348 |
+
|
1349 |
+
return DiffuEraserPipelineOutput(frames=video, latents=latents)
|
examples/example1/mask.mp4
ADDED
Binary file (716 kB). View file
|
|
examples/example1/video.mp4
ADDED
Binary file (672 kB). View file
|
|
examples/example2/mask.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:39849531b31960ee023cd33caf402afd4a4c1402276ba8afa04b7888feb52c3f
|
3 |
+
size 1249680
|
examples/example2/video.mp4
ADDED
Binary file (684 kB). View file
|
|
examples/example3/mask.mp4
ADDED
Binary file (142 kB). View file
|
|
examples/example3/video.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b21c936a305f80ed6707bad621712b24bd1e7a69f82ec7cdd949b18fd1a7fd56
|
3 |
+
size 5657081
|
libs/brushnet_CA.py
ADDED
@@ -0,0 +1,939 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
9 |
+
from diffusers.utils import BaseOutput, logging
|
10 |
+
from diffusers.models.attention_processor import (
|
11 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
12 |
+
CROSS_ATTENTION_PROCESSORS,
|
13 |
+
AttentionProcessor,
|
14 |
+
AttnAddedKVProcessor,
|
15 |
+
AttnProcessor,
|
16 |
+
)
|
17 |
+
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
18 |
+
from diffusers.models.modeling_utils import ModelMixin
|
19 |
+
from .unet_2d_blocks import (
|
20 |
+
CrossAttnDownBlock2D,
|
21 |
+
DownBlock2D,
|
22 |
+
UNetMidBlock2D,
|
23 |
+
UNetMidBlock2DCrossAttn,
|
24 |
+
get_down_block,
|
25 |
+
get_mid_block,
|
26 |
+
get_up_block,
|
27 |
+
MidBlock2D
|
28 |
+
)
|
29 |
+
|
30 |
+
# from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
31 |
+
from libs.unet_2d_condition import UNet2DConditionModel
|
32 |
+
|
33 |
+
|
34 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
35 |
+
|
36 |
+
|
37 |
+
@dataclass
|
38 |
+
class BrushNetOutput(BaseOutput):
|
39 |
+
"""
|
40 |
+
The output of [`BrushNetModel`].
|
41 |
+
|
42 |
+
Args:
|
43 |
+
up_block_res_samples (`tuple[torch.Tensor]`):
|
44 |
+
A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
|
45 |
+
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
46 |
+
used to condition the original UNet's upsampling activations.
|
47 |
+
down_block_res_samples (`tuple[torch.Tensor]`):
|
48 |
+
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
49 |
+
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
50 |
+
used to condition the original UNet's downsampling activations.
|
51 |
+
mid_down_block_re_sample (`torch.Tensor`):
|
52 |
+
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
|
53 |
+
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
54 |
+
Output can be used to condition the original UNet's middle block activation.
|
55 |
+
"""
|
56 |
+
|
57 |
+
up_block_res_samples: Tuple[torch.Tensor]
|
58 |
+
down_block_res_samples: Tuple[torch.Tensor]
|
59 |
+
mid_block_res_sample: torch.Tensor
|
60 |
+
|
61 |
+
|
62 |
+
class BrushNetModel(ModelMixin, ConfigMixin):
|
63 |
+
"""
|
64 |
+
A BrushNet model.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
in_channels (`int`, defaults to 4):
|
68 |
+
The number of channels in the input sample.
|
69 |
+
flip_sin_to_cos (`bool`, defaults to `True`):
|
70 |
+
Whether to flip the sin to cos in the time embedding.
|
71 |
+
freq_shift (`int`, defaults to 0):
|
72 |
+
The frequency shift to apply to the time embedding.
|
73 |
+
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
74 |
+
The tuple of downsample blocks to use.
|
75 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
76 |
+
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
|
77 |
+
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
78 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
79 |
+
The tuple of upsample blocks to use.
|
80 |
+
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
81 |
+
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
82 |
+
The tuple of output channels for each block.
|
83 |
+
layers_per_block (`int`, defaults to 2):
|
84 |
+
The number of layers per block.
|
85 |
+
downsample_padding (`int`, defaults to 1):
|
86 |
+
The padding to use for the downsampling convolution.
|
87 |
+
mid_block_scale_factor (`float`, defaults to 1):
|
88 |
+
The scale factor to use for the mid block.
|
89 |
+
act_fn (`str`, defaults to "silu"):
|
90 |
+
The activation function to use.
|
91 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
92 |
+
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
93 |
+
in post-processing.
|
94 |
+
norm_eps (`float`, defaults to 1e-5):
|
95 |
+
The epsilon to use for the normalization.
|
96 |
+
cross_attention_dim (`int`, defaults to 1280):
|
97 |
+
The dimension of the cross attention features.
|
98 |
+
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
99 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
100 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
101 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
102 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
103 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
104 |
+
dimension to `cross_attention_dim`.
|
105 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
106 |
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
107 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
108 |
+
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
109 |
+
The dimension of the attention heads.
|
110 |
+
use_linear_projection (`bool`, defaults to `False`):
|
111 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
112 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
113 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
114 |
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
115 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
116 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
117 |
+
num_class_embeds (`int`, *optional*, defaults to 0):
|
118 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
119 |
+
class conditioning with `class_embed_type` equal to `None`.
|
120 |
+
upcast_attention (`bool`, defaults to `False`):
|
121 |
+
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
122 |
+
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
123 |
+
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
124 |
+
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
125 |
+
`class_embed_type="projection"`.
|
126 |
+
brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
127 |
+
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
128 |
+
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
129 |
+
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
130 |
+
global_pool_conditions (`bool`, defaults to `False`):
|
131 |
+
TODO(Patrick) - unused parameter.
|
132 |
+
addition_embed_type_num_heads (`int`, defaults to 64):
|
133 |
+
The number of heads to use for the `TextTimeEmbedding` layer.
|
134 |
+
"""
|
135 |
+
|
136 |
+
_supports_gradient_checkpointing = True
|
137 |
+
|
138 |
+
@register_to_config
|
139 |
+
def __init__(
|
140 |
+
self,
|
141 |
+
in_channels: int = 4,
|
142 |
+
conditioning_channels: int = 5,
|
143 |
+
flip_sin_to_cos: bool = True,
|
144 |
+
freq_shift: int = 0,
|
145 |
+
down_block_types: Tuple[str, ...] = (
|
146 |
+
"CrossAttnDownBlock2D",
|
147 |
+
"CrossAttnDownBlock2D",
|
148 |
+
"CrossAttnDownBlock2D",
|
149 |
+
"DownBlock2D",
|
150 |
+
),
|
151 |
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
152 |
+
up_block_types: Tuple[str, ...] = (
|
153 |
+
"UpBlock2D",
|
154 |
+
"CrossAttnUpBlock2D",
|
155 |
+
"CrossAttnUpBlock2D",
|
156 |
+
"CrossAttnUpBlock2D",
|
157 |
+
),
|
158 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
159 |
+
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
160 |
+
layers_per_block: int = 2,
|
161 |
+
downsample_padding: int = 1,
|
162 |
+
mid_block_scale_factor: float = 1,
|
163 |
+
act_fn: str = "silu",
|
164 |
+
norm_num_groups: Optional[int] = 32,
|
165 |
+
norm_eps: float = 1e-5,
|
166 |
+
cross_attention_dim: int = 1280,
|
167 |
+
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
168 |
+
encoder_hid_dim: Optional[int] = None,
|
169 |
+
encoder_hid_dim_type: Optional[str] = None,
|
170 |
+
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
171 |
+
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
172 |
+
use_linear_projection: bool = False,
|
173 |
+
class_embed_type: Optional[str] = None,
|
174 |
+
addition_embed_type: Optional[str] = None,
|
175 |
+
addition_time_embed_dim: Optional[int] = None,
|
176 |
+
num_class_embeds: Optional[int] = None,
|
177 |
+
upcast_attention: bool = False,
|
178 |
+
resnet_time_scale_shift: str = "default",
|
179 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
180 |
+
brushnet_conditioning_channel_order: str = "rgb",
|
181 |
+
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
182 |
+
global_pool_conditions: bool = False,
|
183 |
+
addition_embed_type_num_heads: int = 64,
|
184 |
+
):
|
185 |
+
super().__init__()
|
186 |
+
|
187 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
188 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
189 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
190 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
191 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
192 |
+
# which is why we correct for the naming here.
|
193 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
194 |
+
|
195 |
+
# Check inputs
|
196 |
+
if len(down_block_types) != len(up_block_types):
|
197 |
+
raise ValueError(
|
198 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
199 |
+
)
|
200 |
+
|
201 |
+
if len(block_out_channels) != len(down_block_types):
|
202 |
+
raise ValueError(
|
203 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
204 |
+
)
|
205 |
+
|
206 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
207 |
+
raise ValueError(
|
208 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
209 |
+
)
|
210 |
+
|
211 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
212 |
+
raise ValueError(
|
213 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
214 |
+
)
|
215 |
+
|
216 |
+
if isinstance(transformer_layers_per_block, int):
|
217 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
218 |
+
|
219 |
+
# input
|
220 |
+
conv_in_kernel = 3
|
221 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
222 |
+
self.conv_in_condition = nn.Conv2d(
|
223 |
+
in_channels+conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
224 |
+
)
|
225 |
+
|
226 |
+
# time
|
227 |
+
time_embed_dim = block_out_channels[0] * 4
|
228 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
229 |
+
timestep_input_dim = block_out_channels[0]
|
230 |
+
self.time_embedding = TimestepEmbedding(
|
231 |
+
timestep_input_dim,
|
232 |
+
time_embed_dim,
|
233 |
+
act_fn=act_fn,
|
234 |
+
)
|
235 |
+
|
236 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
237 |
+
encoder_hid_dim_type = "text_proj"
|
238 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
239 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
240 |
+
|
241 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
242 |
+
raise ValueError(
|
243 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
244 |
+
)
|
245 |
+
|
246 |
+
if encoder_hid_dim_type == "text_proj":
|
247 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
248 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
249 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
250 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
251 |
+
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
252 |
+
self.encoder_hid_proj = TextImageProjection(
|
253 |
+
text_embed_dim=encoder_hid_dim,
|
254 |
+
image_embed_dim=cross_attention_dim,
|
255 |
+
cross_attention_dim=cross_attention_dim,
|
256 |
+
)
|
257 |
+
|
258 |
+
elif encoder_hid_dim_type is not None:
|
259 |
+
raise ValueError(
|
260 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
261 |
+
)
|
262 |
+
else:
|
263 |
+
self.encoder_hid_proj = None
|
264 |
+
|
265 |
+
# class embedding
|
266 |
+
if class_embed_type is None and num_class_embeds is not None:
|
267 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
268 |
+
elif class_embed_type == "timestep":
|
269 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
270 |
+
elif class_embed_type == "identity":
|
271 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
272 |
+
elif class_embed_type == "projection":
|
273 |
+
if projection_class_embeddings_input_dim is None:
|
274 |
+
raise ValueError(
|
275 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
276 |
+
)
|
277 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
278 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
279 |
+
# 2. it projects from an arbitrary input dimension.
|
280 |
+
#
|
281 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
282 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
283 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
284 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
285 |
+
else:
|
286 |
+
self.class_embedding = None
|
287 |
+
|
288 |
+
if addition_embed_type == "text":
|
289 |
+
if encoder_hid_dim is not None:
|
290 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
291 |
+
else:
|
292 |
+
text_time_embedding_from_dim = cross_attention_dim
|
293 |
+
|
294 |
+
self.add_embedding = TextTimeEmbedding(
|
295 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
296 |
+
)
|
297 |
+
elif addition_embed_type == "text_image":
|
298 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
299 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
300 |
+
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
301 |
+
self.add_embedding = TextImageTimeEmbedding(
|
302 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
303 |
+
)
|
304 |
+
elif addition_embed_type == "text_time":
|
305 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
306 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
307 |
+
|
308 |
+
elif addition_embed_type is not None:
|
309 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
310 |
+
|
311 |
+
self.down_blocks = nn.ModuleList([])
|
312 |
+
self.brushnet_down_blocks = nn.ModuleList([])
|
313 |
+
|
314 |
+
if isinstance(only_cross_attention, bool):
|
315 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
316 |
+
|
317 |
+
if isinstance(attention_head_dim, int):
|
318 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
319 |
+
|
320 |
+
if isinstance(num_attention_heads, int):
|
321 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
322 |
+
|
323 |
+
# down
|
324 |
+
output_channel = block_out_channels[0]
|
325 |
+
|
326 |
+
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
327 |
+
brushnet_block = zero_module(brushnet_block)
|
328 |
+
self.brushnet_down_blocks.append(brushnet_block) #零卷积
|
329 |
+
|
330 |
+
for i, down_block_type in enumerate(down_block_types):
|
331 |
+
input_channel = output_channel
|
332 |
+
output_channel = block_out_channels[i]
|
333 |
+
is_final_block = i == len(block_out_channels) - 1
|
334 |
+
|
335 |
+
down_block = get_down_block(
|
336 |
+
down_block_type,
|
337 |
+
num_layers=layers_per_block,
|
338 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
339 |
+
in_channels=input_channel,
|
340 |
+
out_channels=output_channel,
|
341 |
+
temb_channels=time_embed_dim,
|
342 |
+
add_downsample=not is_final_block,
|
343 |
+
resnet_eps=norm_eps,
|
344 |
+
resnet_act_fn=act_fn,
|
345 |
+
resnet_groups=norm_num_groups,
|
346 |
+
cross_attention_dim=cross_attention_dim,
|
347 |
+
num_attention_heads=num_attention_heads[i],
|
348 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
349 |
+
downsample_padding=downsample_padding,
|
350 |
+
use_linear_projection=use_linear_projection,
|
351 |
+
only_cross_attention=only_cross_attention[i],
|
352 |
+
upcast_attention=upcast_attention,
|
353 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
354 |
+
)
|
355 |
+
self.down_blocks.append(down_block)
|
356 |
+
|
357 |
+
for _ in range(layers_per_block):
|
358 |
+
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
359 |
+
brushnet_block = zero_module(brushnet_block)
|
360 |
+
self.brushnet_down_blocks.append(brushnet_block) #零卷积
|
361 |
+
|
362 |
+
if not is_final_block:
|
363 |
+
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
364 |
+
brushnet_block = zero_module(brushnet_block)
|
365 |
+
self.brushnet_down_blocks.append(brushnet_block)
|
366 |
+
|
367 |
+
# mid
|
368 |
+
mid_block_channel = block_out_channels[-1]
|
369 |
+
|
370 |
+
brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
371 |
+
brushnet_block = zero_module(brushnet_block)
|
372 |
+
self.brushnet_mid_block = brushnet_block
|
373 |
+
|
374 |
+
self.mid_block = get_mid_block(
|
375 |
+
mid_block_type,
|
376 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
377 |
+
in_channels=mid_block_channel,
|
378 |
+
temb_channels=time_embed_dim,
|
379 |
+
resnet_eps=norm_eps,
|
380 |
+
resnet_act_fn=act_fn,
|
381 |
+
output_scale_factor=mid_block_scale_factor,
|
382 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
383 |
+
cross_attention_dim=cross_attention_dim,
|
384 |
+
num_attention_heads=num_attention_heads[-1],
|
385 |
+
resnet_groups=norm_num_groups,
|
386 |
+
use_linear_projection=use_linear_projection,
|
387 |
+
upcast_attention=upcast_attention,
|
388 |
+
)
|
389 |
+
|
390 |
+
# count how many layers upsample the images
|
391 |
+
self.num_upsamplers = 0
|
392 |
+
|
393 |
+
# up
|
394 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
395 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
396 |
+
reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block)))
|
397 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
398 |
+
|
399 |
+
output_channel = reversed_block_out_channels[0]
|
400 |
+
|
401 |
+
self.up_blocks = nn.ModuleList([])
|
402 |
+
self.brushnet_up_blocks = nn.ModuleList([])
|
403 |
+
|
404 |
+
for i, up_block_type in enumerate(up_block_types):
|
405 |
+
is_final_block = i == len(block_out_channels) - 1
|
406 |
+
|
407 |
+
prev_output_channel = output_channel
|
408 |
+
output_channel = reversed_block_out_channels[i]
|
409 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
410 |
+
|
411 |
+
# add upsample block for all BUT final layer
|
412 |
+
if not is_final_block:
|
413 |
+
add_upsample = True
|
414 |
+
self.num_upsamplers += 1
|
415 |
+
else:
|
416 |
+
add_upsample = False
|
417 |
+
|
418 |
+
up_block = get_up_block(
|
419 |
+
up_block_type,
|
420 |
+
num_layers=layers_per_block+1,
|
421 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
422 |
+
in_channels=input_channel,
|
423 |
+
out_channels=output_channel,
|
424 |
+
prev_output_channel=prev_output_channel,
|
425 |
+
temb_channels=time_embed_dim,
|
426 |
+
add_upsample=add_upsample,
|
427 |
+
resnet_eps=norm_eps,
|
428 |
+
resnet_act_fn=act_fn,
|
429 |
+
resolution_idx=i,
|
430 |
+
resnet_groups=norm_num_groups,
|
431 |
+
cross_attention_dim=cross_attention_dim,
|
432 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
433 |
+
use_linear_projection=use_linear_projection,
|
434 |
+
only_cross_attention=only_cross_attention[i],
|
435 |
+
upcast_attention=upcast_attention,
|
436 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
437 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
438 |
+
)
|
439 |
+
self.up_blocks.append(up_block)
|
440 |
+
prev_output_channel = output_channel
|
441 |
+
|
442 |
+
for _ in range(layers_per_block+1):
|
443 |
+
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
444 |
+
brushnet_block = zero_module(brushnet_block)
|
445 |
+
self.brushnet_up_blocks.append(brushnet_block)
|
446 |
+
|
447 |
+
if not is_final_block:
|
448 |
+
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
449 |
+
brushnet_block = zero_module(brushnet_block)
|
450 |
+
self.brushnet_up_blocks.append(brushnet_block)
|
451 |
+
|
452 |
+
|
453 |
+
@classmethod
|
454 |
+
def from_unet(
|
455 |
+
cls,
|
456 |
+
unet: UNet2DConditionModel,
|
457 |
+
brushnet_conditioning_channel_order: str = "rgb",
|
458 |
+
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
459 |
+
load_weights_from_unet: bool = True,
|
460 |
+
conditioning_channels: int = 5,
|
461 |
+
):
|
462 |
+
r"""
|
463 |
+
Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
|
464 |
+
|
465 |
+
Parameters:
|
466 |
+
unet (`UNet2DConditionModel`):
|
467 |
+
The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
|
468 |
+
where applicable.
|
469 |
+
"""
|
470 |
+
transformer_layers_per_block = (
|
471 |
+
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
472 |
+
)
|
473 |
+
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
474 |
+
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
475 |
+
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
476 |
+
addition_time_embed_dim = (
|
477 |
+
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
478 |
+
)
|
479 |
+
|
480 |
+
brushnet = cls(
|
481 |
+
in_channels=unet.config.in_channels,
|
482 |
+
conditioning_channels=conditioning_channels,
|
483 |
+
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
484 |
+
freq_shift=unet.config.freq_shift,
|
485 |
+
# down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'],
|
486 |
+
down_block_types=[
|
487 |
+
"CrossAttnDownBlock2D",
|
488 |
+
"CrossAttnDownBlock2D",
|
489 |
+
"CrossAttnDownBlock2D",
|
490 |
+
"DownBlock2D",
|
491 |
+
],
|
492 |
+
# mid_block_type='MidBlock2D',
|
493 |
+
mid_block_type="UNetMidBlock2DCrossAttn",
|
494 |
+
# up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'],
|
495 |
+
up_block_types=["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
|
496 |
+
only_cross_attention=unet.config.only_cross_attention,
|
497 |
+
block_out_channels=unet.config.block_out_channels,
|
498 |
+
layers_per_block=unet.config.layers_per_block,
|
499 |
+
downsample_padding=unet.config.downsample_padding,
|
500 |
+
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
501 |
+
act_fn=unet.config.act_fn,
|
502 |
+
norm_num_groups=unet.config.norm_num_groups,
|
503 |
+
norm_eps=unet.config.norm_eps,
|
504 |
+
cross_attention_dim=unet.config.cross_attention_dim,
|
505 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
506 |
+
encoder_hid_dim=encoder_hid_dim,
|
507 |
+
encoder_hid_dim_type=encoder_hid_dim_type,
|
508 |
+
attention_head_dim=unet.config.attention_head_dim,
|
509 |
+
num_attention_heads=unet.config.num_attention_heads,
|
510 |
+
use_linear_projection=unet.config.use_linear_projection,
|
511 |
+
class_embed_type=unet.config.class_embed_type,
|
512 |
+
addition_embed_type=addition_embed_type,
|
513 |
+
addition_time_embed_dim=addition_time_embed_dim,
|
514 |
+
num_class_embeds=unet.config.num_class_embeds,
|
515 |
+
upcast_attention=unet.config.upcast_attention,
|
516 |
+
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
517 |
+
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
518 |
+
brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
|
519 |
+
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
520 |
+
)
|
521 |
+
|
522 |
+
if load_weights_from_unet:
|
523 |
+
conv_in_condition_weight=torch.zeros_like(brushnet.conv_in_condition.weight)
|
524 |
+
conv_in_condition_weight[:,:4,...]=unet.conv_in.weight
|
525 |
+
conv_in_condition_weight[:,4:8,...]=unet.conv_in.weight
|
526 |
+
brushnet.conv_in_condition.weight=torch.nn.Parameter(conv_in_condition_weight)
|
527 |
+
brushnet.conv_in_condition.bias=unet.conv_in.bias
|
528 |
+
|
529 |
+
brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
530 |
+
brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
531 |
+
|
532 |
+
if brushnet.class_embedding:
|
533 |
+
brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
534 |
+
|
535 |
+
brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(),strict=False)
|
536 |
+
brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(),strict=False)
|
537 |
+
brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(),strict=False)
|
538 |
+
|
539 |
+
return brushnet
|
540 |
+
|
541 |
+
@property
|
542 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
543 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
544 |
+
r"""
|
545 |
+
Returns:
|
546 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
547 |
+
indexed by its weight name.
|
548 |
+
"""
|
549 |
+
# set recursively
|
550 |
+
processors = {}
|
551 |
+
|
552 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
553 |
+
if hasattr(module, "get_processor"):
|
554 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
555 |
+
|
556 |
+
for sub_name, child in module.named_children():
|
557 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
558 |
+
|
559 |
+
return processors
|
560 |
+
|
561 |
+
for name, module in self.named_children():
|
562 |
+
fn_recursive_add_processors(name, module, processors)
|
563 |
+
|
564 |
+
return processors
|
565 |
+
|
566 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
567 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
568 |
+
r"""
|
569 |
+
Sets the attention processor to use to compute attention.
|
570 |
+
|
571 |
+
Parameters:
|
572 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
573 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
574 |
+
for **all** `Attention` layers.
|
575 |
+
|
576 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
577 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
578 |
+
|
579 |
+
"""
|
580 |
+
count = len(self.attn_processors.keys())
|
581 |
+
|
582 |
+
if isinstance(processor, dict) and len(processor) != count:
|
583 |
+
raise ValueError(
|
584 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
585 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
586 |
+
)
|
587 |
+
|
588 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
589 |
+
if hasattr(module, "set_processor"):
|
590 |
+
if not isinstance(processor, dict):
|
591 |
+
module.set_processor(processor)
|
592 |
+
else:
|
593 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
594 |
+
|
595 |
+
for sub_name, child in module.named_children():
|
596 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
597 |
+
|
598 |
+
for name, module in self.named_children():
|
599 |
+
fn_recursive_attn_processor(name, module, processor)
|
600 |
+
|
601 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
602 |
+
def set_default_attn_processor(self):
|
603 |
+
"""
|
604 |
+
Disables custom attention processors and sets the default attention implementation.
|
605 |
+
"""
|
606 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
607 |
+
processor = AttnAddedKVProcessor()
|
608 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
609 |
+
processor = AttnProcessor()
|
610 |
+
else:
|
611 |
+
raise ValueError(
|
612 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
613 |
+
)
|
614 |
+
|
615 |
+
self.set_attn_processor(processor)
|
616 |
+
|
617 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
618 |
+
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
619 |
+
r"""
|
620 |
+
Enable sliced attention computation.
|
621 |
+
|
622 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
623 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
624 |
+
|
625 |
+
Args:
|
626 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
627 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
628 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
629 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
630 |
+
must be a multiple of `slice_size`.
|
631 |
+
"""
|
632 |
+
sliceable_head_dims = []
|
633 |
+
|
634 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
635 |
+
if hasattr(module, "set_attention_slice"):
|
636 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
637 |
+
|
638 |
+
for child in module.children():
|
639 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
640 |
+
|
641 |
+
# retrieve number of attention layers
|
642 |
+
for module in self.children():
|
643 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
644 |
+
|
645 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
646 |
+
|
647 |
+
if slice_size == "auto":
|
648 |
+
# half the attention head size is usually a good trade-off between
|
649 |
+
# speed and memory
|
650 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
651 |
+
elif slice_size == "max":
|
652 |
+
# make smallest slice possible
|
653 |
+
slice_size = num_sliceable_layers * [1]
|
654 |
+
|
655 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
656 |
+
|
657 |
+
if len(slice_size) != len(sliceable_head_dims):
|
658 |
+
raise ValueError(
|
659 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
660 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
661 |
+
)
|
662 |
+
|
663 |
+
for i in range(len(slice_size)):
|
664 |
+
size = slice_size[i]
|
665 |
+
dim = sliceable_head_dims[i]
|
666 |
+
if size is not None and size > dim:
|
667 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
668 |
+
|
669 |
+
# Recursively walk through all the children.
|
670 |
+
# Any children which exposes the set_attention_slice method
|
671 |
+
# gets the message
|
672 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
673 |
+
if hasattr(module, "set_attention_slice"):
|
674 |
+
module.set_attention_slice(slice_size.pop())
|
675 |
+
|
676 |
+
for child in module.children():
|
677 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
678 |
+
|
679 |
+
reversed_slice_size = list(reversed(slice_size))
|
680 |
+
for module in self.children():
|
681 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
682 |
+
|
683 |
+
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
684 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
685 |
+
module.gradient_checkpointing = value
|
686 |
+
|
687 |
+
def forward(
|
688 |
+
self,
|
689 |
+
sample: torch.FloatTensor,
|
690 |
+
timestep: Union[torch.Tensor, float, int],
|
691 |
+
encoder_hidden_states: torch.Tensor,
|
692 |
+
brushnet_cond: torch.FloatTensor,
|
693 |
+
conditioning_scale: float = 1.0,
|
694 |
+
class_labels: Optional[torch.Tensor] = None,
|
695 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
696 |
+
attention_mask: Optional[torch.Tensor] = None,
|
697 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
698 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
699 |
+
guess_mode: bool = False,
|
700 |
+
return_dict: bool = True,
|
701 |
+
) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
|
702 |
+
"""
|
703 |
+
The [`BrushNetModel`] forward method.
|
704 |
+
|
705 |
+
Args:
|
706 |
+
sample (`torch.FloatTensor`):
|
707 |
+
The noisy input tensor.
|
708 |
+
timestep (`Union[torch.Tensor, float, int]`):
|
709 |
+
The number of timesteps to denoise an input.
|
710 |
+
encoder_hidden_states (`torch.Tensor`):
|
711 |
+
The encoder hidden states.
|
712 |
+
brushnet_cond (`torch.FloatTensor`):
|
713 |
+
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
714 |
+
conditioning_scale (`float`, defaults to `1.0`):
|
715 |
+
The scale factor for BrushNet outputs.
|
716 |
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
717 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
718 |
+
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
719 |
+
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
720 |
+
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
721 |
+
embeddings.
|
722 |
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
723 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
724 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
725 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
726 |
+
added_cond_kwargs (`dict`):
|
727 |
+
Additional conditions for the Stable Diffusion XL UNet.
|
728 |
+
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
729 |
+
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
730 |
+
guess_mode (`bool`, defaults to `False`):
|
731 |
+
In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
|
732 |
+
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
733 |
+
return_dict (`bool`, defaults to `True`):
|
734 |
+
Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.
|
735 |
+
|
736 |
+
Returns:
|
737 |
+
[`~models.brushnet.BrushNetOutput`] **or** `tuple`:
|
738 |
+
If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
|
739 |
+
returned where the first element is the sample tensor.
|
740 |
+
"""
|
741 |
+
# check channel order
|
742 |
+
channel_order = self.config.brushnet_conditioning_channel_order
|
743 |
+
|
744 |
+
if channel_order == "rgb":
|
745 |
+
# in rgb order by default
|
746 |
+
...
|
747 |
+
elif channel_order == "bgr":
|
748 |
+
brushnet_cond = torch.flip(brushnet_cond, dims=[1])
|
749 |
+
else:
|
750 |
+
raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
|
751 |
+
|
752 |
+
# prepare attention_mask
|
753 |
+
if attention_mask is not None:
|
754 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
755 |
+
attention_mask = attention_mask.unsqueeze(1)
|
756 |
+
|
757 |
+
# 1. time
|
758 |
+
timesteps = timestep
|
759 |
+
if not torch.is_tensor(timesteps):
|
760 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
761 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
762 |
+
is_mps = sample.device.type == "mps"
|
763 |
+
if isinstance(timestep, float):
|
764 |
+
dtype = torch.float32 if is_mps else torch.float64
|
765 |
+
else:
|
766 |
+
dtype = torch.int32 if is_mps else torch.int64
|
767 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
768 |
+
elif len(timesteps.shape) == 0:
|
769 |
+
timesteps = timesteps[None].to(sample.device)
|
770 |
+
|
771 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
772 |
+
timesteps = timesteps.expand(sample.shape[0])
|
773 |
+
|
774 |
+
t_emb = self.time_proj(timesteps)
|
775 |
+
|
776 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
777 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
778 |
+
# there might be better ways to encapsulate this.
|
779 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
780 |
+
|
781 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
782 |
+
aug_emb = None
|
783 |
+
|
784 |
+
if self.class_embedding is not None:
|
785 |
+
if class_labels is None:
|
786 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
787 |
+
|
788 |
+
if self.config.class_embed_type == "timestep":
|
789 |
+
class_labels = self.time_proj(class_labels)
|
790 |
+
|
791 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
792 |
+
emb = emb + class_emb
|
793 |
+
|
794 |
+
if self.config.addition_embed_type is not None:
|
795 |
+
if self.config.addition_embed_type == "text":
|
796 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
797 |
+
|
798 |
+
elif self.config.addition_embed_type == "text_time":
|
799 |
+
if "text_embeds" not in added_cond_kwargs:
|
800 |
+
raise ValueError(
|
801 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
802 |
+
)
|
803 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
804 |
+
if "time_ids" not in added_cond_kwargs:
|
805 |
+
raise ValueError(
|
806 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
807 |
+
)
|
808 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
809 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
810 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
811 |
+
|
812 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
813 |
+
add_embeds = add_embeds.to(emb.dtype)
|
814 |
+
aug_emb = self.add_embedding(add_embeds)
|
815 |
+
|
816 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
817 |
+
|
818 |
+
# 2. pre-process
|
819 |
+
brushnet_cond=torch.concat([sample,brushnet_cond],1)
|
820 |
+
sample = self.conv_in_condition(brushnet_cond)
|
821 |
+
|
822 |
+
|
823 |
+
# 3. down
|
824 |
+
down_block_res_samples = (sample,)
|
825 |
+
for downsample_block in self.down_blocks:
|
826 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
827 |
+
sample, res_samples = downsample_block(
|
828 |
+
hidden_states=sample,
|
829 |
+
temb=emb,
|
830 |
+
encoder_hidden_states=encoder_hidden_states,
|
831 |
+
attention_mask=attention_mask,
|
832 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
833 |
+
)
|
834 |
+
else:
|
835 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
836 |
+
|
837 |
+
down_block_res_samples += res_samples
|
838 |
+
|
839 |
+
# 4. PaintingNet down blocks
|
840 |
+
brushnet_down_block_res_samples = ()
|
841 |
+
for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
|
842 |
+
down_block_res_sample = brushnet_down_block(down_block_res_sample)
|
843 |
+
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
|
844 |
+
|
845 |
+
# 5. mid
|
846 |
+
if self.mid_block is not None:
|
847 |
+
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
848 |
+
sample = self.mid_block(
|
849 |
+
sample,
|
850 |
+
emb,
|
851 |
+
encoder_hidden_states=encoder_hidden_states,
|
852 |
+
attention_mask=attention_mask,
|
853 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
854 |
+
)
|
855 |
+
else:
|
856 |
+
sample = self.mid_block(sample, emb)
|
857 |
+
|
858 |
+
# 6. BrushNet mid blocks
|
859 |
+
brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)
|
860 |
+
|
861 |
+
|
862 |
+
# 7. up
|
863 |
+
up_block_res_samples = ()
|
864 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
865 |
+
is_final_block = i == len(self.up_blocks) - 1
|
866 |
+
|
867 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
868 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
869 |
+
|
870 |
+
# if we have not reached the final block and need to forward the
|
871 |
+
# upsample size, we do it here
|
872 |
+
if not is_final_block:
|
873 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
874 |
+
|
875 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
876 |
+
sample, up_res_samples = upsample_block(
|
877 |
+
hidden_states=sample,
|
878 |
+
temb=emb,
|
879 |
+
res_hidden_states_tuple=res_samples,
|
880 |
+
encoder_hidden_states=encoder_hidden_states,
|
881 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
882 |
+
upsample_size=upsample_size,
|
883 |
+
attention_mask=attention_mask,
|
884 |
+
return_res_samples=True
|
885 |
+
)
|
886 |
+
else:
|
887 |
+
sample, up_res_samples = upsample_block(
|
888 |
+
hidden_states=sample,
|
889 |
+
temb=emb,
|
890 |
+
res_hidden_states_tuple=res_samples,
|
891 |
+
upsample_size=upsample_size,
|
892 |
+
return_res_samples=True
|
893 |
+
)
|
894 |
+
|
895 |
+
up_block_res_samples += up_res_samples
|
896 |
+
|
897 |
+
# 8. BrushNet up blocks
|
898 |
+
brushnet_up_block_res_samples = ()
|
899 |
+
for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
|
900 |
+
up_block_res_sample = brushnet_up_block(up_block_res_sample)
|
901 |
+
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
|
902 |
+
|
903 |
+
# 6. scaling
|
904 |
+
if guess_mode and not self.config.global_pool_conditions:
|
905 |
+
scales = torch.logspace(-1, 0, len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples), device=sample.device) # 0.1 to 1.0
|
906 |
+
scales = scales * conditioning_scale
|
907 |
+
|
908 |
+
brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples, scales[:len(brushnet_down_block_res_samples)])]
|
909 |
+
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
|
910 |
+
brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples, scales[len(brushnet_down_block_res_samples)+1:])]
|
911 |
+
else:
|
912 |
+
brushnet_down_block_res_samples = [sample * conditioning_scale for sample in brushnet_down_block_res_samples]
|
913 |
+
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
|
914 |
+
brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
|
915 |
+
|
916 |
+
|
917 |
+
if self.config.global_pool_conditions:
|
918 |
+
brushnet_down_block_res_samples = [
|
919 |
+
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
|
920 |
+
]
|
921 |
+
brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
|
922 |
+
brushnet_up_block_res_samples = [
|
923 |
+
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
|
924 |
+
]
|
925 |
+
|
926 |
+
if not return_dict:
|
927 |
+
return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
|
928 |
+
|
929 |
+
return BrushNetOutput(
|
930 |
+
down_block_res_samples=brushnet_down_block_res_samples,
|
931 |
+
mid_block_res_sample=brushnet_mid_block_res_sample,
|
932 |
+
up_block_res_samples=brushnet_up_block_res_samples
|
933 |
+
)
|
934 |
+
|
935 |
+
|
936 |
+
def zero_module(module):
|
937 |
+
for p in module.parameters():
|
938 |
+
nn.init.zeros_(p)
|
939 |
+
return module
|
libs/transformer_temporal.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Any, Dict, Optional
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
21 |
+
from diffusers.utils import BaseOutput
|
22 |
+
from diffusers.models.attention import BasicTransformerBlock, TemporalBasicTransformerBlock
|
23 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
24 |
+
from diffusers.models.modeling_utils import ModelMixin
|
25 |
+
from diffusers.models.resnet import AlphaBlender
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class TransformerTemporalModelOutput(BaseOutput):
|
30 |
+
"""
|
31 |
+
The output of [`TransformerTemporalModel`].
|
32 |
+
|
33 |
+
Args:
|
34 |
+
sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
|
35 |
+
The hidden states output conditioned on `encoder_hidden_states` input.
|
36 |
+
"""
|
37 |
+
|
38 |
+
sample: torch.FloatTensor
|
39 |
+
|
40 |
+
|
41 |
+
class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
42 |
+
"""
|
43 |
+
A Transformer model for video-like data.
|
44 |
+
|
45 |
+
Parameters:
|
46 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
47 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
48 |
+
in_channels (`int`, *optional*):
|
49 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
50 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
51 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
52 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
53 |
+
attention_bias (`bool`, *optional*):
|
54 |
+
Configure if the `TransformerBlock` attention should contain a bias parameter.
|
55 |
+
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
56 |
+
This is fixed during training since it is used to learn a number of position embeddings.
|
57 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
58 |
+
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
|
59 |
+
activation functions.
|
60 |
+
norm_elementwise_affine (`bool`, *optional*):
|
61 |
+
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
|
62 |
+
double_self_attention (`bool`, *optional*):
|
63 |
+
Configure if each `TransformerBlock` should contain two self-attention layers.
|
64 |
+
positional_embeddings: (`str`, *optional*):
|
65 |
+
The type of positional embeddings to apply to the sequence input before passing use.
|
66 |
+
num_positional_embeddings: (`int`, *optional*):
|
67 |
+
The maximum length of the sequence over which to apply positional embeddings.
|
68 |
+
"""
|
69 |
+
|
70 |
+
@register_to_config
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
num_attention_heads: int = 16,
|
74 |
+
attention_head_dim: int = 88,
|
75 |
+
in_channels: Optional[int] = None,
|
76 |
+
out_channels: Optional[int] = None,
|
77 |
+
num_layers: int = 1,
|
78 |
+
dropout: float = 0.0,
|
79 |
+
norm_num_groups: int = 32,
|
80 |
+
cross_attention_dim: Optional[int] = None,
|
81 |
+
attention_bias: bool = False,
|
82 |
+
sample_size: Optional[int] = None,
|
83 |
+
activation_fn: str = "geglu",
|
84 |
+
norm_elementwise_affine: bool = True,
|
85 |
+
double_self_attention: bool = True,
|
86 |
+
positional_embeddings: Optional[str] = None,
|
87 |
+
num_positional_embeddings: Optional[int] = None,
|
88 |
+
):
|
89 |
+
super().__init__()
|
90 |
+
self.num_attention_heads = num_attention_heads
|
91 |
+
self.attention_head_dim = attention_head_dim
|
92 |
+
inner_dim = num_attention_heads * attention_head_dim
|
93 |
+
|
94 |
+
self.in_channels = in_channels
|
95 |
+
|
96 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
97 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
98 |
+
|
99 |
+
# 3. Define transformers blocks
|
100 |
+
self.transformer_blocks = nn.ModuleList(
|
101 |
+
[
|
102 |
+
BasicTransformerBlock(
|
103 |
+
inner_dim,
|
104 |
+
num_attention_heads,
|
105 |
+
attention_head_dim,
|
106 |
+
dropout=dropout,
|
107 |
+
cross_attention_dim=cross_attention_dim,
|
108 |
+
activation_fn=activation_fn,
|
109 |
+
attention_bias=attention_bias,
|
110 |
+
double_self_attention=double_self_attention,
|
111 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
112 |
+
positional_embeddings=positional_embeddings,
|
113 |
+
num_positional_embeddings=num_positional_embeddings,
|
114 |
+
)
|
115 |
+
for d in range(num_layers)
|
116 |
+
]
|
117 |
+
)
|
118 |
+
|
119 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
120 |
+
|
121 |
+
def forward(
|
122 |
+
self,
|
123 |
+
hidden_states: torch.FloatTensor,
|
124 |
+
timestep: Optional[torch.LongTensor] = None,
|
125 |
+
num_frames: int = 1,
|
126 |
+
encoder_hidden_states: Optional[torch.LongTensor] = None,
|
127 |
+
class_labels: torch.LongTensor = None,
|
128 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
129 |
+
) -> TransformerTemporalModelOutput:
|
130 |
+
"""
|
131 |
+
The [`TransformerTemporal`] forward method.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
135 |
+
Input hidden_states.
|
136 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
137 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
138 |
+
self-attention.
|
139 |
+
timestep ( `torch.LongTensor`, *optional*):
|
140 |
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
141 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
142 |
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
143 |
+
`AdaLayerZeroNorm`.
|
144 |
+
num_frames (`int`, *optional*, defaults to 1):
|
145 |
+
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
146 |
+
cross_attention_kwargs (`dict`, *optional*):
|
147 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
148 |
+
`self.processor` in
|
149 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
150 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
151 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
152 |
+
tuple.
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
156 |
+
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
157 |
+
returned, otherwise a `tuple` where the first element is the sample tensor.
|
158 |
+
"""
|
159 |
+
# 1. Input
|
160 |
+
batch_frames, channel, height, width = hidden_states.shape
|
161 |
+
batch_size = batch_frames // num_frames
|
162 |
+
|
163 |
+
residual = hidden_states
|
164 |
+
|
165 |
+
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
|
166 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
167 |
+
|
168 |
+
hidden_states = self.norm(hidden_states)
|
169 |
+
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
|
170 |
+
|
171 |
+
hidden_states = self.proj_in(hidden_states)
|
172 |
+
|
173 |
+
# 2. Blocks
|
174 |
+
for block in self.transformer_blocks:
|
175 |
+
hidden_states = block(
|
176 |
+
hidden_states,
|
177 |
+
encoder_hidden_states=encoder_hidden_states,
|
178 |
+
timestep=timestep,
|
179 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
180 |
+
class_labels=class_labels,
|
181 |
+
)
|
182 |
+
|
183 |
+
# 3. Output
|
184 |
+
hidden_states = self.proj_out(hidden_states)
|
185 |
+
hidden_states = (
|
186 |
+
hidden_states[None, None, :]
|
187 |
+
.reshape(batch_size, height, width, num_frames, channel)
|
188 |
+
.permute(0, 3, 4, 1, 2)
|
189 |
+
.contiguous()
|
190 |
+
)
|
191 |
+
hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
|
192 |
+
|
193 |
+
output = hidden_states + residual
|
194 |
+
|
195 |
+
return output
|
196 |
+
|
197 |
+
|
198 |
+
class TransformerSpatioTemporalModel(nn.Module):
|
199 |
+
"""
|
200 |
+
A Transformer model for video-like data.
|
201 |
+
|
202 |
+
Parameters:
|
203 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
204 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
205 |
+
in_channels (`int`, *optional*):
|
206 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
207 |
+
out_channels (`int`, *optional*):
|
208 |
+
The number of channels in the output (specify if the input is **continuous**).
|
209 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
210 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
211 |
+
"""
|
212 |
+
|
213 |
+
def __init__(
|
214 |
+
self,
|
215 |
+
num_attention_heads: int = 16,
|
216 |
+
attention_head_dim: int = 88,
|
217 |
+
in_channels: int = 320,
|
218 |
+
out_channels: Optional[int] = None,
|
219 |
+
num_layers: int = 1,
|
220 |
+
cross_attention_dim: Optional[int] = None,
|
221 |
+
):
|
222 |
+
super().__init__()
|
223 |
+
self.num_attention_heads = num_attention_heads
|
224 |
+
self.attention_head_dim = attention_head_dim
|
225 |
+
|
226 |
+
inner_dim = num_attention_heads * attention_head_dim
|
227 |
+
self.inner_dim = inner_dim
|
228 |
+
|
229 |
+
# 2. Define input layers
|
230 |
+
self.in_channels = in_channels
|
231 |
+
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
|
232 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
233 |
+
|
234 |
+
# 3. Define transformers blocks
|
235 |
+
self.transformer_blocks = nn.ModuleList(
|
236 |
+
[
|
237 |
+
BasicTransformerBlock(
|
238 |
+
inner_dim,
|
239 |
+
num_attention_heads,
|
240 |
+
attention_head_dim,
|
241 |
+
cross_attention_dim=cross_attention_dim,
|
242 |
+
)
|
243 |
+
for d in range(num_layers)
|
244 |
+
]
|
245 |
+
)
|
246 |
+
|
247 |
+
time_mix_inner_dim = inner_dim
|
248 |
+
self.temporal_transformer_blocks = nn.ModuleList(
|
249 |
+
[
|
250 |
+
TemporalBasicTransformerBlock(
|
251 |
+
inner_dim,
|
252 |
+
time_mix_inner_dim,
|
253 |
+
num_attention_heads,
|
254 |
+
attention_head_dim,
|
255 |
+
cross_attention_dim=cross_attention_dim,
|
256 |
+
)
|
257 |
+
for _ in range(num_layers)
|
258 |
+
]
|
259 |
+
)
|
260 |
+
|
261 |
+
time_embed_dim = in_channels * 4
|
262 |
+
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
|
263 |
+
self.time_proj = Timesteps(in_channels, True, 0)
|
264 |
+
self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
|
265 |
+
|
266 |
+
# 4. Define output layers
|
267 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
268 |
+
# TODO: should use out_channels for continuous projections
|
269 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
270 |
+
|
271 |
+
self.gradient_checkpointing = False
|
272 |
+
|
273 |
+
def forward(
|
274 |
+
self,
|
275 |
+
hidden_states: torch.Tensor,
|
276 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
277 |
+
image_only_indicator: Optional[torch.Tensor] = None,
|
278 |
+
return_dict: bool = True,
|
279 |
+
):
|
280 |
+
"""
|
281 |
+
Args:
|
282 |
+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
283 |
+
Input hidden_states.
|
284 |
+
num_frames (`int`):
|
285 |
+
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
286 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
287 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
288 |
+
self-attention.
|
289 |
+
image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
|
290 |
+
A tensor indicating whether the input contains only images. 1 indicates that the input contains only
|
291 |
+
images, 0 indicates that the input contains video frames.
|
292 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
293 |
+
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
|
294 |
+
tuple.
|
295 |
+
|
296 |
+
Returns:
|
297 |
+
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
298 |
+
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
299 |
+
returned, otherwise a `tuple` where the first element is the sample tensor.
|
300 |
+
"""
|
301 |
+
# 1. Input
|
302 |
+
batch_frames, _, height, width = hidden_states.shape
|
303 |
+
num_frames = image_only_indicator.shape[-1]
|
304 |
+
batch_size = batch_frames // num_frames
|
305 |
+
|
306 |
+
time_context = encoder_hidden_states
|
307 |
+
time_context_first_timestep = time_context[None, :].reshape(
|
308 |
+
batch_size, num_frames, -1, time_context.shape[-1]
|
309 |
+
)[:, 0]
|
310 |
+
time_context = time_context_first_timestep[None, :].broadcast_to(
|
311 |
+
height * width, batch_size, 1, time_context.shape[-1]
|
312 |
+
)
|
313 |
+
time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
|
314 |
+
|
315 |
+
residual = hidden_states
|
316 |
+
|
317 |
+
hidden_states = self.norm(hidden_states)
|
318 |
+
inner_dim = hidden_states.shape[1]
|
319 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
|
320 |
+
hidden_states = self.proj_in(hidden_states)
|
321 |
+
|
322 |
+
num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
|
323 |
+
num_frames_emb = num_frames_emb.repeat(batch_size, 1)
|
324 |
+
num_frames_emb = num_frames_emb.reshape(-1)
|
325 |
+
t_emb = self.time_proj(num_frames_emb)
|
326 |
+
|
327 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
328 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
329 |
+
# there might be better ways to encapsulate this.
|
330 |
+
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
331 |
+
|
332 |
+
emb = self.time_pos_embed(t_emb)
|
333 |
+
emb = emb[:, None, :]
|
334 |
+
|
335 |
+
# 2. Blocks
|
336 |
+
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
|
337 |
+
if self.training and self.gradient_checkpointing:
|
338 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
339 |
+
block,
|
340 |
+
hidden_states,
|
341 |
+
None,
|
342 |
+
encoder_hidden_states,
|
343 |
+
None,
|
344 |
+
use_reentrant=False,
|
345 |
+
)
|
346 |
+
else:
|
347 |
+
hidden_states = block(
|
348 |
+
hidden_states,
|
349 |
+
encoder_hidden_states=encoder_hidden_states,
|
350 |
+
)
|
351 |
+
|
352 |
+
hidden_states_mix = hidden_states
|
353 |
+
hidden_states_mix = hidden_states_mix + emb
|
354 |
+
|
355 |
+
hidden_states_mix = temporal_block(
|
356 |
+
hidden_states_mix,
|
357 |
+
num_frames=num_frames,
|
358 |
+
encoder_hidden_states=time_context,
|
359 |
+
)
|
360 |
+
hidden_states = self.time_mixer(
|
361 |
+
x_spatial=hidden_states,
|
362 |
+
x_temporal=hidden_states_mix,
|
363 |
+
image_only_indicator=image_only_indicator,
|
364 |
+
)
|
365 |
+
|
366 |
+
# 3. Output
|
367 |
+
hidden_states = self.proj_out(hidden_states)
|
368 |
+
hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
369 |
+
|
370 |
+
output = hidden_states + residual
|
371 |
+
|
372 |
+
if not return_dict:
|
373 |
+
return (output,)
|
374 |
+
|
375 |
+
return TransformerTemporalModelOutput(sample=output)
|
libs/unet_2d_blocks.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
libs/unet_2d_condition.py
ADDED
@@ -0,0 +1,1359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.utils.checkpoint
|
20 |
+
|
21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
|
23 |
+
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
24 |
+
from diffusers.models.activations import get_activation
|
25 |
+
from diffusers.models.attention_processor import (
|
26 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
27 |
+
CROSS_ATTENTION_PROCESSORS,
|
28 |
+
Attention,
|
29 |
+
AttentionProcessor,
|
30 |
+
AttnAddedKVProcessor,
|
31 |
+
AttnProcessor,
|
32 |
+
)
|
33 |
+
from diffusers.models.embeddings import (
|
34 |
+
GaussianFourierProjection,
|
35 |
+
GLIGENTextBoundingboxProjection,
|
36 |
+
ImageHintTimeEmbedding,
|
37 |
+
ImageProjection,
|
38 |
+
ImageTimeEmbedding,
|
39 |
+
TextImageProjection,
|
40 |
+
TextImageTimeEmbedding,
|
41 |
+
TextTimeEmbedding,
|
42 |
+
TimestepEmbedding,
|
43 |
+
Timesteps,
|
44 |
+
)
|
45 |
+
from diffusers.models.modeling_utils import ModelMixin
|
46 |
+
from .unet_2d_blocks import (
|
47 |
+
get_down_block,
|
48 |
+
get_mid_block,
|
49 |
+
get_up_block,
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
54 |
+
|
55 |
+
|
56 |
+
@dataclass
|
57 |
+
class UNet2DConditionOutput(BaseOutput):
|
58 |
+
"""
|
59 |
+
The output of [`UNet2DConditionModel`].
|
60 |
+
|
61 |
+
Args:
|
62 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
63 |
+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
64 |
+
"""
|
65 |
+
|
66 |
+
sample: torch.FloatTensor = None
|
67 |
+
|
68 |
+
|
69 |
+
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
|
70 |
+
r"""
|
71 |
+
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
72 |
+
shaped output.
|
73 |
+
|
74 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
75 |
+
for all models (such as downloading or saving).
|
76 |
+
|
77 |
+
Parameters:
|
78 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
79 |
+
Height and width of input/output sample.
|
80 |
+
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
|
81 |
+
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
82 |
+
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
83 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
84 |
+
Whether to flip the sin to cos in the time embedding.
|
85 |
+
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
86 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
87 |
+
The tuple of downsample blocks to use.
|
88 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
89 |
+
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
|
90 |
+
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
91 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
92 |
+
The tuple of upsample blocks to use.
|
93 |
+
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
|
94 |
+
Whether to include self-attention in the basic transformer blocks, see
|
95 |
+
[`~models.attention.BasicTransformerBlock`].
|
96 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
97 |
+
The tuple of output channels for each block.
|
98 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
99 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
100 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
101 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
102 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
103 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
104 |
+
If `None`, normalization and activation layers is skipped in post-processing.
|
105 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
106 |
+
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
107 |
+
The dimension of the cross attention features.
|
108 |
+
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
|
109 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
110 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
111 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
112 |
+
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
|
113 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
|
114 |
+
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
|
115 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
116 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
117 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
118 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
119 |
+
dimension to `cross_attention_dim`.
|
120 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
121 |
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
122 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
123 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
124 |
+
num_attention_heads (`int`, *optional*):
|
125 |
+
The number of attention heads. If not defined, defaults to `attention_head_dim`
|
126 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
127 |
+
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
128 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
129 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
130 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
131 |
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
132 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
133 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
134 |
+
addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
|
135 |
+
Dimension for the timestep embeddings.
|
136 |
+
num_class_embeds (`int`, *optional*, defaults to `None`):
|
137 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
138 |
+
class conditioning with `class_embed_type` equal to `None`.
|
139 |
+
time_embedding_type (`str`, *optional*, defaults to `positional`):
|
140 |
+
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
141 |
+
time_embedding_dim (`int`, *optional*, defaults to `None`):
|
142 |
+
An optional override for the dimension of the projected time embedding.
|
143 |
+
time_embedding_act_fn (`str`, *optional*, defaults to `None`):
|
144 |
+
Optional activation function to use only once on the time embeddings before they are passed to the rest of
|
145 |
+
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
146 |
+
timestep_post_act (`str`, *optional*, defaults to `None`):
|
147 |
+
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
148 |
+
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
|
149 |
+
The dimension of `cond_proj` layer in the timestep embedding.
|
150 |
+
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
|
151 |
+
*optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
|
152 |
+
*optional*): The dimension of the `class_labels` input when
|
153 |
+
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
|
154 |
+
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
155 |
+
embeddings with the class embeddings.
|
156 |
+
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
|
157 |
+
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
|
158 |
+
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
|
159 |
+
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
|
160 |
+
otherwise.
|
161 |
+
"""
|
162 |
+
|
163 |
+
_supports_gradient_checkpointing = True
|
164 |
+
|
165 |
+
@register_to_config
|
166 |
+
def __init__(
|
167 |
+
self,
|
168 |
+
sample_size: Optional[int] = None,
|
169 |
+
in_channels: int = 4,
|
170 |
+
out_channels: int = 4,
|
171 |
+
center_input_sample: bool = False,
|
172 |
+
flip_sin_to_cos: bool = True,
|
173 |
+
freq_shift: int = 0,
|
174 |
+
down_block_types: Tuple[str] = (
|
175 |
+
"CrossAttnDownBlock2D",
|
176 |
+
"CrossAttnDownBlock2D",
|
177 |
+
"CrossAttnDownBlock2D",
|
178 |
+
"DownBlock2D",
|
179 |
+
),
|
180 |
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
181 |
+
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
182 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
183 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
184 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
185 |
+
downsample_padding: int = 1,
|
186 |
+
mid_block_scale_factor: float = 1,
|
187 |
+
dropout: float = 0.0,
|
188 |
+
act_fn: str = "silu",
|
189 |
+
norm_num_groups: Optional[int] = 32,
|
190 |
+
norm_eps: float = 1e-5,
|
191 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
192 |
+
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
193 |
+
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
|
194 |
+
encoder_hid_dim: Optional[int] = None,
|
195 |
+
encoder_hid_dim_type: Optional[str] = None,
|
196 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
197 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
198 |
+
dual_cross_attention: bool = False,
|
199 |
+
use_linear_projection: bool = False,
|
200 |
+
class_embed_type: Optional[str] = None,
|
201 |
+
addition_embed_type: Optional[str] = None,
|
202 |
+
addition_time_embed_dim: Optional[int] = None,
|
203 |
+
num_class_embeds: Optional[int] = None,
|
204 |
+
upcast_attention: bool = False,
|
205 |
+
resnet_time_scale_shift: str = "default",
|
206 |
+
resnet_skip_time_act: bool = False,
|
207 |
+
resnet_out_scale_factor: float = 1.0,
|
208 |
+
time_embedding_type: str = "positional",
|
209 |
+
time_embedding_dim: Optional[int] = None,
|
210 |
+
time_embedding_act_fn: Optional[str] = None,
|
211 |
+
timestep_post_act: Optional[str] = None,
|
212 |
+
time_cond_proj_dim: Optional[int] = None,
|
213 |
+
conv_in_kernel: int = 3,
|
214 |
+
conv_out_kernel: int = 3,
|
215 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
216 |
+
attention_type: str = "default",
|
217 |
+
class_embeddings_concat: bool = False,
|
218 |
+
mid_block_only_cross_attention: Optional[bool] = None,
|
219 |
+
cross_attention_norm: Optional[str] = None,
|
220 |
+
addition_embed_type_num_heads: int = 64,
|
221 |
+
):
|
222 |
+
super().__init__()
|
223 |
+
|
224 |
+
self.sample_size = sample_size
|
225 |
+
|
226 |
+
if num_attention_heads is not None:
|
227 |
+
raise ValueError(
|
228 |
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
229 |
+
)
|
230 |
+
|
231 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
232 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
233 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
234 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
235 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
236 |
+
# which is why we correct for the naming here.
|
237 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
238 |
+
|
239 |
+
# Check inputs
|
240 |
+
self._check_config(
|
241 |
+
down_block_types=down_block_types,
|
242 |
+
up_block_types=up_block_types,
|
243 |
+
only_cross_attention=only_cross_attention,
|
244 |
+
block_out_channels=block_out_channels,
|
245 |
+
layers_per_block=layers_per_block,
|
246 |
+
cross_attention_dim=cross_attention_dim,
|
247 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
248 |
+
reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
|
249 |
+
attention_head_dim=attention_head_dim,
|
250 |
+
num_attention_heads=num_attention_heads,
|
251 |
+
)
|
252 |
+
|
253 |
+
# input
|
254 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
255 |
+
self.conv_in = nn.Conv2d(
|
256 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
257 |
+
)
|
258 |
+
|
259 |
+
# time
|
260 |
+
time_embed_dim, timestep_input_dim = self._set_time_proj(
|
261 |
+
time_embedding_type,
|
262 |
+
block_out_channels=block_out_channels,
|
263 |
+
flip_sin_to_cos=flip_sin_to_cos,
|
264 |
+
freq_shift=freq_shift,
|
265 |
+
time_embedding_dim=time_embedding_dim,
|
266 |
+
)
|
267 |
+
|
268 |
+
self.time_embedding = TimestepEmbedding(
|
269 |
+
timestep_input_dim,
|
270 |
+
time_embed_dim,
|
271 |
+
act_fn=act_fn,
|
272 |
+
post_act_fn=timestep_post_act,
|
273 |
+
cond_proj_dim=time_cond_proj_dim,
|
274 |
+
)
|
275 |
+
|
276 |
+
self._set_encoder_hid_proj(
|
277 |
+
encoder_hid_dim_type,
|
278 |
+
cross_attention_dim=cross_attention_dim,
|
279 |
+
encoder_hid_dim=encoder_hid_dim,
|
280 |
+
)
|
281 |
+
|
282 |
+
# class embedding
|
283 |
+
self._set_class_embedding(
|
284 |
+
class_embed_type,
|
285 |
+
act_fn=act_fn,
|
286 |
+
num_class_embeds=num_class_embeds,
|
287 |
+
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
|
288 |
+
time_embed_dim=time_embed_dim,
|
289 |
+
timestep_input_dim=timestep_input_dim,
|
290 |
+
)
|
291 |
+
|
292 |
+
self._set_add_embedding(
|
293 |
+
addition_embed_type,
|
294 |
+
addition_embed_type_num_heads=addition_embed_type_num_heads,
|
295 |
+
addition_time_embed_dim=addition_time_embed_dim,
|
296 |
+
cross_attention_dim=cross_attention_dim,
|
297 |
+
encoder_hid_dim=encoder_hid_dim,
|
298 |
+
flip_sin_to_cos=flip_sin_to_cos,
|
299 |
+
freq_shift=freq_shift,
|
300 |
+
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
|
301 |
+
time_embed_dim=time_embed_dim,
|
302 |
+
)
|
303 |
+
|
304 |
+
if time_embedding_act_fn is None:
|
305 |
+
self.time_embed_act = None
|
306 |
+
else:
|
307 |
+
self.time_embed_act = get_activation(time_embedding_act_fn)
|
308 |
+
|
309 |
+
self.down_blocks = nn.ModuleList([])
|
310 |
+
self.up_blocks = nn.ModuleList([])
|
311 |
+
|
312 |
+
if isinstance(only_cross_attention, bool):
|
313 |
+
if mid_block_only_cross_attention is None:
|
314 |
+
mid_block_only_cross_attention = only_cross_attention
|
315 |
+
|
316 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
317 |
+
|
318 |
+
if mid_block_only_cross_attention is None:
|
319 |
+
mid_block_only_cross_attention = False
|
320 |
+
|
321 |
+
if isinstance(num_attention_heads, int):
|
322 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
323 |
+
|
324 |
+
if isinstance(attention_head_dim, int):
|
325 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
326 |
+
|
327 |
+
if isinstance(cross_attention_dim, int):
|
328 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
329 |
+
|
330 |
+
if isinstance(layers_per_block, int):
|
331 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
332 |
+
|
333 |
+
if isinstance(transformer_layers_per_block, int):
|
334 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
335 |
+
|
336 |
+
if class_embeddings_concat:
|
337 |
+
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
338 |
+
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
339 |
+
# regular time embeddings
|
340 |
+
blocks_time_embed_dim = time_embed_dim * 2
|
341 |
+
else:
|
342 |
+
blocks_time_embed_dim = time_embed_dim
|
343 |
+
|
344 |
+
# down
|
345 |
+
output_channel = block_out_channels[0]
|
346 |
+
for i, down_block_type in enumerate(down_block_types):
|
347 |
+
input_channel = output_channel
|
348 |
+
output_channel = block_out_channels[i]
|
349 |
+
is_final_block = i == len(block_out_channels) - 1
|
350 |
+
|
351 |
+
down_block = get_down_block(
|
352 |
+
down_block_type,
|
353 |
+
num_layers=layers_per_block[i],
|
354 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
355 |
+
in_channels=input_channel,
|
356 |
+
out_channels=output_channel,
|
357 |
+
temb_channels=blocks_time_embed_dim,
|
358 |
+
add_downsample=not is_final_block,
|
359 |
+
resnet_eps=norm_eps,
|
360 |
+
resnet_act_fn=act_fn,
|
361 |
+
resnet_groups=norm_num_groups,
|
362 |
+
cross_attention_dim=cross_attention_dim[i],
|
363 |
+
num_attention_heads=num_attention_heads[i],
|
364 |
+
downsample_padding=downsample_padding,
|
365 |
+
dual_cross_attention=dual_cross_attention,
|
366 |
+
use_linear_projection=use_linear_projection,
|
367 |
+
only_cross_attention=only_cross_attention[i],
|
368 |
+
upcast_attention=upcast_attention,
|
369 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
370 |
+
attention_type=attention_type,
|
371 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
372 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
373 |
+
cross_attention_norm=cross_attention_norm,
|
374 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
375 |
+
dropout=dropout,
|
376 |
+
)
|
377 |
+
self.down_blocks.append(down_block)
|
378 |
+
|
379 |
+
# mid
|
380 |
+
self.mid_block = get_mid_block(
|
381 |
+
mid_block_type,
|
382 |
+
temb_channels=blocks_time_embed_dim,
|
383 |
+
in_channels=block_out_channels[-1],
|
384 |
+
resnet_eps=norm_eps,
|
385 |
+
resnet_act_fn=act_fn,
|
386 |
+
resnet_groups=norm_num_groups,
|
387 |
+
output_scale_factor=mid_block_scale_factor,
|
388 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
389 |
+
num_attention_heads=num_attention_heads[-1],
|
390 |
+
cross_attention_dim=cross_attention_dim[-1],
|
391 |
+
dual_cross_attention=dual_cross_attention,
|
392 |
+
use_linear_projection=use_linear_projection,
|
393 |
+
mid_block_only_cross_attention=mid_block_only_cross_attention,
|
394 |
+
upcast_attention=upcast_attention,
|
395 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
396 |
+
attention_type=attention_type,
|
397 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
398 |
+
cross_attention_norm=cross_attention_norm,
|
399 |
+
attention_head_dim=attention_head_dim[-1],
|
400 |
+
dropout=dropout,
|
401 |
+
)
|
402 |
+
|
403 |
+
# count how many layers upsample the images
|
404 |
+
self.num_upsamplers = 0
|
405 |
+
|
406 |
+
# up
|
407 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
408 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
409 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
410 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
411 |
+
reversed_transformer_layers_per_block = (
|
412 |
+
list(reversed(transformer_layers_per_block))
|
413 |
+
if reverse_transformer_layers_per_block is None
|
414 |
+
else reverse_transformer_layers_per_block
|
415 |
+
)
|
416 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
417 |
+
|
418 |
+
output_channel = reversed_block_out_channels[0]
|
419 |
+
for i, up_block_type in enumerate(up_block_types):
|
420 |
+
is_final_block = i == len(block_out_channels) - 1
|
421 |
+
|
422 |
+
prev_output_channel = output_channel
|
423 |
+
output_channel = reversed_block_out_channels[i]
|
424 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
425 |
+
|
426 |
+
# add upsample block for all BUT final layer
|
427 |
+
if not is_final_block:
|
428 |
+
add_upsample = True
|
429 |
+
self.num_upsamplers += 1
|
430 |
+
else:
|
431 |
+
add_upsample = False
|
432 |
+
|
433 |
+
up_block = get_up_block(
|
434 |
+
up_block_type,
|
435 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
436 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
437 |
+
in_channels=input_channel,
|
438 |
+
out_channels=output_channel,
|
439 |
+
prev_output_channel=prev_output_channel,
|
440 |
+
temb_channels=blocks_time_embed_dim,
|
441 |
+
add_upsample=add_upsample,
|
442 |
+
resnet_eps=norm_eps,
|
443 |
+
resnet_act_fn=act_fn,
|
444 |
+
resolution_idx=i,
|
445 |
+
resnet_groups=norm_num_groups,
|
446 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
447 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
448 |
+
dual_cross_attention=dual_cross_attention,
|
449 |
+
use_linear_projection=use_linear_projection,
|
450 |
+
only_cross_attention=only_cross_attention[i],
|
451 |
+
upcast_attention=upcast_attention,
|
452 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
453 |
+
attention_type=attention_type,
|
454 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
455 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
456 |
+
cross_attention_norm=cross_attention_norm,
|
457 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
458 |
+
dropout=dropout,
|
459 |
+
)
|
460 |
+
self.up_blocks.append(up_block)
|
461 |
+
prev_output_channel = output_channel
|
462 |
+
|
463 |
+
# out
|
464 |
+
if norm_num_groups is not None:
|
465 |
+
self.conv_norm_out = nn.GroupNorm(
|
466 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
467 |
+
)
|
468 |
+
|
469 |
+
self.conv_act = get_activation(act_fn)
|
470 |
+
|
471 |
+
else:
|
472 |
+
self.conv_norm_out = None
|
473 |
+
self.conv_act = None
|
474 |
+
|
475 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
476 |
+
self.conv_out = nn.Conv2d(
|
477 |
+
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
478 |
+
)
|
479 |
+
|
480 |
+
self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
|
481 |
+
|
482 |
+
def _check_config(
|
483 |
+
self,
|
484 |
+
down_block_types: Tuple[str],
|
485 |
+
up_block_types: Tuple[str],
|
486 |
+
only_cross_attention: Union[bool, Tuple[bool]],
|
487 |
+
block_out_channels: Tuple[int],
|
488 |
+
layers_per_block: Union[int, Tuple[int]],
|
489 |
+
cross_attention_dim: Union[int, Tuple[int]],
|
490 |
+
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
|
491 |
+
reverse_transformer_layers_per_block: bool,
|
492 |
+
attention_head_dim: int,
|
493 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]],
|
494 |
+
):
|
495 |
+
if len(down_block_types) != len(up_block_types):
|
496 |
+
raise ValueError(
|
497 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
498 |
+
)
|
499 |
+
|
500 |
+
if len(block_out_channels) != len(down_block_types):
|
501 |
+
raise ValueError(
|
502 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
503 |
+
)
|
504 |
+
|
505 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
506 |
+
raise ValueError(
|
507 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
508 |
+
)
|
509 |
+
|
510 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
511 |
+
raise ValueError(
|
512 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
513 |
+
)
|
514 |
+
|
515 |
+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
516 |
+
raise ValueError(
|
517 |
+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
518 |
+
)
|
519 |
+
|
520 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
521 |
+
raise ValueError(
|
522 |
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
523 |
+
)
|
524 |
+
|
525 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
526 |
+
raise ValueError(
|
527 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
528 |
+
)
|
529 |
+
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
|
530 |
+
for layer_number_per_block in transformer_layers_per_block:
|
531 |
+
if isinstance(layer_number_per_block, list):
|
532 |
+
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
|
533 |
+
|
534 |
+
def _set_time_proj(
|
535 |
+
self,
|
536 |
+
time_embedding_type: str,
|
537 |
+
block_out_channels: int,
|
538 |
+
flip_sin_to_cos: bool,
|
539 |
+
freq_shift: float,
|
540 |
+
time_embedding_dim: int,
|
541 |
+
) -> Tuple[int, int]:
|
542 |
+
if time_embedding_type == "fourier":
|
543 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
544 |
+
if time_embed_dim % 2 != 0:
|
545 |
+
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
546 |
+
self.time_proj = GaussianFourierProjection(
|
547 |
+
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
548 |
+
)
|
549 |
+
timestep_input_dim = time_embed_dim
|
550 |
+
elif time_embedding_type == "positional":
|
551 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
552 |
+
|
553 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
554 |
+
timestep_input_dim = block_out_channels[0]
|
555 |
+
else:
|
556 |
+
raise ValueError(
|
557 |
+
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
558 |
+
)
|
559 |
+
|
560 |
+
return time_embed_dim, timestep_input_dim
|
561 |
+
|
562 |
+
def _set_encoder_hid_proj(
|
563 |
+
self,
|
564 |
+
encoder_hid_dim_type: Optional[str],
|
565 |
+
cross_attention_dim: Union[int, Tuple[int]],
|
566 |
+
encoder_hid_dim: Optional[int],
|
567 |
+
):
|
568 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
569 |
+
encoder_hid_dim_type = "text_proj"
|
570 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
571 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
572 |
+
|
573 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
574 |
+
raise ValueError(
|
575 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
576 |
+
)
|
577 |
+
|
578 |
+
if encoder_hid_dim_type == "text_proj":
|
579 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
580 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
581 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
582 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
583 |
+
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
584 |
+
self.encoder_hid_proj = TextImageProjection(
|
585 |
+
text_embed_dim=encoder_hid_dim,
|
586 |
+
image_embed_dim=cross_attention_dim,
|
587 |
+
cross_attention_dim=cross_attention_dim,
|
588 |
+
)
|
589 |
+
elif encoder_hid_dim_type == "image_proj":
|
590 |
+
# Kandinsky 2.2
|
591 |
+
self.encoder_hid_proj = ImageProjection(
|
592 |
+
image_embed_dim=encoder_hid_dim,
|
593 |
+
cross_attention_dim=cross_attention_dim,
|
594 |
+
)
|
595 |
+
elif encoder_hid_dim_type is not None:
|
596 |
+
raise ValueError(
|
597 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
598 |
+
)
|
599 |
+
else:
|
600 |
+
self.encoder_hid_proj = None
|
601 |
+
|
602 |
+
def _set_class_embedding(
|
603 |
+
self,
|
604 |
+
class_embed_type: Optional[str],
|
605 |
+
act_fn: str,
|
606 |
+
num_class_embeds: Optional[int],
|
607 |
+
projection_class_embeddings_input_dim: Optional[int],
|
608 |
+
time_embed_dim: int,
|
609 |
+
timestep_input_dim: int,
|
610 |
+
):
|
611 |
+
if class_embed_type is None and num_class_embeds is not None:
|
612 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
613 |
+
elif class_embed_type == "timestep":
|
614 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
615 |
+
elif class_embed_type == "identity":
|
616 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
617 |
+
elif class_embed_type == "projection":
|
618 |
+
if projection_class_embeddings_input_dim is None:
|
619 |
+
raise ValueError(
|
620 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
621 |
+
)
|
622 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
623 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
624 |
+
# 2. it projects from an arbitrary input dimension.
|
625 |
+
#
|
626 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
627 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
628 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
629 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
630 |
+
elif class_embed_type == "simple_projection":
|
631 |
+
if projection_class_embeddings_input_dim is None:
|
632 |
+
raise ValueError(
|
633 |
+
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
|
634 |
+
)
|
635 |
+
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
|
636 |
+
else:
|
637 |
+
self.class_embedding = None
|
638 |
+
|
639 |
+
def _set_add_embedding(
|
640 |
+
self,
|
641 |
+
addition_embed_type: str,
|
642 |
+
addition_embed_type_num_heads: int,
|
643 |
+
addition_time_embed_dim: Optional[int],
|
644 |
+
flip_sin_to_cos: bool,
|
645 |
+
freq_shift: float,
|
646 |
+
cross_attention_dim: Optional[int],
|
647 |
+
encoder_hid_dim: Optional[int],
|
648 |
+
projection_class_embeddings_input_dim: Optional[int],
|
649 |
+
time_embed_dim: int,
|
650 |
+
):
|
651 |
+
if addition_embed_type == "text":
|
652 |
+
if encoder_hid_dim is not None:
|
653 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
654 |
+
else:
|
655 |
+
text_time_embedding_from_dim = cross_attention_dim
|
656 |
+
|
657 |
+
self.add_embedding = TextTimeEmbedding(
|
658 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
659 |
+
)
|
660 |
+
elif addition_embed_type == "text_image":
|
661 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
662 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
663 |
+
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
664 |
+
self.add_embedding = TextImageTimeEmbedding(
|
665 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
666 |
+
)
|
667 |
+
elif addition_embed_type == "text_time":
|
668 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
669 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
670 |
+
elif addition_embed_type == "image":
|
671 |
+
# Kandinsky 2.2
|
672 |
+
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
673 |
+
elif addition_embed_type == "image_hint":
|
674 |
+
# Kandinsky 2.2 ControlNet
|
675 |
+
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
676 |
+
elif addition_embed_type is not None:
|
677 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
678 |
+
|
679 |
+
def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
|
680 |
+
if attention_type in ["gated", "gated-text-image"]:
|
681 |
+
positive_len = 768
|
682 |
+
if isinstance(cross_attention_dim, int):
|
683 |
+
positive_len = cross_attention_dim
|
684 |
+
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
|
685 |
+
positive_len = cross_attention_dim[0]
|
686 |
+
|
687 |
+
feature_type = "text-only" if attention_type == "gated" else "text-image"
|
688 |
+
self.position_net = GLIGENTextBoundingboxProjection(
|
689 |
+
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
|
690 |
+
)
|
691 |
+
|
692 |
+
@property
|
693 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
694 |
+
r"""
|
695 |
+
Returns:
|
696 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
697 |
+
indexed by its weight name.
|
698 |
+
"""
|
699 |
+
# set recursively
|
700 |
+
processors = {}
|
701 |
+
|
702 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
703 |
+
if hasattr(module, "get_processor"):
|
704 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
705 |
+
|
706 |
+
for sub_name, child in module.named_children():
|
707 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
708 |
+
|
709 |
+
return processors
|
710 |
+
|
711 |
+
for name, module in self.named_children():
|
712 |
+
fn_recursive_add_processors(name, module, processors)
|
713 |
+
|
714 |
+
return processors
|
715 |
+
|
716 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
717 |
+
r"""
|
718 |
+
Sets the attention processor to use to compute attention.
|
719 |
+
|
720 |
+
Parameters:
|
721 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
722 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
723 |
+
for **all** `Attention` layers.
|
724 |
+
|
725 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
726 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
727 |
+
|
728 |
+
"""
|
729 |
+
count = len(self.attn_processors.keys())
|
730 |
+
|
731 |
+
if isinstance(processor, dict) and len(processor) != count:
|
732 |
+
raise ValueError(
|
733 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
734 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
735 |
+
)
|
736 |
+
|
737 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
738 |
+
if hasattr(module, "set_processor"):
|
739 |
+
if not isinstance(processor, dict):
|
740 |
+
module.set_processor(processor)
|
741 |
+
else:
|
742 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
743 |
+
|
744 |
+
for sub_name, child in module.named_children():
|
745 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
746 |
+
|
747 |
+
for name, module in self.named_children():
|
748 |
+
fn_recursive_attn_processor(name, module, processor)
|
749 |
+
|
750 |
+
def set_default_attn_processor(self):
|
751 |
+
"""
|
752 |
+
Disables custom attention processors and sets the default attention implementation.
|
753 |
+
"""
|
754 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
755 |
+
processor = AttnAddedKVProcessor()
|
756 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
757 |
+
processor = AttnProcessor()
|
758 |
+
else:
|
759 |
+
raise ValueError(
|
760 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
761 |
+
)
|
762 |
+
|
763 |
+
self.set_attn_processor(processor)
|
764 |
+
|
765 |
+
def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
|
766 |
+
r"""
|
767 |
+
Enable sliced attention computation.
|
768 |
+
|
769 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
770 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
771 |
+
|
772 |
+
Args:
|
773 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
774 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
775 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
776 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
777 |
+
must be a multiple of `slice_size`.
|
778 |
+
"""
|
779 |
+
sliceable_head_dims = []
|
780 |
+
|
781 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
782 |
+
if hasattr(module, "set_attention_slice"):
|
783 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
784 |
+
|
785 |
+
for child in module.children():
|
786 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
787 |
+
|
788 |
+
# retrieve number of attention layers
|
789 |
+
for module in self.children():
|
790 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
791 |
+
|
792 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
793 |
+
|
794 |
+
if slice_size == "auto":
|
795 |
+
# half the attention head size is usually a good trade-off between
|
796 |
+
# speed and memory
|
797 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
798 |
+
elif slice_size == "max":
|
799 |
+
# make smallest slice possible
|
800 |
+
slice_size = num_sliceable_layers * [1]
|
801 |
+
|
802 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
803 |
+
|
804 |
+
if len(slice_size) != len(sliceable_head_dims):
|
805 |
+
raise ValueError(
|
806 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
807 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
808 |
+
)
|
809 |
+
|
810 |
+
for i in range(len(slice_size)):
|
811 |
+
size = slice_size[i]
|
812 |
+
dim = sliceable_head_dims[i]
|
813 |
+
if size is not None and size > dim:
|
814 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
815 |
+
|
816 |
+
# Recursively walk through all the children.
|
817 |
+
# Any children which exposes the set_attention_slice method
|
818 |
+
# gets the message
|
819 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
820 |
+
if hasattr(module, "set_attention_slice"):
|
821 |
+
module.set_attention_slice(slice_size.pop())
|
822 |
+
|
823 |
+
for child in module.children():
|
824 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
825 |
+
|
826 |
+
reversed_slice_size = list(reversed(slice_size))
|
827 |
+
for module in self.children():
|
828 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
829 |
+
|
830 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
831 |
+
if hasattr(module, "gradient_checkpointing"):
|
832 |
+
module.gradient_checkpointing = value
|
833 |
+
|
834 |
+
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
835 |
+
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
836 |
+
|
837 |
+
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
838 |
+
|
839 |
+
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
|
840 |
+
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
841 |
+
|
842 |
+
Args:
|
843 |
+
s1 (`float`):
|
844 |
+
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
845 |
+
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
846 |
+
s2 (`float`):
|
847 |
+
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
848 |
+
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
849 |
+
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
850 |
+
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
851 |
+
"""
|
852 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
853 |
+
setattr(upsample_block, "s1", s1)
|
854 |
+
setattr(upsample_block, "s2", s2)
|
855 |
+
setattr(upsample_block, "b1", b1)
|
856 |
+
setattr(upsample_block, "b2", b2)
|
857 |
+
|
858 |
+
def disable_freeu(self):
|
859 |
+
"""Disables the FreeU mechanism."""
|
860 |
+
freeu_keys = {"s1", "s2", "b1", "b2"}
|
861 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
862 |
+
for k in freeu_keys:
|
863 |
+
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
864 |
+
setattr(upsample_block, k, None)
|
865 |
+
|
866 |
+
def fuse_qkv_projections(self):
|
867 |
+
"""
|
868 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
869 |
+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
870 |
+
|
871 |
+
<Tip warning={true}>
|
872 |
+
|
873 |
+
This API is 🧪 experimental.
|
874 |
+
|
875 |
+
</Tip>
|
876 |
+
"""
|
877 |
+
self.original_attn_processors = None
|
878 |
+
|
879 |
+
for _, attn_processor in self.attn_processors.items():
|
880 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
881 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
882 |
+
|
883 |
+
self.original_attn_processors = self.attn_processors
|
884 |
+
|
885 |
+
for module in self.modules():
|
886 |
+
if isinstance(module, Attention):
|
887 |
+
module.fuse_projections(fuse=True)
|
888 |
+
|
889 |
+
def unfuse_qkv_projections(self):
|
890 |
+
"""Disables the fused QKV projection if enabled.
|
891 |
+
|
892 |
+
<Tip warning={true}>
|
893 |
+
|
894 |
+
This API is 🧪 experimental.
|
895 |
+
|
896 |
+
</Tip>
|
897 |
+
|
898 |
+
"""
|
899 |
+
if self.original_attn_processors is not None:
|
900 |
+
self.set_attn_processor(self.original_attn_processors)
|
901 |
+
|
902 |
+
def unload_lora(self):
|
903 |
+
"""Unloads LoRA weights."""
|
904 |
+
deprecate(
|
905 |
+
"unload_lora",
|
906 |
+
"0.28.0",
|
907 |
+
"Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
|
908 |
+
)
|
909 |
+
for module in self.modules():
|
910 |
+
if hasattr(module, "set_lora_layer"):
|
911 |
+
module.set_lora_layer(None)
|
912 |
+
|
913 |
+
def get_time_embed(
|
914 |
+
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
|
915 |
+
) -> Optional[torch.Tensor]:
|
916 |
+
timesteps = timestep
|
917 |
+
if not torch.is_tensor(timesteps):
|
918 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
919 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
920 |
+
is_mps = sample.device.type == "mps"
|
921 |
+
if isinstance(timestep, float):
|
922 |
+
dtype = torch.float32 if is_mps else torch.float64
|
923 |
+
else:
|
924 |
+
dtype = torch.int32 if is_mps else torch.int64
|
925 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
926 |
+
elif len(timesteps.shape) == 0:
|
927 |
+
timesteps = timesteps[None].to(sample.device)
|
928 |
+
|
929 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
930 |
+
timesteps = timesteps.expand(sample.shape[0])
|
931 |
+
|
932 |
+
t_emb = self.time_proj(timesteps)
|
933 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
934 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
935 |
+
# there might be better ways to encapsulate this.
|
936 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
937 |
+
return t_emb
|
938 |
+
|
939 |
+
def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
940 |
+
class_emb = None
|
941 |
+
if self.class_embedding is not None:
|
942 |
+
if class_labels is None:
|
943 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
944 |
+
|
945 |
+
if self.config.class_embed_type == "timestep":
|
946 |
+
class_labels = self.time_proj(class_labels)
|
947 |
+
|
948 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
949 |
+
# there might be better ways to encapsulate this.
|
950 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
951 |
+
|
952 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
953 |
+
return class_emb
|
954 |
+
|
955 |
+
def get_aug_embed(
|
956 |
+
self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
|
957 |
+
) -> Optional[torch.Tensor]:
|
958 |
+
aug_emb = None
|
959 |
+
if self.config.addition_embed_type == "text":
|
960 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
961 |
+
elif self.config.addition_embed_type == "text_image":
|
962 |
+
# Kandinsky 2.1 - style
|
963 |
+
if "image_embeds" not in added_cond_kwargs:
|
964 |
+
raise ValueError(
|
965 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
966 |
+
)
|
967 |
+
|
968 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
969 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
970 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
971 |
+
elif self.config.addition_embed_type == "text_time":
|
972 |
+
# SDXL - style
|
973 |
+
if "text_embeds" not in added_cond_kwargs:
|
974 |
+
raise ValueError(
|
975 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
976 |
+
)
|
977 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
978 |
+
if "time_ids" not in added_cond_kwargs:
|
979 |
+
raise ValueError(
|
980 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
981 |
+
)
|
982 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
983 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
984 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
985 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
986 |
+
add_embeds = add_embeds.to(emb.dtype)
|
987 |
+
aug_emb = self.add_embedding(add_embeds)
|
988 |
+
elif self.config.addition_embed_type == "image":
|
989 |
+
# Kandinsky 2.2 - style
|
990 |
+
if "image_embeds" not in added_cond_kwargs:
|
991 |
+
raise ValueError(
|
992 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
993 |
+
)
|
994 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
995 |
+
aug_emb = self.add_embedding(image_embs)
|
996 |
+
elif self.config.addition_embed_type == "image_hint":
|
997 |
+
# Kandinsky 2.2 - style
|
998 |
+
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
999 |
+
raise ValueError(
|
1000 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
1001 |
+
)
|
1002 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
1003 |
+
hint = added_cond_kwargs.get("hint")
|
1004 |
+
aug_emb = self.add_embedding(image_embs, hint)
|
1005 |
+
return aug_emb
|
1006 |
+
|
1007 |
+
def process_encoder_hidden_states(
|
1008 |
+
self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
|
1009 |
+
) -> torch.Tensor:
|
1010 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
1011 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
1012 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
1013 |
+
# Kadinsky 2.1 - style
|
1014 |
+
if "image_embeds" not in added_cond_kwargs:
|
1015 |
+
raise ValueError(
|
1016 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1017 |
+
)
|
1018 |
+
|
1019 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
1020 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
1021 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
1022 |
+
# Kandinsky 2.2 - style
|
1023 |
+
if "image_embeds" not in added_cond_kwargs:
|
1024 |
+
raise ValueError(
|
1025 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1026 |
+
)
|
1027 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
1028 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
1029 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
1030 |
+
if "image_embeds" not in added_cond_kwargs:
|
1031 |
+
raise ValueError(
|
1032 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1033 |
+
)
|
1034 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
1035 |
+
image_embeds = self.encoder_hid_proj(image_embeds)
|
1036 |
+
encoder_hidden_states = (encoder_hidden_states, image_embeds)
|
1037 |
+
return encoder_hidden_states
|
1038 |
+
|
1039 |
+
def forward(
|
1040 |
+
self,
|
1041 |
+
sample: torch.FloatTensor,
|
1042 |
+
timestep: Union[torch.Tensor, float, int],
|
1043 |
+
encoder_hidden_states: torch.Tensor,
|
1044 |
+
class_labels: Optional[torch.Tensor] = None,
|
1045 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
1046 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1047 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1048 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
1049 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
1050 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
1051 |
+
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
1052 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
1053 |
+
return_dict: bool = True,
|
1054 |
+
down_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
|
1055 |
+
mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
|
1056 |
+
up_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
|
1057 |
+
features_adapter: Optional[torch.Tensor] = None,
|
1058 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
1059 |
+
r"""
|
1060 |
+
The [`UNet2DConditionModel`] forward method.
|
1061 |
+
|
1062 |
+
Args:
|
1063 |
+
sample (`torch.FloatTensor`):
|
1064 |
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
1065 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
1066 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
1067 |
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
1068 |
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
1069 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
1070 |
+
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
|
1071 |
+
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
|
1072 |
+
through the `self.time_embedding` layer to obtain the timestep embeddings.
|
1073 |
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
1074 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
1075 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
1076 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
1077 |
+
cross_attention_kwargs (`dict`, *optional*):
|
1078 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
1079 |
+
`self.processor` in
|
1080 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
1081 |
+
added_cond_kwargs: (`dict`, *optional*):
|
1082 |
+
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
|
1083 |
+
are passed along to the UNet blocks.
|
1084 |
+
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
|
1085 |
+
A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
1086 |
+
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
1087 |
+
A tensor that if specified is added to the residual of the middle unet block.
|
1088 |
+
encoder_attention_mask (`torch.Tensor`):
|
1089 |
+
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
1090 |
+
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
1091 |
+
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
1092 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1093 |
+
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
1094 |
+
tuple.
|
1095 |
+
cross_attention_kwargs (`dict`, *optional*):
|
1096 |
+
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
1097 |
+
added_cond_kwargs: (`dict`, *optional*):
|
1098 |
+
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
1099 |
+
are passed along to the UNet blocks.
|
1100 |
+
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
1101 |
+
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
|
1102 |
+
example from ControlNet side model(s)
|
1103 |
+
mid_block_additional_residual (`torch.Tensor`, *optional*):
|
1104 |
+
additional residual to be added to UNet mid block output, for example from ControlNet side model
|
1105 |
+
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
1106 |
+
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
|
1107 |
+
features_adapter (`torch.FloatTensor`, *optional*):
|
1108 |
+
(batch, channels, num_frames, height, width) adapter features tensor
|
1109 |
+
|
1110 |
+
Returns:
|
1111 |
+
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
1112 |
+
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
1113 |
+
a `tuple` is returned where the first element is the sample tensor.
|
1114 |
+
"""
|
1115 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
1116 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
1117 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
1118 |
+
# on the fly if necessary.
|
1119 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
1120 |
+
|
1121 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
1122 |
+
forward_upsample_size = False
|
1123 |
+
upsample_size = None
|
1124 |
+
|
1125 |
+
for dim in sample.shape[-2:]:
|
1126 |
+
if dim % default_overall_up_factor != 0:
|
1127 |
+
# Forward upsample size to force interpolation output size.
|
1128 |
+
forward_upsample_size = True
|
1129 |
+
break
|
1130 |
+
|
1131 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
1132 |
+
# expects mask of shape:
|
1133 |
+
# [batch, key_tokens]
|
1134 |
+
# adds singleton query_tokens dimension:
|
1135 |
+
# [batch, 1, key_tokens]
|
1136 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
1137 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
1138 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
1139 |
+
if attention_mask is not None:
|
1140 |
+
# assume that mask is expressed as:
|
1141 |
+
# (1 = keep, 0 = discard)
|
1142 |
+
# convert mask into a bias that can be added to attention scores:
|
1143 |
+
# (keep = +0, discard = -10000.0)
|
1144 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
1145 |
+
attention_mask = attention_mask.unsqueeze(1)
|
1146 |
+
|
1147 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
1148 |
+
if encoder_attention_mask is not None:
|
1149 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
1150 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
1151 |
+
|
1152 |
+
# 0. center input if necessary
|
1153 |
+
if self.config.center_input_sample:
|
1154 |
+
sample = 2 * sample - 1.0
|
1155 |
+
|
1156 |
+
# 1. time
|
1157 |
+
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
|
1158 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
1159 |
+
aug_emb = None
|
1160 |
+
|
1161 |
+
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
|
1162 |
+
if class_emb is not None:
|
1163 |
+
if self.config.class_embeddings_concat:
|
1164 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
1165 |
+
else:
|
1166 |
+
emb = emb + class_emb
|
1167 |
+
|
1168 |
+
aug_emb = self.get_aug_embed(
|
1169 |
+
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
1170 |
+
)
|
1171 |
+
if self.config.addition_embed_type == "image_hint":
|
1172 |
+
aug_emb, hint = aug_emb
|
1173 |
+
sample = torch.cat([sample, hint], dim=1)
|
1174 |
+
|
1175 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
1176 |
+
|
1177 |
+
if self.time_embed_act is not None:
|
1178 |
+
emb = self.time_embed_act(emb)
|
1179 |
+
|
1180 |
+
encoder_hidden_states = self.process_encoder_hidden_states(
|
1181 |
+
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
1182 |
+
)
|
1183 |
+
|
1184 |
+
# 2. pre-process
|
1185 |
+
sample = self.conv_in(sample)
|
1186 |
+
|
1187 |
+
# 2.5 GLIGEN position net
|
1188 |
+
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
|
1189 |
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
1190 |
+
gligen_args = cross_attention_kwargs.pop("gligen")
|
1191 |
+
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
1192 |
+
|
1193 |
+
# 3. down
|
1194 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
1195 |
+
if USE_PEFT_BACKEND:
|
1196 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
1197 |
+
scale_lora_layers(self, lora_scale)
|
1198 |
+
|
1199 |
+
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
1200 |
+
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
|
1201 |
+
is_adapter = down_intrablock_additional_residuals is not None
|
1202 |
+
# maintain backward compatibility for legacy usage, where
|
1203 |
+
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
|
1204 |
+
# but can only use one or the other
|
1205 |
+
is_brushnet = down_block_add_samples is not None and mid_block_add_sample is not None and up_block_add_samples is not None
|
1206 |
+
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
|
1207 |
+
deprecate(
|
1208 |
+
"T2I should not use down_block_additional_residuals",
|
1209 |
+
"1.3.0",
|
1210 |
+
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
|
1211 |
+
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
|
1212 |
+
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
|
1213 |
+
standard_warn=False,
|
1214 |
+
)
|
1215 |
+
down_intrablock_additional_residuals = down_block_additional_residuals
|
1216 |
+
is_adapter = True
|
1217 |
+
|
1218 |
+
down_block_res_samples = (sample,)
|
1219 |
+
|
1220 |
+
if is_brushnet:
|
1221 |
+
sample = sample + down_block_add_samples.pop(0)
|
1222 |
+
|
1223 |
+
adapter_idx = 0
|
1224 |
+
for downsample_block in self.down_blocks:
|
1225 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
1226 |
+
# For t2i-adapter CrossAttnDownBlock2D
|
1227 |
+
additional_residuals = {}
|
1228 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
1229 |
+
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
|
1230 |
+
|
1231 |
+
if is_brushnet and len(down_block_add_samples)>0:
|
1232 |
+
additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
|
1233 |
+
for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))]
|
1234 |
+
|
1235 |
+
sample, res_samples = downsample_block(
|
1236 |
+
hidden_states=sample,
|
1237 |
+
temb=emb,
|
1238 |
+
encoder_hidden_states=encoder_hidden_states,
|
1239 |
+
attention_mask=attention_mask,
|
1240 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1241 |
+
encoder_attention_mask=encoder_attention_mask,
|
1242 |
+
**additional_residuals,
|
1243 |
+
)
|
1244 |
+
else:
|
1245 |
+
additional_residuals = {}
|
1246 |
+
if is_brushnet and len(down_block_add_samples)>0:
|
1247 |
+
additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
|
1248 |
+
for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))]
|
1249 |
+
|
1250 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale, **additional_residuals)
|
1251 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
1252 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
1253 |
+
|
1254 |
+
if features_adapter is not None:
|
1255 |
+
sample += features_adapter[adapter_idx]
|
1256 |
+
adapter_idx += 1
|
1257 |
+
|
1258 |
+
down_block_res_samples += res_samples
|
1259 |
+
|
1260 |
+
if features_adapter is not None:
|
1261 |
+
assert len(features_adapter) == adapter_idx, "Wrong features_adapter"
|
1262 |
+
|
1263 |
+
if is_controlnet:
|
1264 |
+
new_down_block_res_samples = ()
|
1265 |
+
|
1266 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
1267 |
+
down_block_res_samples, down_block_additional_residuals
|
1268 |
+
):
|
1269 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
1270 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
1271 |
+
|
1272 |
+
down_block_res_samples = new_down_block_res_samples
|
1273 |
+
|
1274 |
+
# 4. mid
|
1275 |
+
if self.mid_block is not None:
|
1276 |
+
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
1277 |
+
sample = self.mid_block(
|
1278 |
+
sample,
|
1279 |
+
emb,
|
1280 |
+
encoder_hidden_states=encoder_hidden_states,
|
1281 |
+
attention_mask=attention_mask,
|
1282 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1283 |
+
encoder_attention_mask=encoder_attention_mask,
|
1284 |
+
)
|
1285 |
+
else:
|
1286 |
+
sample = self.mid_block(sample, emb)
|
1287 |
+
|
1288 |
+
# To support T2I-Adapter-XL
|
1289 |
+
if (
|
1290 |
+
is_adapter
|
1291 |
+
and len(down_intrablock_additional_residuals) > 0
|
1292 |
+
and sample.shape == down_intrablock_additional_residuals[0].shape
|
1293 |
+
):
|
1294 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
1295 |
+
|
1296 |
+
if is_controlnet:
|
1297 |
+
sample = sample + mid_block_additional_residual
|
1298 |
+
|
1299 |
+
if is_brushnet:
|
1300 |
+
sample = sample + mid_block_add_sample
|
1301 |
+
|
1302 |
+
# 5. up
|
1303 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
1304 |
+
is_final_block = i == len(self.up_blocks) - 1
|
1305 |
+
|
1306 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
1307 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
1308 |
+
|
1309 |
+
# if we have not reached the final block and need to forward the
|
1310 |
+
# upsample size, we do it here
|
1311 |
+
if not is_final_block and forward_upsample_size:
|
1312 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
1313 |
+
|
1314 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
1315 |
+
additional_residuals = {}
|
1316 |
+
if is_brushnet and len(up_block_add_samples)>0:
|
1317 |
+
additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
|
1318 |
+
for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))]
|
1319 |
+
|
1320 |
+
sample = upsample_block(
|
1321 |
+
hidden_states=sample,
|
1322 |
+
temb=emb,
|
1323 |
+
res_hidden_states_tuple=res_samples,
|
1324 |
+
encoder_hidden_states=encoder_hidden_states,
|
1325 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1326 |
+
upsample_size=upsample_size,
|
1327 |
+
attention_mask=attention_mask,
|
1328 |
+
encoder_attention_mask=encoder_attention_mask,
|
1329 |
+
**additional_residuals,
|
1330 |
+
)
|
1331 |
+
else:
|
1332 |
+
additional_residuals = {}
|
1333 |
+
if is_brushnet and len(up_block_add_samples)>0:
|
1334 |
+
additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
|
1335 |
+
for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))]
|
1336 |
+
|
1337 |
+
sample = upsample_block(
|
1338 |
+
hidden_states=sample,
|
1339 |
+
temb=emb,
|
1340 |
+
res_hidden_states_tuple=res_samples,
|
1341 |
+
upsample_size=upsample_size,
|
1342 |
+
scale=lora_scale,
|
1343 |
+
**additional_residuals,
|
1344 |
+
)
|
1345 |
+
|
1346 |
+
# 6. post-process
|
1347 |
+
if self.conv_norm_out:
|
1348 |
+
sample = self.conv_norm_out(sample)
|
1349 |
+
sample = self.conv_act(sample)
|
1350 |
+
sample = self.conv_out(sample)
|
1351 |
+
|
1352 |
+
if USE_PEFT_BACKEND:
|
1353 |
+
# remove `lora_scale` from each PEFT layer
|
1354 |
+
unscale_lora_layers(self, lora_scale)
|
1355 |
+
|
1356 |
+
if not return_dict:
|
1357 |
+
return (sample,)
|
1358 |
+
|
1359 |
+
return UNet2DConditionOutput(sample=sample)
|
libs/unet_3d_blocks.py
ADDED
@@ -0,0 +1,2463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
from diffusers.utils import is_torch_version
|
21 |
+
from diffusers.utils.torch_utils import apply_freeu
|
22 |
+
from diffusers.models.attention import Attention
|
23 |
+
from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel
|
24 |
+
from diffusers.models.resnet import (
|
25 |
+
Downsample2D,
|
26 |
+
ResnetBlock2D,
|
27 |
+
SpatioTemporalResBlock,
|
28 |
+
TemporalConvLayer,
|
29 |
+
Upsample2D,
|
30 |
+
)
|
31 |
+
from diffusers.models.transformers.transformer_2d import Transformer2DModel
|
32 |
+
from diffusers.models.transformers.transformer_temporal import (
|
33 |
+
TransformerSpatioTemporalModel,
|
34 |
+
)
|
35 |
+
from libs.transformer_temporal import TransformerTemporalModel
|
36 |
+
|
37 |
+
def get_down_block(
|
38 |
+
down_block_type: str,
|
39 |
+
num_layers: int,
|
40 |
+
in_channels: int,
|
41 |
+
out_channels: int,
|
42 |
+
temb_channels: int,
|
43 |
+
add_downsample: bool,
|
44 |
+
resnet_eps: float,
|
45 |
+
resnet_act_fn: str,
|
46 |
+
num_attention_heads: int,
|
47 |
+
resnet_groups: Optional[int] = None,
|
48 |
+
cross_attention_dim: Optional[int] = None,
|
49 |
+
downsample_padding: Optional[int] = None,
|
50 |
+
dual_cross_attention: bool = False,
|
51 |
+
use_linear_projection: bool = True,
|
52 |
+
only_cross_attention: bool = False,
|
53 |
+
upcast_attention: bool = False,
|
54 |
+
resnet_time_scale_shift: str = "default",
|
55 |
+
temporal_num_attention_heads: int = 8,
|
56 |
+
temporal_max_seq_length: int = 32,
|
57 |
+
transformer_layers_per_block: int = 1,
|
58 |
+
) -> Union[
|
59 |
+
"DownBlock3D",
|
60 |
+
"CrossAttnDownBlock3D",
|
61 |
+
"DownBlockMotion",
|
62 |
+
"CrossAttnDownBlockMotion",
|
63 |
+
"DownBlockSpatioTemporal",
|
64 |
+
"CrossAttnDownBlockSpatioTemporal",
|
65 |
+
]:
|
66 |
+
if down_block_type == "DownBlock3D":
|
67 |
+
return DownBlock3D(
|
68 |
+
num_layers=num_layers,
|
69 |
+
in_channels=in_channels,
|
70 |
+
out_channels=out_channels,
|
71 |
+
temb_channels=temb_channels,
|
72 |
+
add_downsample=add_downsample,
|
73 |
+
resnet_eps=resnet_eps,
|
74 |
+
resnet_act_fn=resnet_act_fn,
|
75 |
+
resnet_groups=resnet_groups,
|
76 |
+
downsample_padding=downsample_padding,
|
77 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
78 |
+
)
|
79 |
+
elif down_block_type == "CrossAttnDownBlock3D":
|
80 |
+
if cross_attention_dim is None:
|
81 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
|
82 |
+
return CrossAttnDownBlock3D(
|
83 |
+
num_layers=num_layers,
|
84 |
+
in_channels=in_channels,
|
85 |
+
out_channels=out_channels,
|
86 |
+
temb_channels=temb_channels,
|
87 |
+
add_downsample=add_downsample,
|
88 |
+
resnet_eps=resnet_eps,
|
89 |
+
resnet_act_fn=resnet_act_fn,
|
90 |
+
resnet_groups=resnet_groups,
|
91 |
+
downsample_padding=downsample_padding,
|
92 |
+
cross_attention_dim=cross_attention_dim,
|
93 |
+
num_attention_heads=num_attention_heads,
|
94 |
+
dual_cross_attention=dual_cross_attention,
|
95 |
+
use_linear_projection=use_linear_projection,
|
96 |
+
only_cross_attention=only_cross_attention,
|
97 |
+
upcast_attention=upcast_attention,
|
98 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
99 |
+
)
|
100 |
+
if down_block_type == "DownBlockMotion":
|
101 |
+
return DownBlockMotion(
|
102 |
+
num_layers=num_layers,
|
103 |
+
in_channels=in_channels,
|
104 |
+
out_channels=out_channels,
|
105 |
+
temb_channels=temb_channels,
|
106 |
+
add_downsample=add_downsample,
|
107 |
+
resnet_eps=resnet_eps,
|
108 |
+
resnet_act_fn=resnet_act_fn,
|
109 |
+
resnet_groups=resnet_groups,
|
110 |
+
downsample_padding=downsample_padding,
|
111 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
112 |
+
temporal_num_attention_heads=temporal_num_attention_heads,
|
113 |
+
temporal_max_seq_length=temporal_max_seq_length,
|
114 |
+
)
|
115 |
+
elif down_block_type == "CrossAttnDownBlockMotion":
|
116 |
+
if cross_attention_dim is None:
|
117 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion")
|
118 |
+
return CrossAttnDownBlockMotion(
|
119 |
+
num_layers=num_layers,
|
120 |
+
in_channels=in_channels,
|
121 |
+
out_channels=out_channels,
|
122 |
+
temb_channels=temb_channels,
|
123 |
+
add_downsample=add_downsample,
|
124 |
+
resnet_eps=resnet_eps,
|
125 |
+
resnet_act_fn=resnet_act_fn,
|
126 |
+
resnet_groups=resnet_groups,
|
127 |
+
downsample_padding=downsample_padding,
|
128 |
+
cross_attention_dim=cross_attention_dim,
|
129 |
+
num_attention_heads=num_attention_heads,
|
130 |
+
dual_cross_attention=dual_cross_attention,
|
131 |
+
use_linear_projection=use_linear_projection,
|
132 |
+
only_cross_attention=only_cross_attention,
|
133 |
+
upcast_attention=upcast_attention,
|
134 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
135 |
+
temporal_num_attention_heads=temporal_num_attention_heads,
|
136 |
+
temporal_max_seq_length=temporal_max_seq_length,
|
137 |
+
)
|
138 |
+
elif down_block_type == "DownBlockSpatioTemporal":
|
139 |
+
# added for SDV
|
140 |
+
return DownBlockSpatioTemporal(
|
141 |
+
num_layers=num_layers,
|
142 |
+
in_channels=in_channels,
|
143 |
+
out_channels=out_channels,
|
144 |
+
temb_channels=temb_channels,
|
145 |
+
add_downsample=add_downsample,
|
146 |
+
)
|
147 |
+
elif down_block_type == "CrossAttnDownBlockSpatioTemporal":
|
148 |
+
# added for SDV
|
149 |
+
if cross_attention_dim is None:
|
150 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal")
|
151 |
+
return CrossAttnDownBlockSpatioTemporal(
|
152 |
+
in_channels=in_channels,
|
153 |
+
out_channels=out_channels,
|
154 |
+
temb_channels=temb_channels,
|
155 |
+
num_layers=num_layers,
|
156 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
157 |
+
add_downsample=add_downsample,
|
158 |
+
cross_attention_dim=cross_attention_dim,
|
159 |
+
num_attention_heads=num_attention_heads,
|
160 |
+
)
|
161 |
+
|
162 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
163 |
+
|
164 |
+
|
165 |
+
def get_up_block(
|
166 |
+
up_block_type: str,
|
167 |
+
num_layers: int,
|
168 |
+
in_channels: int,
|
169 |
+
out_channels: int,
|
170 |
+
prev_output_channel: int,
|
171 |
+
temb_channels: int,
|
172 |
+
add_upsample: bool,
|
173 |
+
resnet_eps: float,
|
174 |
+
resnet_act_fn: str,
|
175 |
+
num_attention_heads: int,
|
176 |
+
resolution_idx: Optional[int] = None,
|
177 |
+
resnet_groups: Optional[int] = None,
|
178 |
+
cross_attention_dim: Optional[int] = None,
|
179 |
+
dual_cross_attention: bool = False,
|
180 |
+
use_linear_projection: bool = True,
|
181 |
+
only_cross_attention: bool = False,
|
182 |
+
upcast_attention: bool = False,
|
183 |
+
resnet_time_scale_shift: str = "default",
|
184 |
+
temporal_num_attention_heads: int = 8,
|
185 |
+
temporal_cross_attention_dim: Optional[int] = None,
|
186 |
+
temporal_max_seq_length: int = 32,
|
187 |
+
transformer_layers_per_block: int = 1,
|
188 |
+
dropout: float = 0.0,
|
189 |
+
) -> Union[
|
190 |
+
"UpBlock3D",
|
191 |
+
"CrossAttnUpBlock3D",
|
192 |
+
"UpBlockMotion",
|
193 |
+
"CrossAttnUpBlockMotion",
|
194 |
+
"UpBlockSpatioTemporal",
|
195 |
+
"CrossAttnUpBlockSpatioTemporal",
|
196 |
+
]:
|
197 |
+
if up_block_type == "UpBlock3D":
|
198 |
+
return UpBlock3D(
|
199 |
+
num_layers=num_layers,
|
200 |
+
in_channels=in_channels,
|
201 |
+
out_channels=out_channels,
|
202 |
+
prev_output_channel=prev_output_channel,
|
203 |
+
temb_channels=temb_channels,
|
204 |
+
add_upsample=add_upsample,
|
205 |
+
resnet_eps=resnet_eps,
|
206 |
+
resnet_act_fn=resnet_act_fn,
|
207 |
+
resnet_groups=resnet_groups,
|
208 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
209 |
+
resolution_idx=resolution_idx,
|
210 |
+
)
|
211 |
+
elif up_block_type == "CrossAttnUpBlock3D":
|
212 |
+
if cross_attention_dim is None:
|
213 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
|
214 |
+
return CrossAttnUpBlock3D(
|
215 |
+
num_layers=num_layers,
|
216 |
+
in_channels=in_channels,
|
217 |
+
out_channels=out_channels,
|
218 |
+
prev_output_channel=prev_output_channel,
|
219 |
+
temb_channels=temb_channels,
|
220 |
+
add_upsample=add_upsample,
|
221 |
+
resnet_eps=resnet_eps,
|
222 |
+
resnet_act_fn=resnet_act_fn,
|
223 |
+
resnet_groups=resnet_groups,
|
224 |
+
cross_attention_dim=cross_attention_dim,
|
225 |
+
num_attention_heads=num_attention_heads,
|
226 |
+
dual_cross_attention=dual_cross_attention,
|
227 |
+
use_linear_projection=use_linear_projection,
|
228 |
+
only_cross_attention=only_cross_attention,
|
229 |
+
upcast_attention=upcast_attention,
|
230 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
231 |
+
resolution_idx=resolution_idx,
|
232 |
+
)
|
233 |
+
if up_block_type == "UpBlockMotion":
|
234 |
+
return UpBlockMotion(
|
235 |
+
num_layers=num_layers,
|
236 |
+
in_channels=in_channels,
|
237 |
+
out_channels=out_channels,
|
238 |
+
prev_output_channel=prev_output_channel,
|
239 |
+
temb_channels=temb_channels,
|
240 |
+
add_upsample=add_upsample,
|
241 |
+
resnet_eps=resnet_eps,
|
242 |
+
resnet_act_fn=resnet_act_fn,
|
243 |
+
resnet_groups=resnet_groups,
|
244 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
245 |
+
resolution_idx=resolution_idx,
|
246 |
+
temporal_num_attention_heads=temporal_num_attention_heads,
|
247 |
+
temporal_max_seq_length=temporal_max_seq_length,
|
248 |
+
)
|
249 |
+
elif up_block_type == "CrossAttnUpBlockMotion":
|
250 |
+
if cross_attention_dim is None:
|
251 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion")
|
252 |
+
return CrossAttnUpBlockMotion(
|
253 |
+
num_layers=num_layers,
|
254 |
+
in_channels=in_channels,
|
255 |
+
out_channels=out_channels,
|
256 |
+
prev_output_channel=prev_output_channel,
|
257 |
+
temb_channels=temb_channels,
|
258 |
+
add_upsample=add_upsample,
|
259 |
+
resnet_eps=resnet_eps,
|
260 |
+
resnet_act_fn=resnet_act_fn,
|
261 |
+
resnet_groups=resnet_groups,
|
262 |
+
cross_attention_dim=cross_attention_dim,
|
263 |
+
num_attention_heads=num_attention_heads,
|
264 |
+
dual_cross_attention=dual_cross_attention,
|
265 |
+
use_linear_projection=use_linear_projection,
|
266 |
+
only_cross_attention=only_cross_attention,
|
267 |
+
upcast_attention=upcast_attention,
|
268 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
269 |
+
resolution_idx=resolution_idx,
|
270 |
+
temporal_num_attention_heads=temporal_num_attention_heads,
|
271 |
+
temporal_max_seq_length=temporal_max_seq_length,
|
272 |
+
)
|
273 |
+
elif up_block_type == "UpBlockSpatioTemporal":
|
274 |
+
# added for SDV
|
275 |
+
return UpBlockSpatioTemporal(
|
276 |
+
num_layers=num_layers,
|
277 |
+
in_channels=in_channels,
|
278 |
+
out_channels=out_channels,
|
279 |
+
prev_output_channel=prev_output_channel,
|
280 |
+
temb_channels=temb_channels,
|
281 |
+
resolution_idx=resolution_idx,
|
282 |
+
add_upsample=add_upsample,
|
283 |
+
)
|
284 |
+
elif up_block_type == "CrossAttnUpBlockSpatioTemporal":
|
285 |
+
# added for SDV
|
286 |
+
if cross_attention_dim is None:
|
287 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal")
|
288 |
+
return CrossAttnUpBlockSpatioTemporal(
|
289 |
+
in_channels=in_channels,
|
290 |
+
out_channels=out_channels,
|
291 |
+
prev_output_channel=prev_output_channel,
|
292 |
+
temb_channels=temb_channels,
|
293 |
+
num_layers=num_layers,
|
294 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
295 |
+
add_upsample=add_upsample,
|
296 |
+
cross_attention_dim=cross_attention_dim,
|
297 |
+
num_attention_heads=num_attention_heads,
|
298 |
+
resolution_idx=resolution_idx,
|
299 |
+
)
|
300 |
+
|
301 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
302 |
+
|
303 |
+
|
304 |
+
class UNetMidBlock3DCrossAttn(nn.Module):
|
305 |
+
def __init__(
|
306 |
+
self,
|
307 |
+
in_channels: int,
|
308 |
+
temb_channels: int,
|
309 |
+
dropout: float = 0.0,
|
310 |
+
num_layers: int = 1,
|
311 |
+
resnet_eps: float = 1e-6,
|
312 |
+
resnet_time_scale_shift: str = "default",
|
313 |
+
resnet_act_fn: str = "swish",
|
314 |
+
resnet_groups: int = 32,
|
315 |
+
resnet_pre_norm: bool = True,
|
316 |
+
num_attention_heads: int = 1,
|
317 |
+
output_scale_factor: float = 1.0,
|
318 |
+
cross_attention_dim: int = 1280,
|
319 |
+
dual_cross_attention: bool = False,
|
320 |
+
use_linear_projection: bool = True,
|
321 |
+
upcast_attention: bool = False,
|
322 |
+
):
|
323 |
+
super().__init__()
|
324 |
+
|
325 |
+
self.has_cross_attention = True
|
326 |
+
self.num_attention_heads = num_attention_heads
|
327 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
328 |
+
|
329 |
+
# there is always at least one resnet
|
330 |
+
resnets = [
|
331 |
+
ResnetBlock2D(
|
332 |
+
in_channels=in_channels,
|
333 |
+
out_channels=in_channels,
|
334 |
+
temb_channels=temb_channels,
|
335 |
+
eps=resnet_eps,
|
336 |
+
groups=resnet_groups,
|
337 |
+
dropout=dropout,
|
338 |
+
time_embedding_norm=resnet_time_scale_shift,
|
339 |
+
non_linearity=resnet_act_fn,
|
340 |
+
output_scale_factor=output_scale_factor,
|
341 |
+
pre_norm=resnet_pre_norm,
|
342 |
+
)
|
343 |
+
]
|
344 |
+
temp_convs = [
|
345 |
+
TemporalConvLayer(
|
346 |
+
in_channels,
|
347 |
+
in_channels,
|
348 |
+
dropout=0.1,
|
349 |
+
norm_num_groups=resnet_groups,
|
350 |
+
)
|
351 |
+
]
|
352 |
+
attentions = []
|
353 |
+
temp_attentions = []
|
354 |
+
|
355 |
+
for _ in range(num_layers):
|
356 |
+
attentions.append(
|
357 |
+
Transformer2DModel(
|
358 |
+
in_channels // num_attention_heads,
|
359 |
+
num_attention_heads,
|
360 |
+
in_channels=in_channels,
|
361 |
+
num_layers=1,
|
362 |
+
cross_attention_dim=cross_attention_dim,
|
363 |
+
norm_num_groups=resnet_groups,
|
364 |
+
use_linear_projection=use_linear_projection,
|
365 |
+
upcast_attention=upcast_attention,
|
366 |
+
)
|
367 |
+
)
|
368 |
+
temp_attentions.append(
|
369 |
+
TransformerTemporalModel(
|
370 |
+
in_channels // num_attention_heads,
|
371 |
+
num_attention_heads,
|
372 |
+
in_channels=in_channels,
|
373 |
+
num_layers=1,
|
374 |
+
cross_attention_dim=cross_attention_dim,
|
375 |
+
norm_num_groups=resnet_groups,
|
376 |
+
)
|
377 |
+
)
|
378 |
+
resnets.append(
|
379 |
+
ResnetBlock2D(
|
380 |
+
in_channels=in_channels,
|
381 |
+
out_channels=in_channels,
|
382 |
+
temb_channels=temb_channels,
|
383 |
+
eps=resnet_eps,
|
384 |
+
groups=resnet_groups,
|
385 |
+
dropout=dropout,
|
386 |
+
time_embedding_norm=resnet_time_scale_shift,
|
387 |
+
non_linearity=resnet_act_fn,
|
388 |
+
output_scale_factor=output_scale_factor,
|
389 |
+
pre_norm=resnet_pre_norm,
|
390 |
+
)
|
391 |
+
)
|
392 |
+
temp_convs.append(
|
393 |
+
TemporalConvLayer(
|
394 |
+
in_channels,
|
395 |
+
in_channels,
|
396 |
+
dropout=0.1,
|
397 |
+
norm_num_groups=resnet_groups,
|
398 |
+
)
|
399 |
+
)
|
400 |
+
|
401 |
+
self.resnets = nn.ModuleList(resnets)
|
402 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
403 |
+
self.attentions = nn.ModuleList(attentions)
|
404 |
+
self.temp_attentions = nn.ModuleList(temp_attentions)
|
405 |
+
|
406 |
+
def forward(
|
407 |
+
self,
|
408 |
+
hidden_states: torch.FloatTensor,
|
409 |
+
temb: Optional[torch.FloatTensor] = None,
|
410 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
411 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
412 |
+
num_frames: int = 1,
|
413 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
414 |
+
) -> torch.FloatTensor:
|
415 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
416 |
+
hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
|
417 |
+
for attn, temp_attn, resnet, temp_conv in zip(
|
418 |
+
self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
|
419 |
+
):
|
420 |
+
hidden_states = attn(
|
421 |
+
hidden_states,
|
422 |
+
encoder_hidden_states=encoder_hidden_states,
|
423 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
424 |
+
return_dict=False,
|
425 |
+
)[0]
|
426 |
+
hidden_states = temp_attn(
|
427 |
+
hidden_states,
|
428 |
+
num_frames=num_frames,
|
429 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
430 |
+
return_dict=False,
|
431 |
+
)[0]
|
432 |
+
hidden_states = resnet(hidden_states, temb)
|
433 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
434 |
+
|
435 |
+
return hidden_states
|
436 |
+
|
437 |
+
|
438 |
+
class CrossAttnDownBlock3D(nn.Module):
|
439 |
+
def __init__(
|
440 |
+
self,
|
441 |
+
in_channels: int,
|
442 |
+
out_channels: int,
|
443 |
+
temb_channels: int,
|
444 |
+
dropout: float = 0.0,
|
445 |
+
num_layers: int = 1,
|
446 |
+
resnet_eps: float = 1e-6,
|
447 |
+
resnet_time_scale_shift: str = "default",
|
448 |
+
resnet_act_fn: str = "swish",
|
449 |
+
resnet_groups: int = 32,
|
450 |
+
resnet_pre_norm: bool = True,
|
451 |
+
num_attention_heads: int = 1,
|
452 |
+
cross_attention_dim: int = 1280,
|
453 |
+
output_scale_factor: float = 1.0,
|
454 |
+
downsample_padding: int = 1,
|
455 |
+
add_downsample: bool = True,
|
456 |
+
dual_cross_attention: bool = False,
|
457 |
+
use_linear_projection: bool = False,
|
458 |
+
only_cross_attention: bool = False,
|
459 |
+
upcast_attention: bool = False,
|
460 |
+
):
|
461 |
+
super().__init__()
|
462 |
+
resnets = []
|
463 |
+
attentions = []
|
464 |
+
temp_attentions = []
|
465 |
+
temp_convs = []
|
466 |
+
|
467 |
+
self.has_cross_attention = True
|
468 |
+
self.num_attention_heads = num_attention_heads
|
469 |
+
|
470 |
+
for i in range(num_layers):
|
471 |
+
in_channels = in_channels if i == 0 else out_channels
|
472 |
+
resnets.append(
|
473 |
+
ResnetBlock2D(
|
474 |
+
in_channels=in_channels,
|
475 |
+
out_channels=out_channels,
|
476 |
+
temb_channels=temb_channels,
|
477 |
+
eps=resnet_eps,
|
478 |
+
groups=resnet_groups,
|
479 |
+
dropout=dropout,
|
480 |
+
time_embedding_norm=resnet_time_scale_shift,
|
481 |
+
non_linearity=resnet_act_fn,
|
482 |
+
output_scale_factor=output_scale_factor,
|
483 |
+
pre_norm=resnet_pre_norm,
|
484 |
+
)
|
485 |
+
)
|
486 |
+
temp_convs.append(
|
487 |
+
TemporalConvLayer(
|
488 |
+
out_channels,
|
489 |
+
out_channels,
|
490 |
+
dropout=0.1,
|
491 |
+
norm_num_groups=resnet_groups,
|
492 |
+
)
|
493 |
+
)
|
494 |
+
attentions.append(
|
495 |
+
Transformer2DModel(
|
496 |
+
out_channels // num_attention_heads,
|
497 |
+
num_attention_heads,
|
498 |
+
in_channels=out_channels,
|
499 |
+
num_layers=1,
|
500 |
+
cross_attention_dim=cross_attention_dim,
|
501 |
+
norm_num_groups=resnet_groups,
|
502 |
+
use_linear_projection=use_linear_projection,
|
503 |
+
only_cross_attention=only_cross_attention,
|
504 |
+
upcast_attention=upcast_attention,
|
505 |
+
)
|
506 |
+
)
|
507 |
+
temp_attentions.append(
|
508 |
+
TransformerTemporalModel(
|
509 |
+
out_channels // num_attention_heads,
|
510 |
+
num_attention_heads,
|
511 |
+
in_channels=out_channels,
|
512 |
+
num_layers=1,
|
513 |
+
cross_attention_dim=cross_attention_dim,
|
514 |
+
norm_num_groups=resnet_groups,
|
515 |
+
)
|
516 |
+
)
|
517 |
+
self.resnets = nn.ModuleList(resnets)
|
518 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
519 |
+
self.attentions = nn.ModuleList(attentions)
|
520 |
+
self.temp_attentions = nn.ModuleList(temp_attentions)
|
521 |
+
|
522 |
+
if add_downsample:
|
523 |
+
self.downsamplers = nn.ModuleList(
|
524 |
+
[
|
525 |
+
Downsample2D(
|
526 |
+
out_channels,
|
527 |
+
use_conv=True,
|
528 |
+
out_channels=out_channels,
|
529 |
+
padding=downsample_padding,
|
530 |
+
name="op",
|
531 |
+
)
|
532 |
+
]
|
533 |
+
)
|
534 |
+
else:
|
535 |
+
self.downsamplers = None
|
536 |
+
|
537 |
+
self.gradient_checkpointing = False
|
538 |
+
|
539 |
+
def forward(
|
540 |
+
self,
|
541 |
+
hidden_states: torch.FloatTensor,
|
542 |
+
temb: Optional[torch.FloatTensor] = None,
|
543 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
544 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
545 |
+
num_frames: int = 1,
|
546 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
547 |
+
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
548 |
+
# TODO(Patrick, William) - attention mask is not used
|
549 |
+
output_states = ()
|
550 |
+
|
551 |
+
for resnet, temp_conv, attn, temp_attn in zip(
|
552 |
+
self.resnets, self.temp_convs, self.attentions, self.temp_attentions
|
553 |
+
):
|
554 |
+
hidden_states = resnet(hidden_states, temb)
|
555 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
556 |
+
hidden_states = attn(
|
557 |
+
hidden_states,
|
558 |
+
encoder_hidden_states=encoder_hidden_states,
|
559 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
560 |
+
return_dict=False,
|
561 |
+
)[0]
|
562 |
+
hidden_states = temp_attn(
|
563 |
+
hidden_states,
|
564 |
+
num_frames=num_frames,
|
565 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
566 |
+
return_dict=False,
|
567 |
+
)[0]
|
568 |
+
|
569 |
+
output_states += (hidden_states,)
|
570 |
+
|
571 |
+
if self.downsamplers is not None:
|
572 |
+
for downsampler in self.downsamplers:
|
573 |
+
hidden_states = downsampler(hidden_states)
|
574 |
+
|
575 |
+
output_states += (hidden_states,)
|
576 |
+
|
577 |
+
return hidden_states, output_states
|
578 |
+
|
579 |
+
|
580 |
+
class DownBlock3D(nn.Module):
|
581 |
+
def __init__(
|
582 |
+
self,
|
583 |
+
in_channels: int,
|
584 |
+
out_channels: int,
|
585 |
+
temb_channels: int,
|
586 |
+
dropout: float = 0.0,
|
587 |
+
num_layers: int = 1,
|
588 |
+
resnet_eps: float = 1e-6,
|
589 |
+
resnet_time_scale_shift: str = "default",
|
590 |
+
resnet_act_fn: str = "swish",
|
591 |
+
resnet_groups: int = 32,
|
592 |
+
resnet_pre_norm: bool = True,
|
593 |
+
output_scale_factor: float = 1.0,
|
594 |
+
add_downsample: bool = True,
|
595 |
+
downsample_padding: int = 1,
|
596 |
+
):
|
597 |
+
super().__init__()
|
598 |
+
resnets = []
|
599 |
+
temp_convs = []
|
600 |
+
|
601 |
+
for i in range(num_layers):
|
602 |
+
in_channels = in_channels if i == 0 else out_channels
|
603 |
+
resnets.append(
|
604 |
+
ResnetBlock2D(
|
605 |
+
in_channels=in_channels,
|
606 |
+
out_channels=out_channels,
|
607 |
+
temb_channels=temb_channels,
|
608 |
+
eps=resnet_eps,
|
609 |
+
groups=resnet_groups,
|
610 |
+
dropout=dropout,
|
611 |
+
time_embedding_norm=resnet_time_scale_shift,
|
612 |
+
non_linearity=resnet_act_fn,
|
613 |
+
output_scale_factor=output_scale_factor,
|
614 |
+
pre_norm=resnet_pre_norm,
|
615 |
+
)
|
616 |
+
)
|
617 |
+
temp_convs.append(
|
618 |
+
TemporalConvLayer(
|
619 |
+
out_channels,
|
620 |
+
out_channels,
|
621 |
+
dropout=0.1,
|
622 |
+
norm_num_groups=resnet_groups,
|
623 |
+
)
|
624 |
+
)
|
625 |
+
|
626 |
+
self.resnets = nn.ModuleList(resnets)
|
627 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
628 |
+
|
629 |
+
if add_downsample:
|
630 |
+
self.downsamplers = nn.ModuleList(
|
631 |
+
[
|
632 |
+
Downsample2D(
|
633 |
+
out_channels,
|
634 |
+
use_conv=True,
|
635 |
+
out_channels=out_channels,
|
636 |
+
padding=downsample_padding,
|
637 |
+
name="op",
|
638 |
+
)
|
639 |
+
]
|
640 |
+
)
|
641 |
+
else:
|
642 |
+
self.downsamplers = None
|
643 |
+
|
644 |
+
self.gradient_checkpointing = False
|
645 |
+
|
646 |
+
def forward(
|
647 |
+
self,
|
648 |
+
hidden_states: torch.FloatTensor,
|
649 |
+
temb: Optional[torch.FloatTensor] = None,
|
650 |
+
num_frames: int = 1,
|
651 |
+
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
652 |
+
output_states = ()
|
653 |
+
|
654 |
+
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
|
655 |
+
hidden_states = resnet(hidden_states, temb)
|
656 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
657 |
+
|
658 |
+
output_states += (hidden_states,)
|
659 |
+
|
660 |
+
if self.downsamplers is not None:
|
661 |
+
for downsampler in self.downsamplers:
|
662 |
+
hidden_states = downsampler(hidden_states)
|
663 |
+
|
664 |
+
output_states += (hidden_states,)
|
665 |
+
|
666 |
+
return hidden_states, output_states
|
667 |
+
|
668 |
+
|
669 |
+
class CrossAttnUpBlock3D(nn.Module):
|
670 |
+
def __init__(
|
671 |
+
self,
|
672 |
+
in_channels: int,
|
673 |
+
out_channels: int,
|
674 |
+
prev_output_channel: int,
|
675 |
+
temb_channels: int,
|
676 |
+
dropout: float = 0.0,
|
677 |
+
num_layers: int = 1,
|
678 |
+
resnet_eps: float = 1e-6,
|
679 |
+
resnet_time_scale_shift: str = "default",
|
680 |
+
resnet_act_fn: str = "swish",
|
681 |
+
resnet_groups: int = 32,
|
682 |
+
resnet_pre_norm: bool = True,
|
683 |
+
num_attention_heads: int = 1,
|
684 |
+
cross_attention_dim: int = 1280,
|
685 |
+
output_scale_factor: float = 1.0,
|
686 |
+
add_upsample: bool = True,
|
687 |
+
dual_cross_attention: bool = False,
|
688 |
+
use_linear_projection: bool = False,
|
689 |
+
only_cross_attention: bool = False,
|
690 |
+
upcast_attention: bool = False,
|
691 |
+
resolution_idx: Optional[int] = None,
|
692 |
+
):
|
693 |
+
super().__init__()
|
694 |
+
resnets = []
|
695 |
+
temp_convs = []
|
696 |
+
attentions = []
|
697 |
+
temp_attentions = []
|
698 |
+
|
699 |
+
self.has_cross_attention = True
|
700 |
+
self.num_attention_heads = num_attention_heads
|
701 |
+
|
702 |
+
for i in range(num_layers):
|
703 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
704 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
705 |
+
|
706 |
+
resnets.append(
|
707 |
+
ResnetBlock2D(
|
708 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
709 |
+
out_channels=out_channels,
|
710 |
+
temb_channels=temb_channels,
|
711 |
+
eps=resnet_eps,
|
712 |
+
groups=resnet_groups,
|
713 |
+
dropout=dropout,
|
714 |
+
time_embedding_norm=resnet_time_scale_shift,
|
715 |
+
non_linearity=resnet_act_fn,
|
716 |
+
output_scale_factor=output_scale_factor,
|
717 |
+
pre_norm=resnet_pre_norm,
|
718 |
+
)
|
719 |
+
)
|
720 |
+
temp_convs.append(
|
721 |
+
TemporalConvLayer(
|
722 |
+
out_channels,
|
723 |
+
out_channels,
|
724 |
+
dropout=0.1,
|
725 |
+
norm_num_groups=resnet_groups,
|
726 |
+
)
|
727 |
+
)
|
728 |
+
attentions.append(
|
729 |
+
Transformer2DModel(
|
730 |
+
out_channels // num_attention_heads,
|
731 |
+
num_attention_heads,
|
732 |
+
in_channels=out_channels,
|
733 |
+
num_layers=1,
|
734 |
+
cross_attention_dim=cross_attention_dim,
|
735 |
+
norm_num_groups=resnet_groups,
|
736 |
+
use_linear_projection=use_linear_projection,
|
737 |
+
only_cross_attention=only_cross_attention,
|
738 |
+
upcast_attention=upcast_attention,
|
739 |
+
)
|
740 |
+
)
|
741 |
+
temp_attentions.append(
|
742 |
+
TransformerTemporalModel(
|
743 |
+
out_channels // num_attention_heads,
|
744 |
+
num_attention_heads,
|
745 |
+
in_channels=out_channels,
|
746 |
+
num_layers=1,
|
747 |
+
cross_attention_dim=cross_attention_dim,
|
748 |
+
norm_num_groups=resnet_groups,
|
749 |
+
)
|
750 |
+
)
|
751 |
+
self.resnets = nn.ModuleList(resnets)
|
752 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
753 |
+
self.attentions = nn.ModuleList(attentions)
|
754 |
+
self.temp_attentions = nn.ModuleList(temp_attentions)
|
755 |
+
|
756 |
+
if add_upsample:
|
757 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
758 |
+
else:
|
759 |
+
self.upsamplers = None
|
760 |
+
|
761 |
+
self.gradient_checkpointing = False
|
762 |
+
self.resolution_idx = resolution_idx
|
763 |
+
|
764 |
+
def forward(
|
765 |
+
self,
|
766 |
+
hidden_states: torch.FloatTensor,
|
767 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
768 |
+
temb: Optional[torch.FloatTensor] = None,
|
769 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
770 |
+
upsample_size: Optional[int] = None,
|
771 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
772 |
+
num_frames: int = 1,
|
773 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
774 |
+
) -> torch.FloatTensor:
|
775 |
+
is_freeu_enabled = (
|
776 |
+
getattr(self, "s1", None)
|
777 |
+
and getattr(self, "s2", None)
|
778 |
+
and getattr(self, "b1", None)
|
779 |
+
and getattr(self, "b2", None)
|
780 |
+
)
|
781 |
+
|
782 |
+
# TODO(Patrick, William) - attention mask is not used
|
783 |
+
for resnet, temp_conv, attn, temp_attn in zip(
|
784 |
+
self.resnets, self.temp_convs, self.attentions, self.temp_attentions
|
785 |
+
):
|
786 |
+
# pop res hidden states
|
787 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
788 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
789 |
+
|
790 |
+
# FreeU: Only operate on the first two stages
|
791 |
+
if is_freeu_enabled:
|
792 |
+
hidden_states, res_hidden_states = apply_freeu(
|
793 |
+
self.resolution_idx,
|
794 |
+
hidden_states,
|
795 |
+
res_hidden_states,
|
796 |
+
s1=self.s1,
|
797 |
+
s2=self.s2,
|
798 |
+
b1=self.b1,
|
799 |
+
b2=self.b2,
|
800 |
+
)
|
801 |
+
|
802 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
803 |
+
|
804 |
+
hidden_states = resnet(hidden_states, temb)
|
805 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
806 |
+
hidden_states = attn(
|
807 |
+
hidden_states,
|
808 |
+
encoder_hidden_states=encoder_hidden_states,
|
809 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
810 |
+
return_dict=False,
|
811 |
+
)[0]
|
812 |
+
hidden_states = temp_attn(
|
813 |
+
hidden_states,
|
814 |
+
num_frames=num_frames,
|
815 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
816 |
+
return_dict=False,
|
817 |
+
)[0]
|
818 |
+
|
819 |
+
if self.upsamplers is not None:
|
820 |
+
for upsampler in self.upsamplers:
|
821 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
822 |
+
|
823 |
+
return hidden_states
|
824 |
+
|
825 |
+
|
826 |
+
class UpBlock3D(nn.Module):
|
827 |
+
def __init__(
|
828 |
+
self,
|
829 |
+
in_channels: int,
|
830 |
+
prev_output_channel: int,
|
831 |
+
out_channels: int,
|
832 |
+
temb_channels: int,
|
833 |
+
dropout: float = 0.0,
|
834 |
+
num_layers: int = 1,
|
835 |
+
resnet_eps: float = 1e-6,
|
836 |
+
resnet_time_scale_shift: str = "default",
|
837 |
+
resnet_act_fn: str = "swish",
|
838 |
+
resnet_groups: int = 32,
|
839 |
+
resnet_pre_norm: bool = True,
|
840 |
+
output_scale_factor: float = 1.0,
|
841 |
+
add_upsample: bool = True,
|
842 |
+
resolution_idx: Optional[int] = None,
|
843 |
+
):
|
844 |
+
super().__init__()
|
845 |
+
resnets = []
|
846 |
+
temp_convs = []
|
847 |
+
|
848 |
+
for i in range(num_layers):
|
849 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
850 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
851 |
+
|
852 |
+
resnets.append(
|
853 |
+
ResnetBlock2D(
|
854 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
855 |
+
out_channels=out_channels,
|
856 |
+
temb_channels=temb_channels,
|
857 |
+
eps=resnet_eps,
|
858 |
+
groups=resnet_groups,
|
859 |
+
dropout=dropout,
|
860 |
+
time_embedding_norm=resnet_time_scale_shift,
|
861 |
+
non_linearity=resnet_act_fn,
|
862 |
+
output_scale_factor=output_scale_factor,
|
863 |
+
pre_norm=resnet_pre_norm,
|
864 |
+
)
|
865 |
+
)
|
866 |
+
temp_convs.append(
|
867 |
+
TemporalConvLayer(
|
868 |
+
out_channels,
|
869 |
+
out_channels,
|
870 |
+
dropout=0.1,
|
871 |
+
norm_num_groups=resnet_groups,
|
872 |
+
)
|
873 |
+
)
|
874 |
+
|
875 |
+
self.resnets = nn.ModuleList(resnets)
|
876 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
877 |
+
|
878 |
+
if add_upsample:
|
879 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
880 |
+
else:
|
881 |
+
self.upsamplers = None
|
882 |
+
|
883 |
+
self.gradient_checkpointing = False
|
884 |
+
self.resolution_idx = resolution_idx
|
885 |
+
|
886 |
+
def forward(
|
887 |
+
self,
|
888 |
+
hidden_states: torch.FloatTensor,
|
889 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
890 |
+
temb: Optional[torch.FloatTensor] = None,
|
891 |
+
upsample_size: Optional[int] = None,
|
892 |
+
num_frames: int = 1,
|
893 |
+
) -> torch.FloatTensor:
|
894 |
+
is_freeu_enabled = (
|
895 |
+
getattr(self, "s1", None)
|
896 |
+
and getattr(self, "s2", None)
|
897 |
+
and getattr(self, "b1", None)
|
898 |
+
and getattr(self, "b2", None)
|
899 |
+
)
|
900 |
+
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
|
901 |
+
# pop res hidden states
|
902 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
903 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
904 |
+
|
905 |
+
# FreeU: Only operate on the first two stages
|
906 |
+
if is_freeu_enabled:
|
907 |
+
hidden_states, res_hidden_states = apply_freeu(
|
908 |
+
self.resolution_idx,
|
909 |
+
hidden_states,
|
910 |
+
res_hidden_states,
|
911 |
+
s1=self.s1,
|
912 |
+
s2=self.s2,
|
913 |
+
b1=self.b1,
|
914 |
+
b2=self.b2,
|
915 |
+
)
|
916 |
+
|
917 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
918 |
+
|
919 |
+
hidden_states = resnet(hidden_states, temb)
|
920 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
921 |
+
|
922 |
+
if self.upsamplers is not None:
|
923 |
+
for upsampler in self.upsamplers:
|
924 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
925 |
+
|
926 |
+
return hidden_states
|
927 |
+
|
928 |
+
|
929 |
+
class DownBlockMotion(nn.Module):
|
930 |
+
def __init__(
|
931 |
+
self,
|
932 |
+
in_channels: int,
|
933 |
+
out_channels: int,
|
934 |
+
temb_channels: int,
|
935 |
+
dropout: float = 0.0,
|
936 |
+
num_layers: int = 1,
|
937 |
+
resnet_eps: float = 1e-6,
|
938 |
+
resnet_time_scale_shift: str = "default",
|
939 |
+
resnet_act_fn: str = "swish",
|
940 |
+
resnet_groups: int = 32,
|
941 |
+
resnet_pre_norm: bool = True,
|
942 |
+
output_scale_factor: float = 1.0,
|
943 |
+
add_downsample: bool = True,
|
944 |
+
downsample_padding: int = 1,
|
945 |
+
temporal_num_attention_heads: int = 1,
|
946 |
+
temporal_cross_attention_dim: Optional[int] = None,
|
947 |
+
temporal_max_seq_length: int = 32,
|
948 |
+
):
|
949 |
+
super().__init__()
|
950 |
+
resnets = []
|
951 |
+
motion_modules = []
|
952 |
+
|
953 |
+
for i in range(num_layers):
|
954 |
+
in_channels = in_channels if i == 0 else out_channels
|
955 |
+
resnets.append(
|
956 |
+
ResnetBlock2D(
|
957 |
+
in_channels=in_channels,
|
958 |
+
out_channels=out_channels,
|
959 |
+
temb_channels=temb_channels,
|
960 |
+
eps=resnet_eps,
|
961 |
+
groups=resnet_groups,
|
962 |
+
dropout=dropout,
|
963 |
+
time_embedding_norm=resnet_time_scale_shift,
|
964 |
+
non_linearity=resnet_act_fn,
|
965 |
+
output_scale_factor=output_scale_factor,
|
966 |
+
pre_norm=resnet_pre_norm,
|
967 |
+
)
|
968 |
+
)
|
969 |
+
motion_modules.append(
|
970 |
+
TransformerTemporalModel(
|
971 |
+
num_attention_heads=temporal_num_attention_heads,
|
972 |
+
in_channels=out_channels,
|
973 |
+
norm_num_groups=resnet_groups,
|
974 |
+
cross_attention_dim=temporal_cross_attention_dim,
|
975 |
+
attention_bias=False,
|
976 |
+
activation_fn="geglu",
|
977 |
+
positional_embeddings="sinusoidal",
|
978 |
+
num_positional_embeddings=temporal_max_seq_length,
|
979 |
+
attention_head_dim=out_channels // temporal_num_attention_heads,
|
980 |
+
)
|
981 |
+
)
|
982 |
+
|
983 |
+
self.resnets = nn.ModuleList(resnets)
|
984 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
985 |
+
|
986 |
+
if add_downsample:
|
987 |
+
self.downsamplers = nn.ModuleList(
|
988 |
+
[
|
989 |
+
Downsample2D(
|
990 |
+
out_channels,
|
991 |
+
use_conv=True,
|
992 |
+
out_channels=out_channels,
|
993 |
+
padding=downsample_padding,
|
994 |
+
name="op",
|
995 |
+
)
|
996 |
+
]
|
997 |
+
)
|
998 |
+
else:
|
999 |
+
self.downsamplers = None
|
1000 |
+
|
1001 |
+
self.gradient_checkpointing = False
|
1002 |
+
|
1003 |
+
def forward(
|
1004 |
+
self,
|
1005 |
+
hidden_states: torch.FloatTensor,
|
1006 |
+
down_block_add_samples: Optional[torch.FloatTensor] = None,
|
1007 |
+
temb: Optional[torch.FloatTensor] = None,
|
1008 |
+
scale: float = 1.0,
|
1009 |
+
num_frames: int = 1,
|
1010 |
+
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
1011 |
+
output_states = ()
|
1012 |
+
|
1013 |
+
blocks = zip(self.resnets, self.motion_modules)
|
1014 |
+
for resnet, motion_module in blocks:
|
1015 |
+
if self.training and self.gradient_checkpointing:
|
1016 |
+
|
1017 |
+
def create_custom_forward(module, return_dict=None):
|
1018 |
+
def custom_forward(*inputs):
|
1019 |
+
if return_dict is not None:
|
1020 |
+
return module(*inputs, return_dict=return_dict)
|
1021 |
+
else:
|
1022 |
+
return module(*inputs)
|
1023 |
+
|
1024 |
+
return custom_forward
|
1025 |
+
|
1026 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1027 |
+
|
1028 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1029 |
+
create_custom_forward(resnet),
|
1030 |
+
hidden_states,
|
1031 |
+
temb,
|
1032 |
+
**ckpt_kwargs,
|
1033 |
+
)
|
1034 |
+
|
1035 |
+
if down_block_add_samples is not None:
|
1036 |
+
hidden_states = hidden_states + down_block_add_samples.pop(0)
|
1037 |
+
|
1038 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1039 |
+
create_custom_forward(motion_module),
|
1040 |
+
hidden_states.requires_grad_(),
|
1041 |
+
temb,
|
1042 |
+
num_frames,
|
1043 |
+
**ckpt_kwargs,
|
1044 |
+
)
|
1045 |
+
|
1046 |
+
else:
|
1047 |
+
hidden_states = resnet(hidden_states, temb, scale=scale)
|
1048 |
+
if down_block_add_samples is not None:
|
1049 |
+
hidden_states = hidden_states + down_block_add_samples.pop(0)
|
1050 |
+
hidden_states = motion_module(hidden_states, num_frames=num_frames)
|
1051 |
+
|
1052 |
+
output_states = output_states + (hidden_states,)
|
1053 |
+
|
1054 |
+
if self.downsamplers is not None:
|
1055 |
+
for downsampler in self.downsamplers:
|
1056 |
+
hidden_states = downsampler(hidden_states, scale=scale)
|
1057 |
+
|
1058 |
+
if down_block_add_samples is not None:
|
1059 |
+
hidden_states = hidden_states + down_block_add_samples.pop(0)
|
1060 |
+
|
1061 |
+
output_states = output_states + (hidden_states,)
|
1062 |
+
|
1063 |
+
return hidden_states, output_states
|
1064 |
+
|
1065 |
+
|
1066 |
+
class CrossAttnDownBlockMotion(nn.Module):
|
1067 |
+
def __init__(
|
1068 |
+
self,
|
1069 |
+
in_channels: int,
|
1070 |
+
out_channels: int,
|
1071 |
+
temb_channels: int,
|
1072 |
+
dropout: float = 0.0,
|
1073 |
+
num_layers: int = 1,
|
1074 |
+
transformer_layers_per_block: int = 1,
|
1075 |
+
resnet_eps: float = 1e-6,
|
1076 |
+
resnet_time_scale_shift: str = "default",
|
1077 |
+
resnet_act_fn: str = "swish",
|
1078 |
+
resnet_groups: int = 32,
|
1079 |
+
resnet_pre_norm: bool = True,
|
1080 |
+
num_attention_heads: int = 1,
|
1081 |
+
cross_attention_dim: int = 1280,
|
1082 |
+
output_scale_factor: float = 1.0,
|
1083 |
+
downsample_padding: int = 1,
|
1084 |
+
add_downsample: bool = True,
|
1085 |
+
dual_cross_attention: bool = False,
|
1086 |
+
use_linear_projection: bool = False,
|
1087 |
+
only_cross_attention: bool = False,
|
1088 |
+
upcast_attention: bool = False,
|
1089 |
+
attention_type: str = "default",
|
1090 |
+
temporal_cross_attention_dim: Optional[int] = None,
|
1091 |
+
temporal_num_attention_heads: int = 8,
|
1092 |
+
temporal_max_seq_length: int = 32,
|
1093 |
+
):
|
1094 |
+
super().__init__()
|
1095 |
+
resnets = []
|
1096 |
+
attentions = []
|
1097 |
+
motion_modules = []
|
1098 |
+
|
1099 |
+
self.has_cross_attention = True
|
1100 |
+
self.num_attention_heads = num_attention_heads
|
1101 |
+
|
1102 |
+
for i in range(num_layers):
|
1103 |
+
in_channels = in_channels if i == 0 else out_channels
|
1104 |
+
resnets.append(
|
1105 |
+
ResnetBlock2D(
|
1106 |
+
in_channels=in_channels,
|
1107 |
+
out_channels=out_channels,
|
1108 |
+
temb_channels=temb_channels,
|
1109 |
+
eps=resnet_eps,
|
1110 |
+
groups=resnet_groups,
|
1111 |
+
dropout=dropout,
|
1112 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1113 |
+
non_linearity=resnet_act_fn,
|
1114 |
+
output_scale_factor=output_scale_factor,
|
1115 |
+
pre_norm=resnet_pre_norm,
|
1116 |
+
)
|
1117 |
+
)
|
1118 |
+
|
1119 |
+
if not dual_cross_attention:
|
1120 |
+
attentions.append(
|
1121 |
+
Transformer2DModel(
|
1122 |
+
num_attention_heads,
|
1123 |
+
out_channels // num_attention_heads,
|
1124 |
+
in_channels=out_channels,
|
1125 |
+
num_layers=transformer_layers_per_block,
|
1126 |
+
cross_attention_dim=cross_attention_dim,
|
1127 |
+
norm_num_groups=resnet_groups,
|
1128 |
+
use_linear_projection=use_linear_projection,
|
1129 |
+
only_cross_attention=only_cross_attention,
|
1130 |
+
upcast_attention=upcast_attention,
|
1131 |
+
attention_type=attention_type,
|
1132 |
+
)
|
1133 |
+
)
|
1134 |
+
else:
|
1135 |
+
attentions.append(
|
1136 |
+
DualTransformer2DModel(
|
1137 |
+
num_attention_heads,
|
1138 |
+
out_channels // num_attention_heads,
|
1139 |
+
in_channels=out_channels,
|
1140 |
+
num_layers=1,
|
1141 |
+
cross_attention_dim=cross_attention_dim,
|
1142 |
+
norm_num_groups=resnet_groups,
|
1143 |
+
)
|
1144 |
+
)
|
1145 |
+
|
1146 |
+
motion_modules.append(
|
1147 |
+
TransformerTemporalModel(
|
1148 |
+
num_attention_heads=temporal_num_attention_heads,
|
1149 |
+
in_channels=out_channels,
|
1150 |
+
norm_num_groups=resnet_groups,
|
1151 |
+
cross_attention_dim=temporal_cross_attention_dim,
|
1152 |
+
attention_bias=False,
|
1153 |
+
activation_fn="geglu",
|
1154 |
+
positional_embeddings="sinusoidal",
|
1155 |
+
num_positional_embeddings=temporal_max_seq_length,
|
1156 |
+
attention_head_dim=out_channels // temporal_num_attention_heads,
|
1157 |
+
)
|
1158 |
+
)
|
1159 |
+
|
1160 |
+
self.attentions = nn.ModuleList(attentions)
|
1161 |
+
self.resnets = nn.ModuleList(resnets)
|
1162 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
1163 |
+
|
1164 |
+
if add_downsample:
|
1165 |
+
self.downsamplers = nn.ModuleList(
|
1166 |
+
[
|
1167 |
+
Downsample2D(
|
1168 |
+
out_channels,
|
1169 |
+
use_conv=True,
|
1170 |
+
out_channels=out_channels,
|
1171 |
+
padding=downsample_padding,
|
1172 |
+
name="op",
|
1173 |
+
)
|
1174 |
+
]
|
1175 |
+
)
|
1176 |
+
else:
|
1177 |
+
self.downsamplers = None
|
1178 |
+
|
1179 |
+
self.gradient_checkpointing = False
|
1180 |
+
|
1181 |
+
def forward(
|
1182 |
+
self,
|
1183 |
+
hidden_states: torch.FloatTensor,
|
1184 |
+
temb: Optional[torch.FloatTensor] = None,
|
1185 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1186 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1187 |
+
num_frames: int = 1,
|
1188 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
1189 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1190 |
+
additional_residuals: Optional[torch.FloatTensor] = None,
|
1191 |
+
down_block_add_samples: Optional[torch.FloatTensor] = None,
|
1192 |
+
):
|
1193 |
+
output_states = ()
|
1194 |
+
|
1195 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
1196 |
+
|
1197 |
+
blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
|
1198 |
+
for i, (resnet, attn, motion_module) in enumerate(blocks):
|
1199 |
+
if self.training and self.gradient_checkpointing:
|
1200 |
+
|
1201 |
+
def create_custom_forward(module, return_dict=None):
|
1202 |
+
def custom_forward(*inputs):
|
1203 |
+
if return_dict is not None:
|
1204 |
+
return module(*inputs, return_dict=return_dict)
|
1205 |
+
else:
|
1206 |
+
return module(*inputs)
|
1207 |
+
|
1208 |
+
return custom_forward
|
1209 |
+
|
1210 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1211 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1212 |
+
create_custom_forward(resnet),
|
1213 |
+
hidden_states,
|
1214 |
+
temb,
|
1215 |
+
**ckpt_kwargs,
|
1216 |
+
)
|
1217 |
+
hidden_states = attn(
|
1218 |
+
hidden_states,
|
1219 |
+
encoder_hidden_states=encoder_hidden_states,
|
1220 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1221 |
+
attention_mask=attention_mask,
|
1222 |
+
encoder_attention_mask=encoder_attention_mask,
|
1223 |
+
return_dict=False,
|
1224 |
+
)[0]
|
1225 |
+
# apply additional residuals to the output of the last pair of resnet and attention blocks
|
1226 |
+
if i == len(blocks) - 1 and additional_residuals is not None:
|
1227 |
+
hidden_states = hidden_states + additional_residuals
|
1228 |
+
if down_block_add_samples is not None:
|
1229 |
+
hidden_states = hidden_states + down_block_add_samples.pop(0)
|
1230 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1231 |
+
create_custom_forward(motion_module),
|
1232 |
+
hidden_states.requires_grad_(),
|
1233 |
+
temb,
|
1234 |
+
num_frames,
|
1235 |
+
**ckpt_kwargs,
|
1236 |
+
)
|
1237 |
+
else:
|
1238 |
+
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
1239 |
+
hidden_states = attn(
|
1240 |
+
hidden_states,
|
1241 |
+
encoder_hidden_states=encoder_hidden_states,
|
1242 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1243 |
+
attention_mask=attention_mask,
|
1244 |
+
encoder_attention_mask=encoder_attention_mask,
|
1245 |
+
return_dict=False,
|
1246 |
+
)[0]
|
1247 |
+
# apply additional residuals to the output of the last pair of resnet and attention blocks
|
1248 |
+
if i == len(blocks) - 1 and additional_residuals is not None:
|
1249 |
+
hidden_states = hidden_states + additional_residuals
|
1250 |
+
if down_block_add_samples is not None:
|
1251 |
+
hidden_states = hidden_states + down_block_add_samples.pop(0)
|
1252 |
+
hidden_states = motion_module(
|
1253 |
+
hidden_states,
|
1254 |
+
num_frames=num_frames,
|
1255 |
+
)
|
1256 |
+
|
1257 |
+
# # apply additional residuals to the output of the last pair of resnet and attention blocks
|
1258 |
+
# if i == len(blocks) - 1 and additional_residuals is not None:
|
1259 |
+
# hidden_states = hidden_states + additional_residuals
|
1260 |
+
|
1261 |
+
output_states = output_states + (hidden_states,)
|
1262 |
+
|
1263 |
+
if self.downsamplers is not None:
|
1264 |
+
for downsampler in self.downsamplers:
|
1265 |
+
hidden_states = downsampler(hidden_states, scale=lora_scale)
|
1266 |
+
|
1267 |
+
if down_block_add_samples is not None:
|
1268 |
+
hidden_states = hidden_states + down_block_add_samples.pop(0)
|
1269 |
+
|
1270 |
+
output_states = output_states + (hidden_states,)
|
1271 |
+
|
1272 |
+
return hidden_states, output_states
|
1273 |
+
|
1274 |
+
|
1275 |
+
class CrossAttnUpBlockMotion(nn.Module):
|
1276 |
+
def __init__(
|
1277 |
+
self,
|
1278 |
+
in_channels: int,
|
1279 |
+
out_channels: int,
|
1280 |
+
prev_output_channel: int,
|
1281 |
+
temb_channels: int,
|
1282 |
+
resolution_idx: Optional[int] = None,
|
1283 |
+
dropout: float = 0.0,
|
1284 |
+
num_layers: int = 1,
|
1285 |
+
transformer_layers_per_block: int = 1,
|
1286 |
+
resnet_eps: float = 1e-6,
|
1287 |
+
resnet_time_scale_shift: str = "default",
|
1288 |
+
resnet_act_fn: str = "swish",
|
1289 |
+
resnet_groups: int = 32,
|
1290 |
+
resnet_pre_norm: bool = True,
|
1291 |
+
num_attention_heads: int = 1,
|
1292 |
+
cross_attention_dim: int = 1280,
|
1293 |
+
output_scale_factor: float = 1.0,
|
1294 |
+
add_upsample: bool = True,
|
1295 |
+
dual_cross_attention: bool = False,
|
1296 |
+
use_linear_projection: bool = False,
|
1297 |
+
only_cross_attention: bool = False,
|
1298 |
+
upcast_attention: bool = False,
|
1299 |
+
attention_type: str = "default",
|
1300 |
+
temporal_cross_attention_dim: Optional[int] = None,
|
1301 |
+
temporal_num_attention_heads: int = 8,
|
1302 |
+
temporal_max_seq_length: int = 32,
|
1303 |
+
):
|
1304 |
+
super().__init__()
|
1305 |
+
resnets = []
|
1306 |
+
attentions = []
|
1307 |
+
motion_modules = []
|
1308 |
+
|
1309 |
+
self.has_cross_attention = True
|
1310 |
+
self.num_attention_heads = num_attention_heads
|
1311 |
+
|
1312 |
+
for i in range(num_layers):
|
1313 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1314 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1315 |
+
|
1316 |
+
resnets.append(
|
1317 |
+
ResnetBlock2D(
|
1318 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1319 |
+
out_channels=out_channels,
|
1320 |
+
temb_channels=temb_channels,
|
1321 |
+
eps=resnet_eps,
|
1322 |
+
groups=resnet_groups,
|
1323 |
+
dropout=dropout,
|
1324 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1325 |
+
non_linearity=resnet_act_fn,
|
1326 |
+
output_scale_factor=output_scale_factor,
|
1327 |
+
pre_norm=resnet_pre_norm,
|
1328 |
+
)
|
1329 |
+
)
|
1330 |
+
|
1331 |
+
if not dual_cross_attention:
|
1332 |
+
attentions.append(
|
1333 |
+
Transformer2DModel(
|
1334 |
+
num_attention_heads,
|
1335 |
+
out_channels // num_attention_heads,
|
1336 |
+
in_channels=out_channels,
|
1337 |
+
num_layers=transformer_layers_per_block,
|
1338 |
+
cross_attention_dim=cross_attention_dim,
|
1339 |
+
norm_num_groups=resnet_groups,
|
1340 |
+
use_linear_projection=use_linear_projection,
|
1341 |
+
only_cross_attention=only_cross_attention,
|
1342 |
+
upcast_attention=upcast_attention,
|
1343 |
+
attention_type=attention_type,
|
1344 |
+
)
|
1345 |
+
)
|
1346 |
+
else:
|
1347 |
+
attentions.append(
|
1348 |
+
DualTransformer2DModel(
|
1349 |
+
num_attention_heads,
|
1350 |
+
out_channels // num_attention_heads,
|
1351 |
+
in_channels=out_channels,
|
1352 |
+
num_layers=1,
|
1353 |
+
cross_attention_dim=cross_attention_dim,
|
1354 |
+
norm_num_groups=resnet_groups,
|
1355 |
+
)
|
1356 |
+
)
|
1357 |
+
motion_modules.append(
|
1358 |
+
TransformerTemporalModel(
|
1359 |
+
num_attention_heads=temporal_num_attention_heads,
|
1360 |
+
in_channels=out_channels,
|
1361 |
+
norm_num_groups=resnet_groups,
|
1362 |
+
cross_attention_dim=temporal_cross_attention_dim,
|
1363 |
+
attention_bias=False,
|
1364 |
+
activation_fn="geglu",
|
1365 |
+
positional_embeddings="sinusoidal",
|
1366 |
+
num_positional_embeddings=temporal_max_seq_length,
|
1367 |
+
attention_head_dim=out_channels // temporal_num_attention_heads,
|
1368 |
+
)
|
1369 |
+
)
|
1370 |
+
|
1371 |
+
self.attentions = nn.ModuleList(attentions)
|
1372 |
+
self.resnets = nn.ModuleList(resnets)
|
1373 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
1374 |
+
|
1375 |
+
if add_upsample:
|
1376 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
1377 |
+
else:
|
1378 |
+
self.upsamplers = None
|
1379 |
+
|
1380 |
+
self.gradient_checkpointing = False
|
1381 |
+
self.resolution_idx = resolution_idx
|
1382 |
+
|
1383 |
+
def forward(
|
1384 |
+
self,
|
1385 |
+
hidden_states: torch.FloatTensor,
|
1386 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
1387 |
+
temb: Optional[torch.FloatTensor] = None,
|
1388 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1389 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1390 |
+
upsample_size: Optional[int] = None,
|
1391 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1392 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
1393 |
+
num_frames: int = 1,
|
1394 |
+
up_block_add_samples: Optional[torch.FloatTensor] = None,
|
1395 |
+
) -> torch.FloatTensor:
|
1396 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
1397 |
+
is_freeu_enabled = (
|
1398 |
+
getattr(self, "s1", None)
|
1399 |
+
and getattr(self, "s2", None)
|
1400 |
+
and getattr(self, "b1", None)
|
1401 |
+
and getattr(self, "b2", None)
|
1402 |
+
)
|
1403 |
+
|
1404 |
+
blocks = zip(self.resnets, self.attentions, self.motion_modules)
|
1405 |
+
for resnet, attn, motion_module in blocks:
|
1406 |
+
# pop res hidden states
|
1407 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1408 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1409 |
+
|
1410 |
+
# FreeU: Only operate on the first two stages
|
1411 |
+
if is_freeu_enabled:
|
1412 |
+
hidden_states, res_hidden_states = apply_freeu(
|
1413 |
+
self.resolution_idx,
|
1414 |
+
hidden_states,
|
1415 |
+
res_hidden_states,
|
1416 |
+
s1=self.s1,
|
1417 |
+
s2=self.s2,
|
1418 |
+
b1=self.b1,
|
1419 |
+
b2=self.b2,
|
1420 |
+
)
|
1421 |
+
|
1422 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1423 |
+
|
1424 |
+
if self.training and self.gradient_checkpointing:
|
1425 |
+
|
1426 |
+
def create_custom_forward(module, return_dict=None):
|
1427 |
+
def custom_forward(*inputs):
|
1428 |
+
if return_dict is not None:
|
1429 |
+
return module(*inputs, return_dict=return_dict)
|
1430 |
+
else:
|
1431 |
+
return module(*inputs)
|
1432 |
+
|
1433 |
+
return custom_forward
|
1434 |
+
|
1435 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1436 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1437 |
+
create_custom_forward(resnet),
|
1438 |
+
hidden_states,
|
1439 |
+
temb,
|
1440 |
+
**ckpt_kwargs,
|
1441 |
+
)
|
1442 |
+
hidden_states = attn(
|
1443 |
+
hidden_states,
|
1444 |
+
encoder_hidden_states=encoder_hidden_states,
|
1445 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1446 |
+
attention_mask=attention_mask,
|
1447 |
+
encoder_attention_mask=encoder_attention_mask,
|
1448 |
+
return_dict=False,
|
1449 |
+
)[0]
|
1450 |
+
if up_block_add_samples is not None:
|
1451 |
+
hidden_states = hidden_states + up_block_add_samples.pop(0)
|
1452 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1453 |
+
create_custom_forward(motion_module),
|
1454 |
+
hidden_states.requires_grad_(),
|
1455 |
+
temb,
|
1456 |
+
num_frames,
|
1457 |
+
**ckpt_kwargs,
|
1458 |
+
)
|
1459 |
+
else:
|
1460 |
+
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
1461 |
+
hidden_states = attn(
|
1462 |
+
hidden_states,
|
1463 |
+
encoder_hidden_states=encoder_hidden_states,
|
1464 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1465 |
+
attention_mask=attention_mask,
|
1466 |
+
encoder_attention_mask=encoder_attention_mask,
|
1467 |
+
return_dict=False,
|
1468 |
+
)[0]
|
1469 |
+
if up_block_add_samples is not None:
|
1470 |
+
hidden_states = hidden_states + up_block_add_samples.pop(0)
|
1471 |
+
hidden_states = motion_module(
|
1472 |
+
hidden_states,
|
1473 |
+
num_frames=num_frames,
|
1474 |
+
)
|
1475 |
+
|
1476 |
+
if self.upsamplers is not None:
|
1477 |
+
for upsampler in self.upsamplers:
|
1478 |
+
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
|
1479 |
+
if up_block_add_samples is not None:
|
1480 |
+
hidden_states = hidden_states + up_block_add_samples.pop(0)
|
1481 |
+
|
1482 |
+
return hidden_states
|
1483 |
+
|
1484 |
+
|
1485 |
+
class UpBlockMotion(nn.Module):
|
1486 |
+
def __init__(
|
1487 |
+
self,
|
1488 |
+
in_channels: int,
|
1489 |
+
prev_output_channel: int,
|
1490 |
+
out_channels: int,
|
1491 |
+
temb_channels: int,
|
1492 |
+
resolution_idx: Optional[int] = None,
|
1493 |
+
dropout: float = 0.0,
|
1494 |
+
num_layers: int = 1,
|
1495 |
+
resnet_eps: float = 1e-6,
|
1496 |
+
resnet_time_scale_shift: str = "default",
|
1497 |
+
resnet_act_fn: str = "swish",
|
1498 |
+
resnet_groups: int = 32,
|
1499 |
+
resnet_pre_norm: bool = True,
|
1500 |
+
output_scale_factor: float = 1.0,
|
1501 |
+
add_upsample: bool = True,
|
1502 |
+
temporal_norm_num_groups: int = 32,
|
1503 |
+
temporal_cross_attention_dim: Optional[int] = None,
|
1504 |
+
temporal_num_attention_heads: int = 8,
|
1505 |
+
temporal_max_seq_length: int = 32,
|
1506 |
+
):
|
1507 |
+
super().__init__()
|
1508 |
+
resnets = []
|
1509 |
+
motion_modules = []
|
1510 |
+
|
1511 |
+
for i in range(num_layers):
|
1512 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1513 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1514 |
+
|
1515 |
+
resnets.append(
|
1516 |
+
ResnetBlock2D(
|
1517 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1518 |
+
out_channels=out_channels,
|
1519 |
+
temb_channels=temb_channels,
|
1520 |
+
eps=resnet_eps,
|
1521 |
+
groups=resnet_groups,
|
1522 |
+
dropout=dropout,
|
1523 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1524 |
+
non_linearity=resnet_act_fn,
|
1525 |
+
output_scale_factor=output_scale_factor,
|
1526 |
+
pre_norm=resnet_pre_norm,
|
1527 |
+
)
|
1528 |
+
)
|
1529 |
+
|
1530 |
+
motion_modules.append(
|
1531 |
+
TransformerTemporalModel(
|
1532 |
+
num_attention_heads=temporal_num_attention_heads,
|
1533 |
+
in_channels=out_channels,
|
1534 |
+
norm_num_groups=temporal_norm_num_groups,
|
1535 |
+
cross_attention_dim=temporal_cross_attention_dim,
|
1536 |
+
attention_bias=False,
|
1537 |
+
activation_fn="geglu",
|
1538 |
+
positional_embeddings="sinusoidal",
|
1539 |
+
num_positional_embeddings=temporal_max_seq_length,
|
1540 |
+
attention_head_dim=out_channels // temporal_num_attention_heads,
|
1541 |
+
)
|
1542 |
+
)
|
1543 |
+
|
1544 |
+
self.resnets = nn.ModuleList(resnets)
|
1545 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
1546 |
+
|
1547 |
+
if add_upsample:
|
1548 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
1549 |
+
else:
|
1550 |
+
self.upsamplers = None
|
1551 |
+
|
1552 |
+
self.gradient_checkpointing = False
|
1553 |
+
self.resolution_idx = resolution_idx
|
1554 |
+
|
1555 |
+
def forward(
|
1556 |
+
self,
|
1557 |
+
hidden_states: torch.FloatTensor,
|
1558 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
1559 |
+
temb: Optional[torch.FloatTensor] = None,
|
1560 |
+
upsample_size=None,
|
1561 |
+
scale: float = 1.0,
|
1562 |
+
num_frames: int = 1,
|
1563 |
+
up_block_add_samples: Optional[torch.FloatTensor] = None,
|
1564 |
+
) -> torch.FloatTensor:
|
1565 |
+
is_freeu_enabled = (
|
1566 |
+
getattr(self, "s1", None)
|
1567 |
+
and getattr(self, "s2", None)
|
1568 |
+
and getattr(self, "b1", None)
|
1569 |
+
and getattr(self, "b2", None)
|
1570 |
+
)
|
1571 |
+
|
1572 |
+
blocks = zip(self.resnets, self.motion_modules)
|
1573 |
+
|
1574 |
+
for resnet, motion_module in blocks:
|
1575 |
+
# pop res hidden states
|
1576 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1577 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1578 |
+
|
1579 |
+
# FreeU: Only operate on the first two stages
|
1580 |
+
if is_freeu_enabled:
|
1581 |
+
hidden_states, res_hidden_states = apply_freeu(
|
1582 |
+
self.resolution_idx,
|
1583 |
+
hidden_states,
|
1584 |
+
res_hidden_states,
|
1585 |
+
s1=self.s1,
|
1586 |
+
s2=self.s2,
|
1587 |
+
b1=self.b1,
|
1588 |
+
b2=self.b2,
|
1589 |
+
)
|
1590 |
+
|
1591 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1592 |
+
|
1593 |
+
if self.training and self.gradient_checkpointing:
|
1594 |
+
|
1595 |
+
def create_custom_forward(module):
|
1596 |
+
def custom_forward(*inputs):
|
1597 |
+
return module(*inputs)
|
1598 |
+
|
1599 |
+
return custom_forward
|
1600 |
+
|
1601 |
+
if is_torch_version(">=", "1.11.0"):
|
1602 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1603 |
+
create_custom_forward(resnet),
|
1604 |
+
hidden_states,
|
1605 |
+
temb,
|
1606 |
+
use_reentrant=False,
|
1607 |
+
)
|
1608 |
+
else:
|
1609 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1610 |
+
create_custom_forward(resnet), hidden_states, temb
|
1611 |
+
)
|
1612 |
+
|
1613 |
+
if up_block_add_samples is not None:
|
1614 |
+
hidden_states = hidden_states + up_block_add_samples.pop(0)
|
1615 |
+
|
1616 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1617 |
+
create_custom_forward(motion_module),
|
1618 |
+
hidden_states.requires_grad_(),
|
1619 |
+
temb,
|
1620 |
+
num_frames,
|
1621 |
+
use_reentrant=False,
|
1622 |
+
)
|
1623 |
+
else:
|
1624 |
+
hidden_states = resnet(hidden_states, temb, scale=scale)
|
1625 |
+
if up_block_add_samples is not None:
|
1626 |
+
hidden_states = hidden_states + up_block_add_samples.pop(0)
|
1627 |
+
hidden_states = motion_module(hidden_states, num_frames=num_frames)
|
1628 |
+
|
1629 |
+
if self.upsamplers is not None:
|
1630 |
+
for upsampler in self.upsamplers:
|
1631 |
+
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
|
1632 |
+
|
1633 |
+
if up_block_add_samples is not None:
|
1634 |
+
hidden_states = hidden_states + up_block_add_samples.pop(0)
|
1635 |
+
|
1636 |
+
return hidden_states
|
1637 |
+
|
1638 |
+
|
1639 |
+
class UNetMidBlockCrossAttnMotion(nn.Module):
|
1640 |
+
def __init__(
|
1641 |
+
self,
|
1642 |
+
in_channels: int,
|
1643 |
+
temb_channels: int,
|
1644 |
+
dropout: float = 0.0,
|
1645 |
+
num_layers: int = 1,
|
1646 |
+
transformer_layers_per_block: int = 1,
|
1647 |
+
resnet_eps: float = 1e-6,
|
1648 |
+
resnet_time_scale_shift: str = "default",
|
1649 |
+
resnet_act_fn: str = "swish",
|
1650 |
+
resnet_groups: int = 32,
|
1651 |
+
resnet_pre_norm: bool = True,
|
1652 |
+
num_attention_heads: int = 1,
|
1653 |
+
output_scale_factor: float = 1.0,
|
1654 |
+
cross_attention_dim: int = 1280,
|
1655 |
+
dual_cross_attention: float = False,
|
1656 |
+
use_linear_projection: float = False,
|
1657 |
+
upcast_attention: float = False,
|
1658 |
+
attention_type: str = "default",
|
1659 |
+
temporal_num_attention_heads: int = 1,
|
1660 |
+
temporal_cross_attention_dim: Optional[int] = None,
|
1661 |
+
temporal_max_seq_length: int = 32,
|
1662 |
+
):
|
1663 |
+
super().__init__()
|
1664 |
+
|
1665 |
+
self.has_cross_attention = True
|
1666 |
+
self.num_attention_heads = num_attention_heads
|
1667 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
1668 |
+
|
1669 |
+
# there is always at least one resnet
|
1670 |
+
resnets = [
|
1671 |
+
ResnetBlock2D(
|
1672 |
+
in_channels=in_channels,
|
1673 |
+
out_channels=in_channels,
|
1674 |
+
temb_channels=temb_channels,
|
1675 |
+
eps=resnet_eps,
|
1676 |
+
groups=resnet_groups,
|
1677 |
+
dropout=dropout,
|
1678 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1679 |
+
non_linearity=resnet_act_fn,
|
1680 |
+
output_scale_factor=output_scale_factor,
|
1681 |
+
pre_norm=resnet_pre_norm,
|
1682 |
+
)
|
1683 |
+
]
|
1684 |
+
attentions = []
|
1685 |
+
motion_modules = []
|
1686 |
+
|
1687 |
+
for _ in range(num_layers):
|
1688 |
+
if not dual_cross_attention:
|
1689 |
+
attentions.append(
|
1690 |
+
Transformer2DModel(
|
1691 |
+
num_attention_heads,
|
1692 |
+
in_channels // num_attention_heads,
|
1693 |
+
in_channels=in_channels,
|
1694 |
+
num_layers=transformer_layers_per_block,
|
1695 |
+
cross_attention_dim=cross_attention_dim,
|
1696 |
+
norm_num_groups=resnet_groups,
|
1697 |
+
use_linear_projection=use_linear_projection,
|
1698 |
+
upcast_attention=upcast_attention,
|
1699 |
+
attention_type=attention_type,
|
1700 |
+
)
|
1701 |
+
)
|
1702 |
+
else:
|
1703 |
+
attentions.append(
|
1704 |
+
DualTransformer2DModel(
|
1705 |
+
num_attention_heads,
|
1706 |
+
in_channels // num_attention_heads,
|
1707 |
+
in_channels=in_channels,
|
1708 |
+
num_layers=1,
|
1709 |
+
cross_attention_dim=cross_attention_dim,
|
1710 |
+
norm_num_groups=resnet_groups,
|
1711 |
+
)
|
1712 |
+
)
|
1713 |
+
resnets.append(
|
1714 |
+
ResnetBlock2D(
|
1715 |
+
in_channels=in_channels,
|
1716 |
+
out_channels=in_channels,
|
1717 |
+
temb_channels=temb_channels,
|
1718 |
+
eps=resnet_eps,
|
1719 |
+
groups=resnet_groups,
|
1720 |
+
dropout=dropout,
|
1721 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1722 |
+
non_linearity=resnet_act_fn,
|
1723 |
+
output_scale_factor=output_scale_factor,
|
1724 |
+
pre_norm=resnet_pre_norm,
|
1725 |
+
)
|
1726 |
+
)
|
1727 |
+
motion_modules.append(
|
1728 |
+
TransformerTemporalModel(
|
1729 |
+
num_attention_heads=temporal_num_attention_heads,
|
1730 |
+
attention_head_dim=in_channels // temporal_num_attention_heads,
|
1731 |
+
in_channels=in_channels,
|
1732 |
+
norm_num_groups=resnet_groups,
|
1733 |
+
cross_attention_dim=temporal_cross_attention_dim,
|
1734 |
+
attention_bias=False,
|
1735 |
+
positional_embeddings="sinusoidal",
|
1736 |
+
num_positional_embeddings=temporal_max_seq_length,
|
1737 |
+
activation_fn="geglu",
|
1738 |
+
)
|
1739 |
+
)
|
1740 |
+
|
1741 |
+
self.attentions = nn.ModuleList(attentions)
|
1742 |
+
self.resnets = nn.ModuleList(resnets)
|
1743 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
1744 |
+
|
1745 |
+
self.gradient_checkpointing = False
|
1746 |
+
|
1747 |
+
def forward(
|
1748 |
+
self,
|
1749 |
+
hidden_states: torch.FloatTensor,
|
1750 |
+
temb: Optional[torch.FloatTensor] = None,
|
1751 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1752 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1753 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1754 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
1755 |
+
num_frames: int = 1,
|
1756 |
+
mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
|
1757 |
+
) -> torch.FloatTensor:
|
1758 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
1759 |
+
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
|
1760 |
+
|
1761 |
+
blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
|
1762 |
+
for attn, resnet, motion_module in blocks:
|
1763 |
+
if self.training and self.gradient_checkpointing:
|
1764 |
+
|
1765 |
+
def create_custom_forward(module, return_dict=None):
|
1766 |
+
def custom_forward(*inputs):
|
1767 |
+
if return_dict is not None:
|
1768 |
+
return module(*inputs, return_dict=return_dict)
|
1769 |
+
else:
|
1770 |
+
return module(*inputs)
|
1771 |
+
|
1772 |
+
return custom_forward
|
1773 |
+
|
1774 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1775 |
+
hidden_states = attn(
|
1776 |
+
hidden_states,
|
1777 |
+
encoder_hidden_states=encoder_hidden_states,
|
1778 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1779 |
+
attention_mask=attention_mask,
|
1780 |
+
encoder_attention_mask=encoder_attention_mask,
|
1781 |
+
return_dict=False,
|
1782 |
+
)[0]
|
1783 |
+
##########
|
1784 |
+
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
1785 |
+
if mid_block_add_sample is not None:
|
1786 |
+
hidden_states = hidden_states + mid_block_add_sample
|
1787 |
+
################################################################
|
1788 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1789 |
+
create_custom_forward(motion_module),
|
1790 |
+
hidden_states.requires_grad_(),
|
1791 |
+
temb,
|
1792 |
+
num_frames,
|
1793 |
+
**ckpt_kwargs,
|
1794 |
+
)
|
1795 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1796 |
+
create_custom_forward(resnet),
|
1797 |
+
hidden_states,
|
1798 |
+
temb,
|
1799 |
+
**ckpt_kwargs,
|
1800 |
+
)
|
1801 |
+
else:
|
1802 |
+
hidden_states = attn(
|
1803 |
+
hidden_states,
|
1804 |
+
encoder_hidden_states=encoder_hidden_states,
|
1805 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1806 |
+
attention_mask=attention_mask,
|
1807 |
+
encoder_attention_mask=encoder_attention_mask,
|
1808 |
+
return_dict=False,
|
1809 |
+
)[0]
|
1810 |
+
##########
|
1811 |
+
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
1812 |
+
if mid_block_add_sample is not None:
|
1813 |
+
hidden_states = hidden_states + mid_block_add_sample
|
1814 |
+
################################################################
|
1815 |
+
hidden_states = motion_module(
|
1816 |
+
hidden_states,
|
1817 |
+
num_frames=num_frames,
|
1818 |
+
)
|
1819 |
+
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
1820 |
+
|
1821 |
+
return hidden_states
|
1822 |
+
|
1823 |
+
|
1824 |
+
class MidBlockTemporalDecoder(nn.Module):
|
1825 |
+
def __init__(
|
1826 |
+
self,
|
1827 |
+
in_channels: int,
|
1828 |
+
out_channels: int,
|
1829 |
+
attention_head_dim: int = 512,
|
1830 |
+
num_layers: int = 1,
|
1831 |
+
upcast_attention: bool = False,
|
1832 |
+
):
|
1833 |
+
super().__init__()
|
1834 |
+
|
1835 |
+
resnets = []
|
1836 |
+
attentions = []
|
1837 |
+
for i in range(num_layers):
|
1838 |
+
input_channels = in_channels if i == 0 else out_channels
|
1839 |
+
resnets.append(
|
1840 |
+
SpatioTemporalResBlock(
|
1841 |
+
in_channels=input_channels,
|
1842 |
+
out_channels=out_channels,
|
1843 |
+
temb_channels=None,
|
1844 |
+
eps=1e-6,
|
1845 |
+
temporal_eps=1e-5,
|
1846 |
+
merge_factor=0.0,
|
1847 |
+
merge_strategy="learned",
|
1848 |
+
switch_spatial_to_temporal_mix=True,
|
1849 |
+
)
|
1850 |
+
)
|
1851 |
+
|
1852 |
+
attentions.append(
|
1853 |
+
Attention(
|
1854 |
+
query_dim=in_channels,
|
1855 |
+
heads=in_channels // attention_head_dim,
|
1856 |
+
dim_head=attention_head_dim,
|
1857 |
+
eps=1e-6,
|
1858 |
+
upcast_attention=upcast_attention,
|
1859 |
+
norm_num_groups=32,
|
1860 |
+
bias=True,
|
1861 |
+
residual_connection=True,
|
1862 |
+
)
|
1863 |
+
)
|
1864 |
+
|
1865 |
+
self.attentions = nn.ModuleList(attentions)
|
1866 |
+
self.resnets = nn.ModuleList(resnets)
|
1867 |
+
|
1868 |
+
def forward(
|
1869 |
+
self,
|
1870 |
+
hidden_states: torch.FloatTensor,
|
1871 |
+
image_only_indicator: torch.FloatTensor,
|
1872 |
+
):
|
1873 |
+
hidden_states = self.resnets[0](
|
1874 |
+
hidden_states,
|
1875 |
+
image_only_indicator=image_only_indicator,
|
1876 |
+
)
|
1877 |
+
for resnet, attn in zip(self.resnets[1:], self.attentions):
|
1878 |
+
hidden_states = attn(hidden_states)
|
1879 |
+
hidden_states = resnet(
|
1880 |
+
hidden_states,
|
1881 |
+
image_only_indicator=image_only_indicator,
|
1882 |
+
)
|
1883 |
+
|
1884 |
+
return hidden_states
|
1885 |
+
|
1886 |
+
|
1887 |
+
class UpBlockTemporalDecoder(nn.Module):
|
1888 |
+
def __init__(
|
1889 |
+
self,
|
1890 |
+
in_channels: int,
|
1891 |
+
out_channels: int,
|
1892 |
+
num_layers: int = 1,
|
1893 |
+
add_upsample: bool = True,
|
1894 |
+
):
|
1895 |
+
super().__init__()
|
1896 |
+
resnets = []
|
1897 |
+
for i in range(num_layers):
|
1898 |
+
input_channels = in_channels if i == 0 else out_channels
|
1899 |
+
|
1900 |
+
resnets.append(
|
1901 |
+
SpatioTemporalResBlock(
|
1902 |
+
in_channels=input_channels,
|
1903 |
+
out_channels=out_channels,
|
1904 |
+
temb_channels=None,
|
1905 |
+
eps=1e-6,
|
1906 |
+
temporal_eps=1e-5,
|
1907 |
+
merge_factor=0.0,
|
1908 |
+
merge_strategy="learned",
|
1909 |
+
switch_spatial_to_temporal_mix=True,
|
1910 |
+
)
|
1911 |
+
)
|
1912 |
+
self.resnets = nn.ModuleList(resnets)
|
1913 |
+
|
1914 |
+
if add_upsample:
|
1915 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
1916 |
+
else:
|
1917 |
+
self.upsamplers = None
|
1918 |
+
|
1919 |
+
def forward(
|
1920 |
+
self,
|
1921 |
+
hidden_states: torch.FloatTensor,
|
1922 |
+
image_only_indicator: torch.FloatTensor,
|
1923 |
+
) -> torch.FloatTensor:
|
1924 |
+
for resnet in self.resnets:
|
1925 |
+
hidden_states = resnet(
|
1926 |
+
hidden_states,
|
1927 |
+
image_only_indicator=image_only_indicator,
|
1928 |
+
)
|
1929 |
+
|
1930 |
+
if self.upsamplers is not None:
|
1931 |
+
for upsampler in self.upsamplers:
|
1932 |
+
hidden_states = upsampler(hidden_states)
|
1933 |
+
|
1934 |
+
return hidden_states
|
1935 |
+
|
1936 |
+
|
1937 |
+
class UNetMidBlockSpatioTemporal(nn.Module):
|
1938 |
+
def __init__(
|
1939 |
+
self,
|
1940 |
+
in_channels: int,
|
1941 |
+
temb_channels: int,
|
1942 |
+
num_layers: int = 1,
|
1943 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
1944 |
+
num_attention_heads: int = 1,
|
1945 |
+
cross_attention_dim: int = 1280,
|
1946 |
+
):
|
1947 |
+
super().__init__()
|
1948 |
+
|
1949 |
+
self.has_cross_attention = True
|
1950 |
+
self.num_attention_heads = num_attention_heads
|
1951 |
+
|
1952 |
+
# support for variable transformer layers per block
|
1953 |
+
if isinstance(transformer_layers_per_block, int):
|
1954 |
+
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
1955 |
+
|
1956 |
+
# there is always at least one resnet
|
1957 |
+
resnets = [
|
1958 |
+
SpatioTemporalResBlock(
|
1959 |
+
in_channels=in_channels,
|
1960 |
+
out_channels=in_channels,
|
1961 |
+
temb_channels=temb_channels,
|
1962 |
+
eps=1e-5,
|
1963 |
+
)
|
1964 |
+
]
|
1965 |
+
attentions = []
|
1966 |
+
|
1967 |
+
for i in range(num_layers):
|
1968 |
+
attentions.append(
|
1969 |
+
TransformerSpatioTemporalModel(
|
1970 |
+
num_attention_heads,
|
1971 |
+
in_channels // num_attention_heads,
|
1972 |
+
in_channels=in_channels,
|
1973 |
+
num_layers=transformer_layers_per_block[i],
|
1974 |
+
cross_attention_dim=cross_attention_dim,
|
1975 |
+
)
|
1976 |
+
)
|
1977 |
+
|
1978 |
+
resnets.append(
|
1979 |
+
SpatioTemporalResBlock(
|
1980 |
+
in_channels=in_channels,
|
1981 |
+
out_channels=in_channels,
|
1982 |
+
temb_channels=temb_channels,
|
1983 |
+
eps=1e-5,
|
1984 |
+
)
|
1985 |
+
)
|
1986 |
+
|
1987 |
+
self.attentions = nn.ModuleList(attentions)
|
1988 |
+
self.resnets = nn.ModuleList(resnets)
|
1989 |
+
|
1990 |
+
self.gradient_checkpointing = False
|
1991 |
+
|
1992 |
+
def forward(
|
1993 |
+
self,
|
1994 |
+
hidden_states: torch.FloatTensor,
|
1995 |
+
temb: Optional[torch.FloatTensor] = None,
|
1996 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1997 |
+
image_only_indicator: Optional[torch.Tensor] = None,
|
1998 |
+
) -> torch.FloatTensor:
|
1999 |
+
hidden_states = self.resnets[0](
|
2000 |
+
hidden_states,
|
2001 |
+
temb,
|
2002 |
+
image_only_indicator=image_only_indicator,
|
2003 |
+
)
|
2004 |
+
|
2005 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
2006 |
+
if self.training and self.gradient_checkpointing: # TODO
|
2007 |
+
|
2008 |
+
def create_custom_forward(module, return_dict=None):
|
2009 |
+
def custom_forward(*inputs):
|
2010 |
+
if return_dict is not None:
|
2011 |
+
return module(*inputs, return_dict=return_dict)
|
2012 |
+
else:
|
2013 |
+
return module(*inputs)
|
2014 |
+
|
2015 |
+
return custom_forward
|
2016 |
+
|
2017 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
2018 |
+
hidden_states = attn(
|
2019 |
+
hidden_states,
|
2020 |
+
encoder_hidden_states=encoder_hidden_states,
|
2021 |
+
image_only_indicator=image_only_indicator,
|
2022 |
+
return_dict=False,
|
2023 |
+
)[0]
|
2024 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
2025 |
+
create_custom_forward(resnet),
|
2026 |
+
hidden_states,
|
2027 |
+
temb,
|
2028 |
+
image_only_indicator,
|
2029 |
+
**ckpt_kwargs,
|
2030 |
+
)
|
2031 |
+
else:
|
2032 |
+
hidden_states = attn(
|
2033 |
+
hidden_states,
|
2034 |
+
encoder_hidden_states=encoder_hidden_states,
|
2035 |
+
image_only_indicator=image_only_indicator,
|
2036 |
+
return_dict=False,
|
2037 |
+
)[0]
|
2038 |
+
hidden_states = resnet(
|
2039 |
+
hidden_states,
|
2040 |
+
temb,
|
2041 |
+
image_only_indicator=image_only_indicator,
|
2042 |
+
)
|
2043 |
+
|
2044 |
+
return hidden_states
|
2045 |
+
|
2046 |
+
|
2047 |
+
class DownBlockSpatioTemporal(nn.Module):
|
2048 |
+
def __init__(
|
2049 |
+
self,
|
2050 |
+
in_channels: int,
|
2051 |
+
out_channels: int,
|
2052 |
+
temb_channels: int,
|
2053 |
+
num_layers: int = 1,
|
2054 |
+
add_downsample: bool = True,
|
2055 |
+
):
|
2056 |
+
super().__init__()
|
2057 |
+
resnets = []
|
2058 |
+
|
2059 |
+
for i in range(num_layers):
|
2060 |
+
in_channels = in_channels if i == 0 else out_channels
|
2061 |
+
resnets.append(
|
2062 |
+
SpatioTemporalResBlock(
|
2063 |
+
in_channels=in_channels,
|
2064 |
+
out_channels=out_channels,
|
2065 |
+
temb_channels=temb_channels,
|
2066 |
+
eps=1e-5,
|
2067 |
+
)
|
2068 |
+
)
|
2069 |
+
|
2070 |
+
self.resnets = nn.ModuleList(resnets)
|
2071 |
+
|
2072 |
+
if add_downsample:
|
2073 |
+
self.downsamplers = nn.ModuleList(
|
2074 |
+
[
|
2075 |
+
Downsample2D(
|
2076 |
+
out_channels,
|
2077 |
+
use_conv=True,
|
2078 |
+
out_channels=out_channels,
|
2079 |
+
name="op",
|
2080 |
+
)
|
2081 |
+
]
|
2082 |
+
)
|
2083 |
+
else:
|
2084 |
+
self.downsamplers = None
|
2085 |
+
|
2086 |
+
self.gradient_checkpointing = False
|
2087 |
+
|
2088 |
+
def forward(
|
2089 |
+
self,
|
2090 |
+
hidden_states: torch.FloatTensor,
|
2091 |
+
temb: Optional[torch.FloatTensor] = None,
|
2092 |
+
image_only_indicator: Optional[torch.Tensor] = None,
|
2093 |
+
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
2094 |
+
output_states = ()
|
2095 |
+
for resnet in self.resnets:
|
2096 |
+
if self.training and self.gradient_checkpointing:
|
2097 |
+
|
2098 |
+
def create_custom_forward(module):
|
2099 |
+
def custom_forward(*inputs):
|
2100 |
+
return module(*inputs)
|
2101 |
+
|
2102 |
+
return custom_forward
|
2103 |
+
|
2104 |
+
if is_torch_version(">=", "1.11.0"):
|
2105 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
2106 |
+
create_custom_forward(resnet),
|
2107 |
+
hidden_states,
|
2108 |
+
temb,
|
2109 |
+
image_only_indicator,
|
2110 |
+
use_reentrant=False,
|
2111 |
+
)
|
2112 |
+
else:
|
2113 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
2114 |
+
create_custom_forward(resnet),
|
2115 |
+
hidden_states,
|
2116 |
+
temb,
|
2117 |
+
image_only_indicator,
|
2118 |
+
)
|
2119 |
+
else:
|
2120 |
+
hidden_states = resnet(
|
2121 |
+
hidden_states,
|
2122 |
+
temb,
|
2123 |
+
image_only_indicator=image_only_indicator,
|
2124 |
+
)
|
2125 |
+
|
2126 |
+
output_states = output_states + (hidden_states,)
|
2127 |
+
|
2128 |
+
if self.downsamplers is not None:
|
2129 |
+
for downsampler in self.downsamplers:
|
2130 |
+
hidden_states = downsampler(hidden_states)
|
2131 |
+
|
2132 |
+
output_states = output_states + (hidden_states,)
|
2133 |
+
|
2134 |
+
return hidden_states, output_states
|
2135 |
+
|
2136 |
+
|
2137 |
+
class CrossAttnDownBlockSpatioTemporal(nn.Module):
|
2138 |
+
def __init__(
|
2139 |
+
self,
|
2140 |
+
in_channels: int,
|
2141 |
+
out_channels: int,
|
2142 |
+
temb_channels: int,
|
2143 |
+
num_layers: int = 1,
|
2144 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
2145 |
+
num_attention_heads: int = 1,
|
2146 |
+
cross_attention_dim: int = 1280,
|
2147 |
+
add_downsample: bool = True,
|
2148 |
+
):
|
2149 |
+
super().__init__()
|
2150 |
+
resnets = []
|
2151 |
+
attentions = []
|
2152 |
+
|
2153 |
+
self.has_cross_attention = True
|
2154 |
+
self.num_attention_heads = num_attention_heads
|
2155 |
+
if isinstance(transformer_layers_per_block, int):
|
2156 |
+
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
2157 |
+
|
2158 |
+
for i in range(num_layers):
|
2159 |
+
in_channels = in_channels if i == 0 else out_channels
|
2160 |
+
resnets.append(
|
2161 |
+
SpatioTemporalResBlock(
|
2162 |
+
in_channels=in_channels,
|
2163 |
+
out_channels=out_channels,
|
2164 |
+
temb_channels=temb_channels,
|
2165 |
+
eps=1e-6,
|
2166 |
+
)
|
2167 |
+
)
|
2168 |
+
attentions.append(
|
2169 |
+
TransformerSpatioTemporalModel(
|
2170 |
+
num_attention_heads,
|
2171 |
+
out_channels // num_attention_heads,
|
2172 |
+
in_channels=out_channels,
|
2173 |
+
num_layers=transformer_layers_per_block[i],
|
2174 |
+
cross_attention_dim=cross_attention_dim,
|
2175 |
+
)
|
2176 |
+
)
|
2177 |
+
|
2178 |
+
self.attentions = nn.ModuleList(attentions)
|
2179 |
+
self.resnets = nn.ModuleList(resnets)
|
2180 |
+
|
2181 |
+
if add_downsample:
|
2182 |
+
self.downsamplers = nn.ModuleList(
|
2183 |
+
[
|
2184 |
+
Downsample2D(
|
2185 |
+
out_channels,
|
2186 |
+
use_conv=True,
|
2187 |
+
out_channels=out_channels,
|
2188 |
+
padding=1,
|
2189 |
+
name="op",
|
2190 |
+
)
|
2191 |
+
]
|
2192 |
+
)
|
2193 |
+
else:
|
2194 |
+
self.downsamplers = None
|
2195 |
+
|
2196 |
+
self.gradient_checkpointing = False
|
2197 |
+
|
2198 |
+
def forward(
|
2199 |
+
self,
|
2200 |
+
hidden_states: torch.FloatTensor,
|
2201 |
+
temb: Optional[torch.FloatTensor] = None,
|
2202 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
2203 |
+
image_only_indicator: Optional[torch.Tensor] = None,
|
2204 |
+
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
2205 |
+
output_states = ()
|
2206 |
+
|
2207 |
+
blocks = list(zip(self.resnets, self.attentions))
|
2208 |
+
for resnet, attn in blocks:
|
2209 |
+
if self.training and self.gradient_checkpointing: # TODO
|
2210 |
+
|
2211 |
+
def create_custom_forward(module, return_dict=None):
|
2212 |
+
def custom_forward(*inputs):
|
2213 |
+
if return_dict is not None:
|
2214 |
+
return module(*inputs, return_dict=return_dict)
|
2215 |
+
else:
|
2216 |
+
return module(*inputs)
|
2217 |
+
|
2218 |
+
return custom_forward
|
2219 |
+
|
2220 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
2221 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
2222 |
+
create_custom_forward(resnet),
|
2223 |
+
hidden_states,
|
2224 |
+
temb,
|
2225 |
+
image_only_indicator,
|
2226 |
+
**ckpt_kwargs,
|
2227 |
+
)
|
2228 |
+
|
2229 |
+
hidden_states = attn(
|
2230 |
+
hidden_states,
|
2231 |
+
encoder_hidden_states=encoder_hidden_states,
|
2232 |
+
image_only_indicator=image_only_indicator,
|
2233 |
+
return_dict=False,
|
2234 |
+
)[0]
|
2235 |
+
else:
|
2236 |
+
hidden_states = resnet(
|
2237 |
+
hidden_states,
|
2238 |
+
temb,
|
2239 |
+
image_only_indicator=image_only_indicator,
|
2240 |
+
)
|
2241 |
+
hidden_states = attn(
|
2242 |
+
hidden_states,
|
2243 |
+
encoder_hidden_states=encoder_hidden_states,
|
2244 |
+
image_only_indicator=image_only_indicator,
|
2245 |
+
return_dict=False,
|
2246 |
+
)[0]
|
2247 |
+
|
2248 |
+
output_states = output_states + (hidden_states,)
|
2249 |
+
|
2250 |
+
if self.downsamplers is not None:
|
2251 |
+
for downsampler in self.downsamplers:
|
2252 |
+
hidden_states = downsampler(hidden_states)
|
2253 |
+
|
2254 |
+
output_states = output_states + (hidden_states,)
|
2255 |
+
|
2256 |
+
return hidden_states, output_states
|
2257 |
+
|
2258 |
+
|
2259 |
+
class UpBlockSpatioTemporal(nn.Module):
|
2260 |
+
def __init__(
|
2261 |
+
self,
|
2262 |
+
in_channels: int,
|
2263 |
+
prev_output_channel: int,
|
2264 |
+
out_channels: int,
|
2265 |
+
temb_channels: int,
|
2266 |
+
resolution_idx: Optional[int] = None,
|
2267 |
+
num_layers: int = 1,
|
2268 |
+
resnet_eps: float = 1e-6,
|
2269 |
+
add_upsample: bool = True,
|
2270 |
+
):
|
2271 |
+
super().__init__()
|
2272 |
+
resnets = []
|
2273 |
+
|
2274 |
+
for i in range(num_layers):
|
2275 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
2276 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
2277 |
+
|
2278 |
+
resnets.append(
|
2279 |
+
SpatioTemporalResBlock(
|
2280 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
2281 |
+
out_channels=out_channels,
|
2282 |
+
temb_channels=temb_channels,
|
2283 |
+
eps=resnet_eps,
|
2284 |
+
)
|
2285 |
+
)
|
2286 |
+
|
2287 |
+
self.resnets = nn.ModuleList(resnets)
|
2288 |
+
|
2289 |
+
if add_upsample:
|
2290 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
2291 |
+
else:
|
2292 |
+
self.upsamplers = None
|
2293 |
+
|
2294 |
+
self.gradient_checkpointing = False
|
2295 |
+
self.resolution_idx = resolution_idx
|
2296 |
+
|
2297 |
+
def forward(
|
2298 |
+
self,
|
2299 |
+
hidden_states: torch.FloatTensor,
|
2300 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
2301 |
+
temb: Optional[torch.FloatTensor] = None,
|
2302 |
+
image_only_indicator: Optional[torch.Tensor] = None,
|
2303 |
+
) -> torch.FloatTensor:
|
2304 |
+
for resnet in self.resnets:
|
2305 |
+
# pop res hidden states
|
2306 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
2307 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
2308 |
+
|
2309 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
2310 |
+
|
2311 |
+
if self.training and self.gradient_checkpointing:
|
2312 |
+
|
2313 |
+
def create_custom_forward(module):
|
2314 |
+
def custom_forward(*inputs):
|
2315 |
+
return module(*inputs)
|
2316 |
+
|
2317 |
+
return custom_forward
|
2318 |
+
|
2319 |
+
if is_torch_version(">=", "1.11.0"):
|
2320 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
2321 |
+
create_custom_forward(resnet),
|
2322 |
+
hidden_states,
|
2323 |
+
temb,
|
2324 |
+
image_only_indicator,
|
2325 |
+
use_reentrant=False,
|
2326 |
+
)
|
2327 |
+
else:
|
2328 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
2329 |
+
create_custom_forward(resnet),
|
2330 |
+
hidden_states,
|
2331 |
+
temb,
|
2332 |
+
image_only_indicator,
|
2333 |
+
)
|
2334 |
+
else:
|
2335 |
+
hidden_states = resnet(
|
2336 |
+
hidden_states,
|
2337 |
+
temb,
|
2338 |
+
image_only_indicator=image_only_indicator,
|
2339 |
+
)
|
2340 |
+
|
2341 |
+
if self.upsamplers is not None:
|
2342 |
+
for upsampler in self.upsamplers:
|
2343 |
+
hidden_states = upsampler(hidden_states)
|
2344 |
+
|
2345 |
+
return hidden_states
|
2346 |
+
|
2347 |
+
|
2348 |
+
class CrossAttnUpBlockSpatioTemporal(nn.Module):
|
2349 |
+
def __init__(
|
2350 |
+
self,
|
2351 |
+
in_channels: int,
|
2352 |
+
out_channels: int,
|
2353 |
+
prev_output_channel: int,
|
2354 |
+
temb_channels: int,
|
2355 |
+
resolution_idx: Optional[int] = None,
|
2356 |
+
num_layers: int = 1,
|
2357 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
2358 |
+
resnet_eps: float = 1e-6,
|
2359 |
+
num_attention_heads: int = 1,
|
2360 |
+
cross_attention_dim: int = 1280,
|
2361 |
+
add_upsample: bool = True,
|
2362 |
+
):
|
2363 |
+
super().__init__()
|
2364 |
+
resnets = []
|
2365 |
+
attentions = []
|
2366 |
+
|
2367 |
+
self.has_cross_attention = True
|
2368 |
+
self.num_attention_heads = num_attention_heads
|
2369 |
+
|
2370 |
+
if isinstance(transformer_layers_per_block, int):
|
2371 |
+
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
2372 |
+
|
2373 |
+
for i in range(num_layers):
|
2374 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
2375 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
2376 |
+
|
2377 |
+
resnets.append(
|
2378 |
+
SpatioTemporalResBlock(
|
2379 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
2380 |
+
out_channels=out_channels,
|
2381 |
+
temb_channels=temb_channels,
|
2382 |
+
eps=resnet_eps,
|
2383 |
+
)
|
2384 |
+
)
|
2385 |
+
attentions.append(
|
2386 |
+
TransformerSpatioTemporalModel(
|
2387 |
+
num_attention_heads,
|
2388 |
+
out_channels // num_attention_heads,
|
2389 |
+
in_channels=out_channels,
|
2390 |
+
num_layers=transformer_layers_per_block[i],
|
2391 |
+
cross_attention_dim=cross_attention_dim,
|
2392 |
+
)
|
2393 |
+
)
|
2394 |
+
|
2395 |
+
self.attentions = nn.ModuleList(attentions)
|
2396 |
+
self.resnets = nn.ModuleList(resnets)
|
2397 |
+
|
2398 |
+
if add_upsample:
|
2399 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
2400 |
+
else:
|
2401 |
+
self.upsamplers = None
|
2402 |
+
|
2403 |
+
self.gradient_checkpointing = False
|
2404 |
+
self.resolution_idx = resolution_idx
|
2405 |
+
|
2406 |
+
def forward(
|
2407 |
+
self,
|
2408 |
+
hidden_states: torch.FloatTensor,
|
2409 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
2410 |
+
temb: Optional[torch.FloatTensor] = None,
|
2411 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
2412 |
+
image_only_indicator: Optional[torch.Tensor] = None,
|
2413 |
+
) -> torch.FloatTensor:
|
2414 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
2415 |
+
# pop res hidden states
|
2416 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
2417 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
2418 |
+
|
2419 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
2420 |
+
|
2421 |
+
if self.training and self.gradient_checkpointing: # TODO
|
2422 |
+
|
2423 |
+
def create_custom_forward(module, return_dict=None):
|
2424 |
+
def custom_forward(*inputs):
|
2425 |
+
if return_dict is not None:
|
2426 |
+
return module(*inputs, return_dict=return_dict)
|
2427 |
+
else:
|
2428 |
+
return module(*inputs)
|
2429 |
+
|
2430 |
+
return custom_forward
|
2431 |
+
|
2432 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
2433 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
2434 |
+
create_custom_forward(resnet),
|
2435 |
+
hidden_states,
|
2436 |
+
temb,
|
2437 |
+
image_only_indicator,
|
2438 |
+
**ckpt_kwargs,
|
2439 |
+
)
|
2440 |
+
hidden_states = attn(
|
2441 |
+
hidden_states,
|
2442 |
+
encoder_hidden_states=encoder_hidden_states,
|
2443 |
+
image_only_indicator=image_only_indicator,
|
2444 |
+
return_dict=False,
|
2445 |
+
)[0]
|
2446 |
+
else:
|
2447 |
+
hidden_states = resnet(
|
2448 |
+
hidden_states,
|
2449 |
+
temb,
|
2450 |
+
image_only_indicator=image_only_indicator,
|
2451 |
+
)
|
2452 |
+
hidden_states = attn(
|
2453 |
+
hidden_states,
|
2454 |
+
encoder_hidden_states=encoder_hidden_states,
|
2455 |
+
image_only_indicator=image_only_indicator,
|
2456 |
+
return_dict=False,
|
2457 |
+
)[0]
|
2458 |
+
|
2459 |
+
if self.upsamplers is not None:
|
2460 |
+
for upsampler in self.upsamplers:
|
2461 |
+
hidden_states = upsampler(hidden_states)
|
2462 |
+
|
2463 |
+
return hidden_states
|
libs/unet_motion_model.py
ADDED
@@ -0,0 +1,975 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.utils.checkpoint
|
19 |
+
|
20 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
21 |
+
from diffusers.loaders import UNet2DConditionLoadersMixin
|
22 |
+
from diffusers.utils import logging, deprecate
|
23 |
+
from diffusers.models.attention_processor import (
|
24 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
25 |
+
CROSS_ATTENTION_PROCESSORS,
|
26 |
+
AttentionProcessor,
|
27 |
+
AttnAddedKVProcessor,
|
28 |
+
AttnProcessor,
|
29 |
+
)
|
30 |
+
# from diffusers.models.controlnet import ControlNetConditioningEmbedding
|
31 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
32 |
+
from diffusers.models.modeling_utils import ModelMixin
|
33 |
+
from diffusers.models.transformers.transformer_temporal import TransformerTemporalModel
|
34 |
+
from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2DCrossAttn
|
35 |
+
from .unet_2d_condition import UNet2DConditionModel
|
36 |
+
from .unet_3d_blocks import (
|
37 |
+
CrossAttnDownBlockMotion,
|
38 |
+
CrossAttnUpBlockMotion,
|
39 |
+
DownBlockMotion,
|
40 |
+
UNetMidBlockCrossAttnMotion,
|
41 |
+
UpBlockMotion,
|
42 |
+
get_down_block,
|
43 |
+
get_up_block,
|
44 |
+
)
|
45 |
+
from diffusers.models.unets.unet_3d_condition import UNet3DConditionOutput
|
46 |
+
|
47 |
+
|
48 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
49 |
+
|
50 |
+
|
51 |
+
class MotionModules(nn.Module):
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
in_channels: int,
|
55 |
+
layers_per_block: int = 2,
|
56 |
+
num_attention_heads: int = 8,
|
57 |
+
attention_bias: bool = False,
|
58 |
+
cross_attention_dim: Optional[int] = None,
|
59 |
+
activation_fn: str = "geglu",
|
60 |
+
norm_num_groups: int = 32,
|
61 |
+
max_seq_length: int = 32,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
self.motion_modules = nn.ModuleList([])
|
65 |
+
|
66 |
+
for i in range(layers_per_block):
|
67 |
+
self.motion_modules.append(
|
68 |
+
TransformerTemporalModel(
|
69 |
+
in_channels=in_channels,
|
70 |
+
norm_num_groups=norm_num_groups,
|
71 |
+
cross_attention_dim=cross_attention_dim,
|
72 |
+
activation_fn=activation_fn,
|
73 |
+
attention_bias=attention_bias,
|
74 |
+
num_attention_heads=num_attention_heads,
|
75 |
+
attention_head_dim=in_channels // num_attention_heads,
|
76 |
+
positional_embeddings="sinusoidal",
|
77 |
+
num_positional_embeddings=max_seq_length,
|
78 |
+
)
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
class MotionAdapter(ModelMixin, ConfigMixin):
|
83 |
+
@register_to_config
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
87 |
+
motion_layers_per_block: int = 2,
|
88 |
+
motion_mid_block_layers_per_block: int = 1,
|
89 |
+
motion_num_attention_heads: int = 8,
|
90 |
+
motion_norm_num_groups: int = 32,
|
91 |
+
motion_max_seq_length: int = 32,
|
92 |
+
use_motion_mid_block: bool = True,
|
93 |
+
):
|
94 |
+
"""Container to store AnimateDiff Motion Modules
|
95 |
+
|
96 |
+
Args:
|
97 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
98 |
+
The tuple of output channels for each UNet block.
|
99 |
+
motion_layers_per_block (`int`, *optional*, defaults to 2):
|
100 |
+
The number of motion layers per UNet block.
|
101 |
+
motion_mid_block_layers_per_block (`int`, *optional*, defaults to 1):
|
102 |
+
The number of motion layers in the middle UNet block.
|
103 |
+
motion_num_attention_heads (`int`, *optional*, defaults to 8):
|
104 |
+
The number of heads to use in each attention layer of the motion module.
|
105 |
+
motion_norm_num_groups (`int`, *optional*, defaults to 32):
|
106 |
+
The number of groups to use in each group normalization layer of the motion module.
|
107 |
+
motion_max_seq_length (`int`, *optional*, defaults to 32):
|
108 |
+
The maximum sequence length to use in the motion module.
|
109 |
+
use_motion_mid_block (`bool`, *optional*, defaults to True):
|
110 |
+
Whether to use a motion module in the middle of the UNet.
|
111 |
+
"""
|
112 |
+
|
113 |
+
super().__init__()
|
114 |
+
down_blocks = []
|
115 |
+
up_blocks = []
|
116 |
+
|
117 |
+
for i, channel in enumerate(block_out_channels):
|
118 |
+
output_channel = block_out_channels[i]
|
119 |
+
down_blocks.append(
|
120 |
+
MotionModules(
|
121 |
+
in_channels=output_channel,
|
122 |
+
norm_num_groups=motion_norm_num_groups,
|
123 |
+
cross_attention_dim=None,
|
124 |
+
activation_fn="geglu",
|
125 |
+
attention_bias=False,
|
126 |
+
num_attention_heads=motion_num_attention_heads,
|
127 |
+
max_seq_length=motion_max_seq_length,
|
128 |
+
layers_per_block=motion_layers_per_block,
|
129 |
+
)
|
130 |
+
)
|
131 |
+
|
132 |
+
if use_motion_mid_block:
|
133 |
+
self.mid_block = MotionModules(
|
134 |
+
in_channels=block_out_channels[-1],
|
135 |
+
norm_num_groups=motion_norm_num_groups,
|
136 |
+
cross_attention_dim=None,
|
137 |
+
activation_fn="geglu",
|
138 |
+
attention_bias=False,
|
139 |
+
num_attention_heads=motion_num_attention_heads,
|
140 |
+
layers_per_block=motion_mid_block_layers_per_block,
|
141 |
+
max_seq_length=motion_max_seq_length,
|
142 |
+
)
|
143 |
+
else:
|
144 |
+
self.mid_block = None
|
145 |
+
|
146 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
147 |
+
output_channel = reversed_block_out_channels[0]
|
148 |
+
for i, channel in enumerate(reversed_block_out_channels):
|
149 |
+
output_channel = reversed_block_out_channels[i]
|
150 |
+
up_blocks.append(
|
151 |
+
MotionModules(
|
152 |
+
in_channels=output_channel,
|
153 |
+
norm_num_groups=motion_norm_num_groups,
|
154 |
+
cross_attention_dim=None,
|
155 |
+
activation_fn="geglu",
|
156 |
+
attention_bias=False,
|
157 |
+
num_attention_heads=motion_num_attention_heads,
|
158 |
+
max_seq_length=motion_max_seq_length,
|
159 |
+
layers_per_block=motion_layers_per_block + 1,
|
160 |
+
)
|
161 |
+
)
|
162 |
+
|
163 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
164 |
+
self.up_blocks = nn.ModuleList(up_blocks)
|
165 |
+
|
166 |
+
def forward(self, sample):
|
167 |
+
pass
|
168 |
+
|
169 |
+
|
170 |
+
class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
171 |
+
r"""
|
172 |
+
A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a
|
173 |
+
sample shaped output.
|
174 |
+
|
175 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
176 |
+
for all models (such as downloading or saving).
|
177 |
+
"""
|
178 |
+
|
179 |
+
_supports_gradient_checkpointing = True
|
180 |
+
|
181 |
+
@register_to_config
|
182 |
+
def __init__(
|
183 |
+
self,
|
184 |
+
sample_size: Optional[int] = None,
|
185 |
+
in_channels: int = 4,
|
186 |
+
conditioning_channels: int = 3,
|
187 |
+
out_channels: int = 4,
|
188 |
+
down_block_types: Tuple[str, ...] = (
|
189 |
+
"CrossAttnDownBlockMotion",
|
190 |
+
"CrossAttnDownBlockMotion",
|
191 |
+
"CrossAttnDownBlockMotion",
|
192 |
+
"DownBlockMotion",
|
193 |
+
),
|
194 |
+
mid_block_type: Optional[str] = "UNetMidBlockCrossAttnMotion",
|
195 |
+
up_block_types: Tuple[str, ...] = (
|
196 |
+
"UpBlockMotion",
|
197 |
+
"CrossAttnUpBlockMotion",
|
198 |
+
"CrossAttnUpBlockMotion",
|
199 |
+
"CrossAttnUpBlockMotion",
|
200 |
+
),
|
201 |
+
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
202 |
+
layers_per_block: int = 2,
|
203 |
+
downsample_padding: int = 1,
|
204 |
+
mid_block_scale_factor: float = 1,
|
205 |
+
act_fn: str = "silu",
|
206 |
+
norm_num_groups: int = 32,
|
207 |
+
norm_eps: float = 1e-5,
|
208 |
+
cross_attention_dim: int = 1280,
|
209 |
+
use_linear_projection: bool = False,
|
210 |
+
num_attention_heads: Union[int, Tuple[int, ...]] = 8,
|
211 |
+
motion_max_seq_length: int = 32,
|
212 |
+
motion_num_attention_heads: int = 8,
|
213 |
+
use_motion_mid_block: int = True,
|
214 |
+
encoder_hid_dim: Optional[int] = None,
|
215 |
+
encoder_hid_dim_type: Optional[str] = None,
|
216 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
217 |
+
):
|
218 |
+
super().__init__()
|
219 |
+
|
220 |
+
self.sample_size = sample_size
|
221 |
+
|
222 |
+
# Check inputs
|
223 |
+
if len(down_block_types) != len(up_block_types):
|
224 |
+
raise ValueError(
|
225 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
226 |
+
)
|
227 |
+
|
228 |
+
if len(block_out_channels) != len(down_block_types):
|
229 |
+
raise ValueError(
|
230 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
231 |
+
)
|
232 |
+
|
233 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
234 |
+
raise ValueError(
|
235 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
236 |
+
)
|
237 |
+
|
238 |
+
# input
|
239 |
+
conv_in_kernel = 3
|
240 |
+
conv_out_kernel = 3
|
241 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
242 |
+
self.conv_in = nn.Conv2d(
|
243 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
244 |
+
)
|
245 |
+
|
246 |
+
# time
|
247 |
+
time_embed_dim = block_out_channels[0] * 4
|
248 |
+
self.time_proj = Timesteps(block_out_channels[0], True, 0)
|
249 |
+
timestep_input_dim = block_out_channels[0]
|
250 |
+
|
251 |
+
self.time_embedding = TimestepEmbedding(
|
252 |
+
timestep_input_dim,
|
253 |
+
time_embed_dim,
|
254 |
+
act_fn=act_fn,
|
255 |
+
)
|
256 |
+
|
257 |
+
if encoder_hid_dim_type is None:
|
258 |
+
self.encoder_hid_proj = None
|
259 |
+
|
260 |
+
# control net conditioning embedding
|
261 |
+
# self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
262 |
+
# conditioning_embedding_channels=block_out_channels[0],
|
263 |
+
# block_out_channels=conditioning_embedding_out_channels,
|
264 |
+
# conditioning_channels=conditioning_channels,
|
265 |
+
# )
|
266 |
+
|
267 |
+
# class embedding
|
268 |
+
self.down_blocks = nn.ModuleList([])
|
269 |
+
self.up_blocks = nn.ModuleList([])
|
270 |
+
|
271 |
+
if isinstance(num_attention_heads, int):
|
272 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
273 |
+
|
274 |
+
# down
|
275 |
+
output_channel = block_out_channels[0]
|
276 |
+
for i, down_block_type in enumerate(down_block_types):
|
277 |
+
input_channel = output_channel
|
278 |
+
output_channel = block_out_channels[i]
|
279 |
+
is_final_block = i == len(block_out_channels) - 1
|
280 |
+
|
281 |
+
down_block = get_down_block(
|
282 |
+
down_block_type,
|
283 |
+
num_layers=layers_per_block,
|
284 |
+
in_channels=input_channel,
|
285 |
+
out_channels=output_channel,
|
286 |
+
temb_channels=time_embed_dim,
|
287 |
+
add_downsample=not is_final_block,
|
288 |
+
resnet_eps=norm_eps,
|
289 |
+
resnet_act_fn=act_fn,
|
290 |
+
resnet_groups=norm_num_groups,
|
291 |
+
cross_attention_dim=cross_attention_dim,
|
292 |
+
num_attention_heads=num_attention_heads[i],
|
293 |
+
downsample_padding=downsample_padding,
|
294 |
+
use_linear_projection=use_linear_projection,
|
295 |
+
dual_cross_attention=False,
|
296 |
+
temporal_num_attention_heads=motion_num_attention_heads,
|
297 |
+
temporal_max_seq_length=motion_max_seq_length,
|
298 |
+
)
|
299 |
+
self.down_blocks.append(down_block)
|
300 |
+
|
301 |
+
# mid
|
302 |
+
if use_motion_mid_block:
|
303 |
+
self.mid_block = UNetMidBlockCrossAttnMotion(
|
304 |
+
in_channels=block_out_channels[-1],
|
305 |
+
temb_channels=time_embed_dim,
|
306 |
+
resnet_eps=norm_eps,
|
307 |
+
resnet_act_fn=act_fn,
|
308 |
+
output_scale_factor=mid_block_scale_factor,
|
309 |
+
cross_attention_dim=cross_attention_dim,
|
310 |
+
num_attention_heads=num_attention_heads[-1],
|
311 |
+
resnet_groups=norm_num_groups,
|
312 |
+
dual_cross_attention=False,
|
313 |
+
temporal_num_attention_heads=motion_num_attention_heads,
|
314 |
+
temporal_max_seq_length=motion_max_seq_length,
|
315 |
+
)
|
316 |
+
|
317 |
+
else:
|
318 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
319 |
+
in_channels=block_out_channels[-1],
|
320 |
+
temb_channels=time_embed_dim,
|
321 |
+
resnet_eps=norm_eps,
|
322 |
+
resnet_act_fn=act_fn,
|
323 |
+
output_scale_factor=mid_block_scale_factor,
|
324 |
+
cross_attention_dim=cross_attention_dim,
|
325 |
+
num_attention_heads=num_attention_heads[-1],
|
326 |
+
resnet_groups=norm_num_groups,
|
327 |
+
dual_cross_attention=False,
|
328 |
+
)
|
329 |
+
|
330 |
+
# count how many layers upsample the images
|
331 |
+
self.num_upsamplers = 0
|
332 |
+
|
333 |
+
# up
|
334 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
335 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
336 |
+
|
337 |
+
output_channel = reversed_block_out_channels[0]
|
338 |
+
for i, up_block_type in enumerate(up_block_types):
|
339 |
+
is_final_block = i == len(block_out_channels) - 1
|
340 |
+
|
341 |
+
prev_output_channel = output_channel
|
342 |
+
output_channel = reversed_block_out_channels[i]
|
343 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
344 |
+
|
345 |
+
# add upsample block for all BUT final layer
|
346 |
+
if not is_final_block:
|
347 |
+
add_upsample = True
|
348 |
+
self.num_upsamplers += 1
|
349 |
+
else:
|
350 |
+
add_upsample = False
|
351 |
+
|
352 |
+
up_block = get_up_block(
|
353 |
+
up_block_type,
|
354 |
+
num_layers=layers_per_block + 1,
|
355 |
+
in_channels=input_channel,
|
356 |
+
out_channels=output_channel,
|
357 |
+
prev_output_channel=prev_output_channel,
|
358 |
+
temb_channels=time_embed_dim,
|
359 |
+
add_upsample=add_upsample,
|
360 |
+
resnet_eps=norm_eps,
|
361 |
+
resnet_act_fn=act_fn,
|
362 |
+
resnet_groups=norm_num_groups,
|
363 |
+
cross_attention_dim=cross_attention_dim,
|
364 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
365 |
+
dual_cross_attention=False,
|
366 |
+
resolution_idx=i,
|
367 |
+
use_linear_projection=use_linear_projection,
|
368 |
+
temporal_num_attention_heads=motion_num_attention_heads,
|
369 |
+
temporal_max_seq_length=motion_max_seq_length,
|
370 |
+
)
|
371 |
+
self.up_blocks.append(up_block)
|
372 |
+
prev_output_channel = output_channel
|
373 |
+
|
374 |
+
# out
|
375 |
+
if norm_num_groups is not None:
|
376 |
+
self.conv_norm_out = nn.GroupNorm(
|
377 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
378 |
+
)
|
379 |
+
self.conv_act = nn.SiLU()
|
380 |
+
else:
|
381 |
+
self.conv_norm_out = None
|
382 |
+
self.conv_act = None
|
383 |
+
|
384 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
385 |
+
self.conv_out = nn.Conv2d(
|
386 |
+
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
387 |
+
)
|
388 |
+
|
389 |
+
@classmethod
|
390 |
+
def from_unet2d(
|
391 |
+
cls,
|
392 |
+
unet: UNet2DConditionModel,
|
393 |
+
motion_adapter: Optional[MotionAdapter] = None,
|
394 |
+
load_weights: bool = True,
|
395 |
+
):
|
396 |
+
has_motion_adapter = motion_adapter is not None
|
397 |
+
|
398 |
+
# based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459
|
399 |
+
config = unet.config
|
400 |
+
config["_class_name"] = cls.__name__
|
401 |
+
|
402 |
+
down_blocks = []
|
403 |
+
for down_blocks_type in config["down_block_types"]:
|
404 |
+
if "CrossAttn" in down_blocks_type:
|
405 |
+
down_blocks.append("CrossAttnDownBlockMotion")
|
406 |
+
else:
|
407 |
+
down_blocks.append("DownBlockMotion")
|
408 |
+
config["down_block_types"] = down_blocks
|
409 |
+
|
410 |
+
up_blocks = []
|
411 |
+
for down_blocks_type in config["up_block_types"]:
|
412 |
+
if "CrossAttn" in down_blocks_type:
|
413 |
+
up_blocks.append("CrossAttnUpBlockMotion")
|
414 |
+
else:
|
415 |
+
up_blocks.append("UpBlockMotion")
|
416 |
+
|
417 |
+
config["up_block_types"] = up_blocks
|
418 |
+
|
419 |
+
if has_motion_adapter:
|
420 |
+
config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"]
|
421 |
+
config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"]
|
422 |
+
config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"]
|
423 |
+
|
424 |
+
# Need this for backwards compatibility with UNet2DConditionModel checkpoints
|
425 |
+
if not config.get("num_attention_heads"):
|
426 |
+
config["num_attention_heads"] = config["attention_head_dim"]
|
427 |
+
|
428 |
+
model = cls.from_config(config)
|
429 |
+
|
430 |
+
if not load_weights:
|
431 |
+
return model
|
432 |
+
|
433 |
+
model.conv_in.load_state_dict(unet.conv_in.state_dict())
|
434 |
+
model.time_proj.load_state_dict(unet.time_proj.state_dict())
|
435 |
+
model.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
436 |
+
# model.controlnet_cond_embedding.load_state_dict(unet.controlnet_cond_embedding.state_dict()) # pose guider
|
437 |
+
|
438 |
+
for i, down_block in enumerate(unet.down_blocks):
|
439 |
+
model.down_blocks[i].resnets.load_state_dict(down_block.resnets.state_dict())
|
440 |
+
if hasattr(model.down_blocks[i], "attentions"):
|
441 |
+
model.down_blocks[i].attentions.load_state_dict(down_block.attentions.state_dict())
|
442 |
+
if model.down_blocks[i].downsamplers:
|
443 |
+
model.down_blocks[i].downsamplers.load_state_dict(down_block.downsamplers.state_dict())
|
444 |
+
|
445 |
+
for i, up_block in enumerate(unet.up_blocks):
|
446 |
+
model.up_blocks[i].resnets.load_state_dict(up_block.resnets.state_dict())
|
447 |
+
if hasattr(model.up_blocks[i], "attentions"):
|
448 |
+
model.up_blocks[i].attentions.load_state_dict(up_block.attentions.state_dict())
|
449 |
+
if model.up_blocks[i].upsamplers:
|
450 |
+
model.up_blocks[i].upsamplers.load_state_dict(up_block.upsamplers.state_dict())
|
451 |
+
|
452 |
+
model.mid_block.resnets.load_state_dict(unet.mid_block.resnets.state_dict())
|
453 |
+
model.mid_block.attentions.load_state_dict(unet.mid_block.attentions.state_dict())
|
454 |
+
|
455 |
+
if unet.conv_norm_out is not None:
|
456 |
+
model.conv_norm_out.load_state_dict(unet.conv_norm_out.state_dict())
|
457 |
+
if unet.conv_act is not None:
|
458 |
+
model.conv_act.load_state_dict(unet.conv_act.state_dict())
|
459 |
+
model.conv_out.load_state_dict(unet.conv_out.state_dict())
|
460 |
+
|
461 |
+
if has_motion_adapter:
|
462 |
+
model.load_motion_modules(motion_adapter)
|
463 |
+
|
464 |
+
# ensure that the Motion UNet is the same dtype as the UNet2DConditionModel
|
465 |
+
model.to(unet.dtype)
|
466 |
+
|
467 |
+
return model
|
468 |
+
|
469 |
+
def freeze_unet2d_params(self) -> None:
|
470 |
+
"""Freeze the weights of just the UNet2DConditionModel, and leave the motion modules
|
471 |
+
unfrozen for fine tuning.
|
472 |
+
"""
|
473 |
+
# Freeze everything
|
474 |
+
for param in self.parameters():
|
475 |
+
param.requires_grad = False
|
476 |
+
|
477 |
+
# Unfreeze Motion Modules
|
478 |
+
for down_block in self.down_blocks:
|
479 |
+
motion_modules = down_block.motion_modules
|
480 |
+
for param in motion_modules.parameters():
|
481 |
+
param.requires_grad = True
|
482 |
+
|
483 |
+
for up_block in self.up_blocks:
|
484 |
+
motion_modules = up_block.motion_modules
|
485 |
+
for param in motion_modules.parameters():
|
486 |
+
param.requires_grad = True
|
487 |
+
|
488 |
+
if hasattr(self.mid_block, "motion_modules"):
|
489 |
+
motion_modules = self.mid_block.motion_modules
|
490 |
+
for param in motion_modules.parameters():
|
491 |
+
param.requires_grad = True
|
492 |
+
|
493 |
+
def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]) -> None:
|
494 |
+
for i, down_block in enumerate(motion_adapter.down_blocks):
|
495 |
+
self.down_blocks[i].motion_modules.load_state_dict(down_block.motion_modules.state_dict())
|
496 |
+
for i, up_block in enumerate(motion_adapter.up_blocks):
|
497 |
+
self.up_blocks[i].motion_modules.load_state_dict(up_block.motion_modules.state_dict())
|
498 |
+
|
499 |
+
# to support older motion modules that don't have a mid_block
|
500 |
+
if hasattr(self.mid_block, "motion_modules"):
|
501 |
+
self.mid_block.motion_modules.load_state_dict(motion_adapter.mid_block.motion_modules.state_dict())
|
502 |
+
|
503 |
+
def save_motion_modules(
|
504 |
+
self,
|
505 |
+
save_directory: str,
|
506 |
+
is_main_process: bool = True,
|
507 |
+
safe_serialization: bool = True,
|
508 |
+
variant: Optional[str] = None,
|
509 |
+
push_to_hub: bool = False,
|
510 |
+
**kwargs,
|
511 |
+
) -> None:
|
512 |
+
state_dict = self.state_dict()
|
513 |
+
|
514 |
+
# Extract all motion modules
|
515 |
+
motion_state_dict = {}
|
516 |
+
for k, v in state_dict.items():
|
517 |
+
if "motion_modules" in k:
|
518 |
+
motion_state_dict[k] = v
|
519 |
+
|
520 |
+
adapter = MotionAdapter(
|
521 |
+
block_out_channels=self.config["block_out_channels"],
|
522 |
+
motion_layers_per_block=self.config["layers_per_block"],
|
523 |
+
motion_norm_num_groups=self.config["norm_num_groups"],
|
524 |
+
motion_num_attention_heads=self.config["motion_num_attention_heads"],
|
525 |
+
motion_max_seq_length=self.config["motion_max_seq_length"],
|
526 |
+
use_motion_mid_block=self.config["use_motion_mid_block"],
|
527 |
+
)
|
528 |
+
adapter.load_state_dict(motion_state_dict)
|
529 |
+
adapter.save_pretrained(
|
530 |
+
save_directory=save_directory,
|
531 |
+
is_main_process=is_main_process,
|
532 |
+
safe_serialization=safe_serialization,
|
533 |
+
variant=variant,
|
534 |
+
push_to_hub=push_to_hub,
|
535 |
+
**kwargs,
|
536 |
+
)
|
537 |
+
|
538 |
+
@property
|
539 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
540 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
541 |
+
r"""
|
542 |
+
Returns:
|
543 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
544 |
+
indexed by its weight name.
|
545 |
+
"""
|
546 |
+
# set recursively
|
547 |
+
processors = {}
|
548 |
+
|
549 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
550 |
+
if hasattr(module, "get_processor"):
|
551 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
552 |
+
|
553 |
+
for sub_name, child in module.named_children():
|
554 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
555 |
+
|
556 |
+
return processors
|
557 |
+
|
558 |
+
for name, module in self.named_children():
|
559 |
+
fn_recursive_add_processors(name, module, processors)
|
560 |
+
|
561 |
+
return processors
|
562 |
+
|
563 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
564 |
+
def set_attn_processor(
|
565 |
+
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
566 |
+
):
|
567 |
+
r"""
|
568 |
+
Sets the attention processor to use to compute attention.
|
569 |
+
|
570 |
+
Parameters:
|
571 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
572 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
573 |
+
for **all** `Attention` layers.
|
574 |
+
|
575 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
576 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
577 |
+
|
578 |
+
"""
|
579 |
+
count = len(self.attn_processors.keys())
|
580 |
+
|
581 |
+
if isinstance(processor, dict) and len(processor) != count:
|
582 |
+
raise ValueError(
|
583 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
584 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
585 |
+
)
|
586 |
+
|
587 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
588 |
+
if hasattr(module, "set_processor"):
|
589 |
+
if not isinstance(processor, dict):
|
590 |
+
module.set_processor(processor, _remove_lora=_remove_lora)
|
591 |
+
else:
|
592 |
+
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
593 |
+
|
594 |
+
for sub_name, child in module.named_children():
|
595 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
596 |
+
|
597 |
+
for name, module in self.named_children():
|
598 |
+
fn_recursive_attn_processor(name, module, processor)
|
599 |
+
|
600 |
+
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
601 |
+
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
602 |
+
"""
|
603 |
+
Sets the attention processor to use [feed forward
|
604 |
+
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
605 |
+
|
606 |
+
Parameters:
|
607 |
+
chunk_size (`int`, *optional*):
|
608 |
+
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
609 |
+
over each tensor of dim=`dim`.
|
610 |
+
dim (`int`, *optional*, defaults to `0`):
|
611 |
+
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
612 |
+
or dim=1 (sequence length).
|
613 |
+
"""
|
614 |
+
if dim not in [0, 1]:
|
615 |
+
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
616 |
+
|
617 |
+
# By default chunk size is 1
|
618 |
+
chunk_size = chunk_size or 1
|
619 |
+
|
620 |
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
621 |
+
if hasattr(module, "set_chunk_feed_forward"):
|
622 |
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
623 |
+
|
624 |
+
for child in module.children():
|
625 |
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
626 |
+
|
627 |
+
for module in self.children():
|
628 |
+
fn_recursive_feed_forward(module, chunk_size, dim)
|
629 |
+
|
630 |
+
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
|
631 |
+
def disable_forward_chunking(self) -> None:
|
632 |
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
633 |
+
if hasattr(module, "set_chunk_feed_forward"):
|
634 |
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
635 |
+
|
636 |
+
for child in module.children():
|
637 |
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
638 |
+
|
639 |
+
for module in self.children():
|
640 |
+
fn_recursive_feed_forward(module, None, 0)
|
641 |
+
|
642 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
643 |
+
def set_default_attn_processor(self) -> None:
|
644 |
+
"""
|
645 |
+
Disables custom attention processors and sets the default attention implementation.
|
646 |
+
"""
|
647 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
648 |
+
processor = AttnAddedKVProcessor()
|
649 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
650 |
+
processor = AttnProcessor()
|
651 |
+
else:
|
652 |
+
raise ValueError(
|
653 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
654 |
+
)
|
655 |
+
|
656 |
+
self.set_attn_processor(processor, _remove_lora=True)
|
657 |
+
|
658 |
+
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
659 |
+
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
|
660 |
+
module.gradient_checkpointing = value
|
661 |
+
|
662 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
663 |
+
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
|
664 |
+
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
665 |
+
|
666 |
+
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
667 |
+
|
668 |
+
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
|
669 |
+
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
670 |
+
|
671 |
+
Args:
|
672 |
+
s1 (`float`):
|
673 |
+
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
674 |
+
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
675 |
+
s2 (`float`):
|
676 |
+
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
677 |
+
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
678 |
+
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
679 |
+
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
680 |
+
"""
|
681 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
682 |
+
setattr(upsample_block, "s1", s1)
|
683 |
+
setattr(upsample_block, "s2", s2)
|
684 |
+
setattr(upsample_block, "b1", b1)
|
685 |
+
setattr(upsample_block, "b2", b2)
|
686 |
+
|
687 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu
|
688 |
+
def disable_freeu(self) -> None:
|
689 |
+
"""Disables the FreeU mechanism."""
|
690 |
+
freeu_keys = {"s1", "s2", "b1", "b2"}
|
691 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
692 |
+
for k in freeu_keys:
|
693 |
+
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
694 |
+
setattr(upsample_block, k, None)
|
695 |
+
|
696 |
+
def forward(
|
697 |
+
self,
|
698 |
+
sample: torch.FloatTensor,
|
699 |
+
timestep: Union[torch.Tensor, float, int],
|
700 |
+
encoder_hidden_states: torch.Tensor,
|
701 |
+
# controlnet_cond: torch.FloatTensor,
|
702 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
703 |
+
attention_mask: Optional[torch.Tensor] = None,
|
704 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
705 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
706 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
707 |
+
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
708 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
709 |
+
return_dict: bool = True,
|
710 |
+
num_frames: int = 24,
|
711 |
+
down_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
|
712 |
+
mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
|
713 |
+
up_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
|
714 |
+
) -> Union[UNet3DConditionOutput, Tuple[torch.Tensor]]:
|
715 |
+
r"""
|
716 |
+
The [`UNetMotionModel`] forward method.
|
717 |
+
|
718 |
+
Args:
|
719 |
+
sample (`torch.FloatTensor`):
|
720 |
+
The noisy input tensor with the following shape `(batch * num_frames, channel, height, width`.
|
721 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
722 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
723 |
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
724 |
+
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
|
725 |
+
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
|
726 |
+
through the `self.time_embedding` layer to obtain the timestep embeddings.
|
727 |
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
728 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
729 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
730 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
731 |
+
cross_attention_kwargs (`dict`, *optional*):
|
732 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
733 |
+
`self.processor` in
|
734 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
735 |
+
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
|
736 |
+
A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
737 |
+
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
738 |
+
A tensor that if specified is added to the residual of the middle unet block.
|
739 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
740 |
+
Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain
|
741 |
+
tuple.
|
742 |
+
|
743 |
+
Returns:
|
744 |
+
[`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`:
|
745 |
+
If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise
|
746 |
+
a `tuple` is returned where the first element is the sample tensor.
|
747 |
+
"""
|
748 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
749 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
750 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
751 |
+
# on the fly if necessary.
|
752 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
753 |
+
|
754 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
755 |
+
forward_upsample_size = False
|
756 |
+
upsample_size = None
|
757 |
+
|
758 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
759 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
760 |
+
forward_upsample_size = True
|
761 |
+
|
762 |
+
# prepare attention_mask
|
763 |
+
if attention_mask is not None:
|
764 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
765 |
+
attention_mask = attention_mask.unsqueeze(1)
|
766 |
+
|
767 |
+
# 1. time
|
768 |
+
timesteps = timestep
|
769 |
+
if not torch.is_tensor(timesteps):
|
770 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
771 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
772 |
+
is_mps = sample.device.type == "mps"
|
773 |
+
if isinstance(timestep, float):
|
774 |
+
dtype = torch.float32 if is_mps else torch.float64
|
775 |
+
else:
|
776 |
+
dtype = torch.int32 if is_mps else torch.int64
|
777 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
778 |
+
elif len(timesteps.shape) == 0:
|
779 |
+
timesteps = timesteps[None].to(sample.device)
|
780 |
+
|
781 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
782 |
+
timesteps = timesteps.expand(sample.shape[0] // num_frames)
|
783 |
+
|
784 |
+
t_emb = self.time_proj(timesteps)
|
785 |
+
|
786 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
787 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
788 |
+
# there might be better ways to encapsulate this.
|
789 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
790 |
+
|
791 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
792 |
+
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
|
793 |
+
|
794 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
795 |
+
if "image_embeds" not in added_cond_kwargs:
|
796 |
+
raise ValueError(
|
797 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
798 |
+
)
|
799 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
800 |
+
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
|
801 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
|
802 |
+
|
803 |
+
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
|
804 |
+
|
805 |
+
# 2. pre-process
|
806 |
+
# sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
|
807 |
+
# N*T C H W
|
808 |
+
sample = self.conv_in(sample)
|
809 |
+
# controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
810 |
+
# sample += controlnet_cond
|
811 |
+
|
812 |
+
# 3. down
|
813 |
+
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
814 |
+
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
|
815 |
+
is_adapter = down_intrablock_additional_residuals is not None
|
816 |
+
# maintain backward compatibility for legacy usage, where
|
817 |
+
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
|
818 |
+
# but can only use one or the other
|
819 |
+
is_brushnet = down_block_add_samples is not None and mid_block_add_sample is not None and up_block_add_samples is not None
|
820 |
+
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
|
821 |
+
deprecate(
|
822 |
+
"T2I should not use down_block_additional_residuals",
|
823 |
+
"1.3.0",
|
824 |
+
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
|
825 |
+
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
|
826 |
+
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
|
827 |
+
standard_warn=False,
|
828 |
+
)
|
829 |
+
down_intrablock_additional_residuals = down_block_additional_residuals
|
830 |
+
is_adapter = True
|
831 |
+
|
832 |
+
down_block_res_samples = (sample,)
|
833 |
+
if is_brushnet:
|
834 |
+
sample = sample + down_block_add_samples.pop(0)
|
835 |
+
|
836 |
+
for downsample_block in self.down_blocks:
|
837 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
838 |
+
# For t2i-adapter CrossAttnDownBlock2D
|
839 |
+
additional_residuals = {}
|
840 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
841 |
+
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
|
842 |
+
if is_brushnet and len(down_block_add_samples)>0:
|
843 |
+
additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
|
844 |
+
for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))]
|
845 |
+
|
846 |
+
sample, res_samples = downsample_block(
|
847 |
+
hidden_states=sample,
|
848 |
+
temb=emb,
|
849 |
+
encoder_hidden_states=encoder_hidden_states,
|
850 |
+
attention_mask=attention_mask,
|
851 |
+
num_frames=num_frames,
|
852 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
853 |
+
**additional_residuals,
|
854 |
+
)
|
855 |
+
else:
|
856 |
+
additional_residuals = {}
|
857 |
+
if is_brushnet and len(down_block_add_samples)>0:
|
858 |
+
additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
|
859 |
+
for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))]
|
860 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames, **additional_residuals,)
|
861 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
862 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
863 |
+
|
864 |
+
down_block_res_samples += res_samples
|
865 |
+
|
866 |
+
if is_controlnet:
|
867 |
+
new_down_block_res_samples = ()
|
868 |
+
|
869 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
870 |
+
down_block_res_samples, down_block_additional_residuals
|
871 |
+
):
|
872 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
873 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
874 |
+
|
875 |
+
down_block_res_samples = new_down_block_res_samples
|
876 |
+
|
877 |
+
if down_block_additional_residuals is not None:
|
878 |
+
new_down_block_res_samples = ()
|
879 |
+
|
880 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
881 |
+
down_block_res_samples, down_block_additional_residuals
|
882 |
+
):
|
883 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
884 |
+
new_down_block_res_samples += (down_block_res_sample,)
|
885 |
+
|
886 |
+
down_block_res_samples = new_down_block_res_samples
|
887 |
+
|
888 |
+
# 4. mid
|
889 |
+
if self.mid_block is not None:
|
890 |
+
# To support older versions of motion modules that don't have a mid_block
|
891 |
+
if hasattr(self.mid_block, "motion_modules"):
|
892 |
+
sample = self.mid_block(
|
893 |
+
sample,
|
894 |
+
emb,
|
895 |
+
encoder_hidden_states=encoder_hidden_states,
|
896 |
+
attention_mask=attention_mask,
|
897 |
+
num_frames=num_frames,
|
898 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
899 |
+
mid_block_add_sample=mid_block_add_sample,
|
900 |
+
)
|
901 |
+
else:
|
902 |
+
sample = self.mid_block(
|
903 |
+
sample,
|
904 |
+
emb,
|
905 |
+
encoder_hidden_states=encoder_hidden_states,
|
906 |
+
attention_mask=attention_mask,
|
907 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
908 |
+
mid_block_add_sample=mid_block_add_sample,
|
909 |
+
)
|
910 |
+
|
911 |
+
if is_controlnet:
|
912 |
+
sample = sample + mid_block_additional_residual
|
913 |
+
|
914 |
+
# if is_brushnet:
|
915 |
+
# sample = sample + mid_block_add_sample
|
916 |
+
|
917 |
+
if mid_block_additional_residual is not None:
|
918 |
+
sample = sample + mid_block_additional_residual
|
919 |
+
|
920 |
+
# 5. up
|
921 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
922 |
+
is_final_block = i == len(self.up_blocks) - 1
|
923 |
+
|
924 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
925 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
926 |
+
|
927 |
+
# if we have not reached the final block and need to forward the
|
928 |
+
# upsample size, we do it here
|
929 |
+
if not is_final_block and forward_upsample_size:
|
930 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
931 |
+
|
932 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
933 |
+
additional_residuals = {}
|
934 |
+
if is_brushnet and len(up_block_add_samples)>0:
|
935 |
+
additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
|
936 |
+
for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))]
|
937 |
+
sample = upsample_block(
|
938 |
+
hidden_states=sample,
|
939 |
+
temb=emb,
|
940 |
+
res_hidden_states_tuple=res_samples,
|
941 |
+
encoder_hidden_states=encoder_hidden_states,
|
942 |
+
upsample_size=upsample_size,
|
943 |
+
attention_mask=attention_mask,
|
944 |
+
num_frames=num_frames,
|
945 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
946 |
+
**additional_residuals,
|
947 |
+
)
|
948 |
+
else:
|
949 |
+
additional_residuals = {}
|
950 |
+
if is_brushnet and len(up_block_add_samples)>0:
|
951 |
+
additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
|
952 |
+
for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))]
|
953 |
+
sample = upsample_block(
|
954 |
+
hidden_states=sample,
|
955 |
+
temb=emb,
|
956 |
+
res_hidden_states_tuple=res_samples,
|
957 |
+
upsample_size=upsample_size,
|
958 |
+
num_frames=num_frames,
|
959 |
+
**additional_residuals,
|
960 |
+
)
|
961 |
+
|
962 |
+
# 6. post-process
|
963 |
+
if self.conv_norm_out:
|
964 |
+
sample = self.conv_norm_out(sample)
|
965 |
+
sample = self.conv_act(sample)
|
966 |
+
|
967 |
+
sample = self.conv_out(sample)
|
968 |
+
|
969 |
+
# reshape to (batch, framerate, channel, width, height)
|
970 |
+
# sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:])
|
971 |
+
|
972 |
+
if not return_dict:
|
973 |
+
return (sample,)
|
974 |
+
|
975 |
+
return UNet3DConditionOutput(sample=sample)
|
propainter/RAFT/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# from .demo import RAFT_infer
|
2 |
+
from .raft import RAFT
|
propainter/RAFT/corr.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from .utils.utils import bilinear_sampler, coords_grid
|
4 |
+
|
5 |
+
try:
|
6 |
+
import alt_cuda_corr
|
7 |
+
except:
|
8 |
+
# alt_cuda_corr is not compiled
|
9 |
+
pass
|
10 |
+
|
11 |
+
|
12 |
+
class CorrBlock:
|
13 |
+
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
14 |
+
self.num_levels = num_levels
|
15 |
+
self.radius = radius
|
16 |
+
self.corr_pyramid = []
|
17 |
+
|
18 |
+
# all pairs correlation
|
19 |
+
corr = CorrBlock.corr(fmap1, fmap2)
|
20 |
+
|
21 |
+
batch, h1, w1, dim, h2, w2 = corr.shape
|
22 |
+
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
|
23 |
+
|
24 |
+
self.corr_pyramid.append(corr)
|
25 |
+
for i in range(self.num_levels-1):
|
26 |
+
corr = F.avg_pool2d(corr, 2, stride=2)
|
27 |
+
self.corr_pyramid.append(corr)
|
28 |
+
|
29 |
+
def __call__(self, coords):
|
30 |
+
r = self.radius
|
31 |
+
coords = coords.permute(0, 2, 3, 1)
|
32 |
+
batch, h1, w1, _ = coords.shape
|
33 |
+
|
34 |
+
out_pyramid = []
|
35 |
+
for i in range(self.num_levels):
|
36 |
+
corr = self.corr_pyramid[i]
|
37 |
+
dx = torch.linspace(-r, r, 2*r+1)
|
38 |
+
dy = torch.linspace(-r, r, 2*r+1)
|
39 |
+
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
|
40 |
+
|
41 |
+
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
|
42 |
+
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
|
43 |
+
coords_lvl = centroid_lvl + delta_lvl
|
44 |
+
|
45 |
+
corr = bilinear_sampler(corr, coords_lvl)
|
46 |
+
corr = corr.view(batch, h1, w1, -1)
|
47 |
+
out_pyramid.append(corr)
|
48 |
+
|
49 |
+
out = torch.cat(out_pyramid, dim=-1)
|
50 |
+
return out.permute(0, 3, 1, 2).contiguous().float()
|
51 |
+
|
52 |
+
@staticmethod
|
53 |
+
def corr(fmap1, fmap2):
|
54 |
+
batch, dim, ht, wd = fmap1.shape
|
55 |
+
fmap1 = fmap1.view(batch, dim, ht*wd)
|
56 |
+
fmap2 = fmap2.view(batch, dim, ht*wd)
|
57 |
+
|
58 |
+
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
|
59 |
+
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
60 |
+
return corr / torch.sqrt(torch.tensor(dim).float())
|
61 |
+
|
62 |
+
|
63 |
+
class CorrLayer(torch.autograd.Function):
|
64 |
+
@staticmethod
|
65 |
+
def forward(ctx, fmap1, fmap2, coords, r):
|
66 |
+
fmap1 = fmap1.contiguous()
|
67 |
+
fmap2 = fmap2.contiguous()
|
68 |
+
coords = coords.contiguous()
|
69 |
+
ctx.save_for_backward(fmap1, fmap2, coords)
|
70 |
+
ctx.r = r
|
71 |
+
corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r)
|
72 |
+
return corr
|
73 |
+
|
74 |
+
@staticmethod
|
75 |
+
def backward(ctx, grad_corr):
|
76 |
+
fmap1, fmap2, coords = ctx.saved_tensors
|
77 |
+
grad_corr = grad_corr.contiguous()
|
78 |
+
fmap1_grad, fmap2_grad, coords_grad = \
|
79 |
+
correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r)
|
80 |
+
return fmap1_grad, fmap2_grad, coords_grad, None
|
81 |
+
|
82 |
+
|
83 |
+
class AlternateCorrBlock:
|
84 |
+
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
85 |
+
self.num_levels = num_levels
|
86 |
+
self.radius = radius
|
87 |
+
|
88 |
+
self.pyramid = [(fmap1, fmap2)]
|
89 |
+
for i in range(self.num_levels):
|
90 |
+
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
|
91 |
+
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
|
92 |
+
self.pyramid.append((fmap1, fmap2))
|
93 |
+
|
94 |
+
def __call__(self, coords):
|
95 |
+
|
96 |
+
coords = coords.permute(0, 2, 3, 1)
|
97 |
+
B, H, W, _ = coords.shape
|
98 |
+
|
99 |
+
corr_list = []
|
100 |
+
for i in range(self.num_levels):
|
101 |
+
r = self.radius
|
102 |
+
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1)
|
103 |
+
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1)
|
104 |
+
|
105 |
+
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
|
106 |
+
corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r)
|
107 |
+
corr_list.append(corr.squeeze(1))
|
108 |
+
|
109 |
+
corr = torch.stack(corr_list, dim=1)
|
110 |
+
corr = corr.reshape(B, -1, H, W)
|
111 |
+
return corr / 16.0
|
propainter/RAFT/datasets.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Data loading based on https://github.com/NVIDIA/flownet2-pytorch
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.utils.data as data
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
import os
|
9 |
+
import math
|
10 |
+
import random
|
11 |
+
from glob import glob
|
12 |
+
import os.path as osp
|
13 |
+
|
14 |
+
from utils import frame_utils
|
15 |
+
from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
|
16 |
+
|
17 |
+
|
18 |
+
class FlowDataset(data.Dataset):
|
19 |
+
def __init__(self, aug_params=None, sparse=False):
|
20 |
+
self.augmentor = None
|
21 |
+
self.sparse = sparse
|
22 |
+
if aug_params is not None:
|
23 |
+
if sparse:
|
24 |
+
self.augmentor = SparseFlowAugmentor(**aug_params)
|
25 |
+
else:
|
26 |
+
self.augmentor = FlowAugmentor(**aug_params)
|
27 |
+
|
28 |
+
self.is_test = False
|
29 |
+
self.init_seed = False
|
30 |
+
self.flow_list = []
|
31 |
+
self.image_list = []
|
32 |
+
self.extra_info = []
|
33 |
+
|
34 |
+
def __getitem__(self, index):
|
35 |
+
|
36 |
+
if self.is_test:
|
37 |
+
img1 = frame_utils.read_gen(self.image_list[index][0])
|
38 |
+
img2 = frame_utils.read_gen(self.image_list[index][1])
|
39 |
+
img1 = np.array(img1).astype(np.uint8)[..., :3]
|
40 |
+
img2 = np.array(img2).astype(np.uint8)[..., :3]
|
41 |
+
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
|
42 |
+
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
|
43 |
+
return img1, img2, self.extra_info[index]
|
44 |
+
|
45 |
+
if not self.init_seed:
|
46 |
+
worker_info = torch.utils.data.get_worker_info()
|
47 |
+
if worker_info is not None:
|
48 |
+
torch.manual_seed(worker_info.id)
|
49 |
+
np.random.seed(worker_info.id)
|
50 |
+
random.seed(worker_info.id)
|
51 |
+
self.init_seed = True
|
52 |
+
|
53 |
+
index = index % len(self.image_list)
|
54 |
+
valid = None
|
55 |
+
if self.sparse:
|
56 |
+
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
|
57 |
+
else:
|
58 |
+
flow = frame_utils.read_gen(self.flow_list[index])
|
59 |
+
|
60 |
+
img1 = frame_utils.read_gen(self.image_list[index][0])
|
61 |
+
img2 = frame_utils.read_gen(self.image_list[index][1])
|
62 |
+
|
63 |
+
flow = np.array(flow).astype(np.float32)
|
64 |
+
img1 = np.array(img1).astype(np.uint8)
|
65 |
+
img2 = np.array(img2).astype(np.uint8)
|
66 |
+
|
67 |
+
# grayscale images
|
68 |
+
if len(img1.shape) == 2:
|
69 |
+
img1 = np.tile(img1[...,None], (1, 1, 3))
|
70 |
+
img2 = np.tile(img2[...,None], (1, 1, 3))
|
71 |
+
else:
|
72 |
+
img1 = img1[..., :3]
|
73 |
+
img2 = img2[..., :3]
|
74 |
+
|
75 |
+
if self.augmentor is not None:
|
76 |
+
if self.sparse:
|
77 |
+
img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
|
78 |
+
else:
|
79 |
+
img1, img2, flow = self.augmentor(img1, img2, flow)
|
80 |
+
|
81 |
+
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
|
82 |
+
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
|
83 |
+
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
|
84 |
+
|
85 |
+
if valid is not None:
|
86 |
+
valid = torch.from_numpy(valid)
|
87 |
+
else:
|
88 |
+
valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
|
89 |
+
|
90 |
+
return img1, img2, flow, valid.float()
|
91 |
+
|
92 |
+
|
93 |
+
def __rmul__(self, v):
|
94 |
+
self.flow_list = v * self.flow_list
|
95 |
+
self.image_list = v * self.image_list
|
96 |
+
return self
|
97 |
+
|
98 |
+
def __len__(self):
|
99 |
+
return len(self.image_list)
|
100 |
+
|
101 |
+
|
102 |
+
class MpiSintel(FlowDataset):
|
103 |
+
def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
|
104 |
+
super(MpiSintel, self).__init__(aug_params)
|
105 |
+
flow_root = osp.join(root, split, 'flow')
|
106 |
+
image_root = osp.join(root, split, dstype)
|
107 |
+
|
108 |
+
if split == 'test':
|
109 |
+
self.is_test = True
|
110 |
+
|
111 |
+
for scene in os.listdir(image_root):
|
112 |
+
image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
|
113 |
+
for i in range(len(image_list)-1):
|
114 |
+
self.image_list += [ [image_list[i], image_list[i+1]] ]
|
115 |
+
self.extra_info += [ (scene, i) ] # scene and frame_id
|
116 |
+
|
117 |
+
if split != 'test':
|
118 |
+
self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
|
119 |
+
|
120 |
+
|
121 |
+
class FlyingChairs(FlowDataset):
|
122 |
+
def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
|
123 |
+
super(FlyingChairs, self).__init__(aug_params)
|
124 |
+
|
125 |
+
images = sorted(glob(osp.join(root, '*.ppm')))
|
126 |
+
flows = sorted(glob(osp.join(root, '*.flo')))
|
127 |
+
assert (len(images)//2 == len(flows))
|
128 |
+
|
129 |
+
split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
|
130 |
+
for i in range(len(flows)):
|
131 |
+
xid = split_list[i]
|
132 |
+
if (split=='training' and xid==1) or (split=='validation' and xid==2):
|
133 |
+
self.flow_list += [ flows[i] ]
|
134 |
+
self.image_list += [ [images[2*i], images[2*i+1]] ]
|
135 |
+
|
136 |
+
|
137 |
+
class FlyingThings3D(FlowDataset):
|
138 |
+
def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
|
139 |
+
super(FlyingThings3D, self).__init__(aug_params)
|
140 |
+
|
141 |
+
for cam in ['left']:
|
142 |
+
for direction in ['into_future', 'into_past']:
|
143 |
+
image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
|
144 |
+
image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
|
145 |
+
|
146 |
+
flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
|
147 |
+
flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
|
148 |
+
|
149 |
+
for idir, fdir in zip(image_dirs, flow_dirs):
|
150 |
+
images = sorted(glob(osp.join(idir, '*.png')) )
|
151 |
+
flows = sorted(glob(osp.join(fdir, '*.pfm')) )
|
152 |
+
for i in range(len(flows)-1):
|
153 |
+
if direction == 'into_future':
|
154 |
+
self.image_list += [ [images[i], images[i+1]] ]
|
155 |
+
self.flow_list += [ flows[i] ]
|
156 |
+
elif direction == 'into_past':
|
157 |
+
self.image_list += [ [images[i+1], images[i]] ]
|
158 |
+
self.flow_list += [ flows[i+1] ]
|
159 |
+
|
160 |
+
|
161 |
+
class KITTI(FlowDataset):
|
162 |
+
def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
|
163 |
+
super(KITTI, self).__init__(aug_params, sparse=True)
|
164 |
+
if split == 'testing':
|
165 |
+
self.is_test = True
|
166 |
+
|
167 |
+
root = osp.join(root, split)
|
168 |
+
images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
|
169 |
+
images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
|
170 |
+
|
171 |
+
for img1, img2 in zip(images1, images2):
|
172 |
+
frame_id = img1.split('/')[-1]
|
173 |
+
self.extra_info += [ [frame_id] ]
|
174 |
+
self.image_list += [ [img1, img2] ]
|
175 |
+
|
176 |
+
if split == 'training':
|
177 |
+
self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
|
178 |
+
|
179 |
+
|
180 |
+
class HD1K(FlowDataset):
|
181 |
+
def __init__(self, aug_params=None, root='datasets/HD1k'):
|
182 |
+
super(HD1K, self).__init__(aug_params, sparse=True)
|
183 |
+
|
184 |
+
seq_ix = 0
|
185 |
+
while 1:
|
186 |
+
flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
|
187 |
+
images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
|
188 |
+
|
189 |
+
if len(flows) == 0:
|
190 |
+
break
|
191 |
+
|
192 |
+
for i in range(len(flows)-1):
|
193 |
+
self.flow_list += [flows[i]]
|
194 |
+
self.image_list += [ [images[i], images[i+1]] ]
|
195 |
+
|
196 |
+
seq_ix += 1
|
197 |
+
|
198 |
+
|
199 |
+
def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
|
200 |
+
""" Create the data loader for the corresponding trainign set """
|
201 |
+
|
202 |
+
if args.stage == 'chairs':
|
203 |
+
aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
|
204 |
+
train_dataset = FlyingChairs(aug_params, split='training')
|
205 |
+
|
206 |
+
elif args.stage == 'things':
|
207 |
+
aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
|
208 |
+
clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
|
209 |
+
final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
|
210 |
+
train_dataset = clean_dataset + final_dataset
|
211 |
+
|
212 |
+
elif args.stage == 'sintel':
|
213 |
+
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
|
214 |
+
things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
|
215 |
+
sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
|
216 |
+
sintel_final = MpiSintel(aug_params, split='training', dstype='final')
|
217 |
+
|
218 |
+
if TRAIN_DS == 'C+T+K+S+H':
|
219 |
+
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
|
220 |
+
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
|
221 |
+
train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
|
222 |
+
|
223 |
+
elif TRAIN_DS == 'C+T+K/S':
|
224 |
+
train_dataset = 100*sintel_clean + 100*sintel_final + things
|
225 |
+
|
226 |
+
elif args.stage == 'kitti':
|
227 |
+
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
|
228 |
+
train_dataset = KITTI(aug_params, split='training')
|
229 |
+
|
230 |
+
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
|
231 |
+
pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
|
232 |
+
|
233 |
+
print('Training with %d image pairs' % len(train_dataset))
|
234 |
+
return train_loader
|
235 |
+
|
propainter/RAFT/demo.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import cv2
|
5 |
+
import glob
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from .raft import RAFT
|
11 |
+
from .utils import flow_viz
|
12 |
+
from .utils.utils import InputPadder
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
DEVICE = 'cuda'
|
17 |
+
|
18 |
+
def load_image(imfile):
|
19 |
+
img = np.array(Image.open(imfile)).astype(np.uint8)
|
20 |
+
img = torch.from_numpy(img).permute(2, 0, 1).float()
|
21 |
+
return img
|
22 |
+
|
23 |
+
|
24 |
+
def load_image_list(image_files):
|
25 |
+
images = []
|
26 |
+
for imfile in sorted(image_files):
|
27 |
+
images.append(load_image(imfile))
|
28 |
+
|
29 |
+
images = torch.stack(images, dim=0)
|
30 |
+
images = images.to(DEVICE)
|
31 |
+
|
32 |
+
padder = InputPadder(images.shape)
|
33 |
+
return padder.pad(images)[0]
|
34 |
+
|
35 |
+
|
36 |
+
def viz(img, flo):
|
37 |
+
img = img[0].permute(1,2,0).cpu().numpy()
|
38 |
+
flo = flo[0].permute(1,2,0).cpu().numpy()
|
39 |
+
|
40 |
+
# map flow to rgb image
|
41 |
+
flo = flow_viz.flow_to_image(flo)
|
42 |
+
# img_flo = np.concatenate([img, flo], axis=0)
|
43 |
+
img_flo = flo
|
44 |
+
|
45 |
+
cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]])
|
46 |
+
# cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
|
47 |
+
# cv2.waitKey()
|
48 |
+
|
49 |
+
|
50 |
+
def demo(args):
|
51 |
+
model = torch.nn.DataParallel(RAFT(args))
|
52 |
+
model.load_state_dict(torch.load(args.model))
|
53 |
+
|
54 |
+
model = model.module
|
55 |
+
model.to(DEVICE)
|
56 |
+
model.eval()
|
57 |
+
|
58 |
+
with torch.no_grad():
|
59 |
+
images = glob.glob(os.path.join(args.path, '*.png')) + \
|
60 |
+
glob.glob(os.path.join(args.path, '*.jpg'))
|
61 |
+
|
62 |
+
images = load_image_list(images)
|
63 |
+
for i in range(images.shape[0]-1):
|
64 |
+
image1 = images[i,None]
|
65 |
+
image2 = images[i+1,None]
|
66 |
+
|
67 |
+
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
|
68 |
+
viz(image1, flow_up)
|
69 |
+
|
70 |
+
|
71 |
+
def RAFT_infer(args):
|
72 |
+
model = torch.nn.DataParallel(RAFT(args))
|
73 |
+
model.load_state_dict(torch.load(args.model))
|
74 |
+
|
75 |
+
model = model.module
|
76 |
+
model.to(DEVICE)
|
77 |
+
model.eval()
|
78 |
+
|
79 |
+
return model
|
propainter/RAFT/extractor.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class ResidualBlock(nn.Module):
|
7 |
+
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
8 |
+
super(ResidualBlock, self).__init__()
|
9 |
+
|
10 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
|
11 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
12 |
+
self.relu = nn.ReLU(inplace=True)
|
13 |
+
|
14 |
+
num_groups = planes // 8
|
15 |
+
|
16 |
+
if norm_fn == 'group':
|
17 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
18 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
19 |
+
if not stride == 1:
|
20 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
21 |
+
|
22 |
+
elif norm_fn == 'batch':
|
23 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
24 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
25 |
+
if not stride == 1:
|
26 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
27 |
+
|
28 |
+
elif norm_fn == 'instance':
|
29 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
30 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
31 |
+
if not stride == 1:
|
32 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
33 |
+
|
34 |
+
elif norm_fn == 'none':
|
35 |
+
self.norm1 = nn.Sequential()
|
36 |
+
self.norm2 = nn.Sequential()
|
37 |
+
if not stride == 1:
|
38 |
+
self.norm3 = nn.Sequential()
|
39 |
+
|
40 |
+
if stride == 1:
|
41 |
+
self.downsample = None
|
42 |
+
|
43 |
+
else:
|
44 |
+
self.downsample = nn.Sequential(
|
45 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
46 |
+
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
y = x
|
50 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
51 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
52 |
+
|
53 |
+
if self.downsample is not None:
|
54 |
+
x = self.downsample(x)
|
55 |
+
|
56 |
+
return self.relu(x+y)
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
class BottleneckBlock(nn.Module):
|
61 |
+
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
62 |
+
super(BottleneckBlock, self).__init__()
|
63 |
+
|
64 |
+
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
|
65 |
+
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
|
66 |
+
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
|
67 |
+
self.relu = nn.ReLU(inplace=True)
|
68 |
+
|
69 |
+
num_groups = planes // 8
|
70 |
+
|
71 |
+
if norm_fn == 'group':
|
72 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
73 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
74 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
75 |
+
if not stride == 1:
|
76 |
+
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
77 |
+
|
78 |
+
elif norm_fn == 'batch':
|
79 |
+
self.norm1 = nn.BatchNorm2d(planes//4)
|
80 |
+
self.norm2 = nn.BatchNorm2d(planes//4)
|
81 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
82 |
+
if not stride == 1:
|
83 |
+
self.norm4 = nn.BatchNorm2d(planes)
|
84 |
+
|
85 |
+
elif norm_fn == 'instance':
|
86 |
+
self.norm1 = nn.InstanceNorm2d(planes//4)
|
87 |
+
self.norm2 = nn.InstanceNorm2d(planes//4)
|
88 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
89 |
+
if not stride == 1:
|
90 |
+
self.norm4 = nn.InstanceNorm2d(planes)
|
91 |
+
|
92 |
+
elif norm_fn == 'none':
|
93 |
+
self.norm1 = nn.Sequential()
|
94 |
+
self.norm2 = nn.Sequential()
|
95 |
+
self.norm3 = nn.Sequential()
|
96 |
+
if not stride == 1:
|
97 |
+
self.norm4 = nn.Sequential()
|
98 |
+
|
99 |
+
if stride == 1:
|
100 |
+
self.downsample = None
|
101 |
+
|
102 |
+
else:
|
103 |
+
self.downsample = nn.Sequential(
|
104 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
|
105 |
+
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
y = x
|
109 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
110 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
111 |
+
y = self.relu(self.norm3(self.conv3(y)))
|
112 |
+
|
113 |
+
if self.downsample is not None:
|
114 |
+
x = self.downsample(x)
|
115 |
+
|
116 |
+
return self.relu(x+y)
|
117 |
+
|
118 |
+
class BasicEncoder(nn.Module):
|
119 |
+
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
120 |
+
super(BasicEncoder, self).__init__()
|
121 |
+
self.norm_fn = norm_fn
|
122 |
+
|
123 |
+
if self.norm_fn == 'group':
|
124 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
125 |
+
|
126 |
+
elif self.norm_fn == 'batch':
|
127 |
+
self.norm1 = nn.BatchNorm2d(64)
|
128 |
+
|
129 |
+
elif self.norm_fn == 'instance':
|
130 |
+
self.norm1 = nn.InstanceNorm2d(64)
|
131 |
+
|
132 |
+
elif self.norm_fn == 'none':
|
133 |
+
self.norm1 = nn.Sequential()
|
134 |
+
|
135 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
136 |
+
self.relu1 = nn.ReLU(inplace=True)
|
137 |
+
|
138 |
+
self.in_planes = 64
|
139 |
+
self.layer1 = self._make_layer(64, stride=1)
|
140 |
+
self.layer2 = self._make_layer(96, stride=2)
|
141 |
+
self.layer3 = self._make_layer(128, stride=2)
|
142 |
+
|
143 |
+
# output convolution
|
144 |
+
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
145 |
+
|
146 |
+
self.dropout = None
|
147 |
+
if dropout > 0:
|
148 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
149 |
+
|
150 |
+
for m in self.modules():
|
151 |
+
if isinstance(m, nn.Conv2d):
|
152 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
153 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
154 |
+
if m.weight is not None:
|
155 |
+
nn.init.constant_(m.weight, 1)
|
156 |
+
if m.bias is not None:
|
157 |
+
nn.init.constant_(m.bias, 0)
|
158 |
+
|
159 |
+
def _make_layer(self, dim, stride=1):
|
160 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
161 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
162 |
+
layers = (layer1, layer2)
|
163 |
+
|
164 |
+
self.in_planes = dim
|
165 |
+
return nn.Sequential(*layers)
|
166 |
+
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
|
170 |
+
# if input is list, combine batch dimension
|
171 |
+
is_list = isinstance(x, tuple) or isinstance(x, list)
|
172 |
+
if is_list:
|
173 |
+
batch_dim = x[0].shape[0]
|
174 |
+
x = torch.cat(x, dim=0)
|
175 |
+
|
176 |
+
x = self.conv1(x)
|
177 |
+
x = self.norm1(x)
|
178 |
+
x = self.relu1(x)
|
179 |
+
|
180 |
+
x = self.layer1(x)
|
181 |
+
x = self.layer2(x)
|
182 |
+
x = self.layer3(x)
|
183 |
+
|
184 |
+
x = self.conv2(x)
|
185 |
+
|
186 |
+
if self.training and self.dropout is not None:
|
187 |
+
x = self.dropout(x)
|
188 |
+
|
189 |
+
if is_list:
|
190 |
+
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
191 |
+
|
192 |
+
return x
|
193 |
+
|
194 |
+
|
195 |
+
class SmallEncoder(nn.Module):
|
196 |
+
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
197 |
+
super(SmallEncoder, self).__init__()
|
198 |
+
self.norm_fn = norm_fn
|
199 |
+
|
200 |
+
if self.norm_fn == 'group':
|
201 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
|
202 |
+
|
203 |
+
elif self.norm_fn == 'batch':
|
204 |
+
self.norm1 = nn.BatchNorm2d(32)
|
205 |
+
|
206 |
+
elif self.norm_fn == 'instance':
|
207 |
+
self.norm1 = nn.InstanceNorm2d(32)
|
208 |
+
|
209 |
+
elif self.norm_fn == 'none':
|
210 |
+
self.norm1 = nn.Sequential()
|
211 |
+
|
212 |
+
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
|
213 |
+
self.relu1 = nn.ReLU(inplace=True)
|
214 |
+
|
215 |
+
self.in_planes = 32
|
216 |
+
self.layer1 = self._make_layer(32, stride=1)
|
217 |
+
self.layer2 = self._make_layer(64, stride=2)
|
218 |
+
self.layer3 = self._make_layer(96, stride=2)
|
219 |
+
|
220 |
+
self.dropout = None
|
221 |
+
if dropout > 0:
|
222 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
223 |
+
|
224 |
+
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
|
225 |
+
|
226 |
+
for m in self.modules():
|
227 |
+
if isinstance(m, nn.Conv2d):
|
228 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
229 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
230 |
+
if m.weight is not None:
|
231 |
+
nn.init.constant_(m.weight, 1)
|
232 |
+
if m.bias is not None:
|
233 |
+
nn.init.constant_(m.bias, 0)
|
234 |
+
|
235 |
+
def _make_layer(self, dim, stride=1):
|
236 |
+
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
237 |
+
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
|
238 |
+
layers = (layer1, layer2)
|
239 |
+
|
240 |
+
self.in_planes = dim
|
241 |
+
return nn.Sequential(*layers)
|
242 |
+
|
243 |
+
|
244 |
+
def forward(self, x):
|
245 |
+
|
246 |
+
# if input is list, combine batch dimension
|
247 |
+
is_list = isinstance(x, tuple) or isinstance(x, list)
|
248 |
+
if is_list:
|
249 |
+
batch_dim = x[0].shape[0]
|
250 |
+
x = torch.cat(x, dim=0)
|
251 |
+
|
252 |
+
x = self.conv1(x)
|
253 |
+
x = self.norm1(x)
|
254 |
+
x = self.relu1(x)
|
255 |
+
|
256 |
+
x = self.layer1(x)
|
257 |
+
x = self.layer2(x)
|
258 |
+
x = self.layer3(x)
|
259 |
+
x = self.conv2(x)
|
260 |
+
|
261 |
+
if self.training and self.dropout is not None:
|
262 |
+
x = self.dropout(x)
|
263 |
+
|
264 |
+
if is_list:
|
265 |
+
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
266 |
+
|
267 |
+
return x
|
propainter/RAFT/raft.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from .update import BasicUpdateBlock, SmallUpdateBlock
|
7 |
+
from .extractor import BasicEncoder, SmallEncoder
|
8 |
+
from .corr import CorrBlock, AlternateCorrBlock
|
9 |
+
from .utils.utils import bilinear_sampler, coords_grid, upflow8
|
10 |
+
|
11 |
+
try:
|
12 |
+
autocast = torch.cuda.amp.autocast
|
13 |
+
except:
|
14 |
+
# dummy autocast for PyTorch < 1.6
|
15 |
+
class autocast:
|
16 |
+
def __init__(self, enabled):
|
17 |
+
pass
|
18 |
+
def __enter__(self):
|
19 |
+
pass
|
20 |
+
def __exit__(self, *args):
|
21 |
+
pass
|
22 |
+
|
23 |
+
|
24 |
+
class RAFT(nn.Module):
|
25 |
+
def __init__(self, args):
|
26 |
+
super(RAFT, self).__init__()
|
27 |
+
self.args = args
|
28 |
+
|
29 |
+
if args.small:
|
30 |
+
self.hidden_dim = hdim = 96
|
31 |
+
self.context_dim = cdim = 64
|
32 |
+
args.corr_levels = 4
|
33 |
+
args.corr_radius = 3
|
34 |
+
|
35 |
+
else:
|
36 |
+
self.hidden_dim = hdim = 128
|
37 |
+
self.context_dim = cdim = 128
|
38 |
+
args.corr_levels = 4
|
39 |
+
args.corr_radius = 4
|
40 |
+
|
41 |
+
if 'dropout' not in args._get_kwargs():
|
42 |
+
args.dropout = 0
|
43 |
+
|
44 |
+
if 'alternate_corr' not in args._get_kwargs():
|
45 |
+
args.alternate_corr = False
|
46 |
+
|
47 |
+
# feature network, context network, and update block
|
48 |
+
if args.small:
|
49 |
+
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
|
50 |
+
self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
|
51 |
+
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
|
52 |
+
|
53 |
+
else:
|
54 |
+
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
|
55 |
+
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
|
56 |
+
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
|
57 |
+
|
58 |
+
|
59 |
+
def freeze_bn(self):
|
60 |
+
for m in self.modules():
|
61 |
+
if isinstance(m, nn.BatchNorm2d):
|
62 |
+
m.eval()
|
63 |
+
|
64 |
+
def initialize_flow(self, img):
|
65 |
+
""" Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
|
66 |
+
N, C, H, W = img.shape
|
67 |
+
coords0 = coords_grid(N, H//8, W//8).to(img.device)
|
68 |
+
coords1 = coords_grid(N, H//8, W//8).to(img.device)
|
69 |
+
|
70 |
+
# optical flow computed as difference: flow = coords1 - coords0
|
71 |
+
return coords0, coords1
|
72 |
+
|
73 |
+
def upsample_flow(self, flow, mask):
|
74 |
+
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
|
75 |
+
N, _, H, W = flow.shape
|
76 |
+
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
77 |
+
mask = torch.softmax(mask, dim=2)
|
78 |
+
|
79 |
+
up_flow = F.unfold(8 * flow, [3,3], padding=1)
|
80 |
+
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
81 |
+
|
82 |
+
up_flow = torch.sum(mask * up_flow, dim=2)
|
83 |
+
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
84 |
+
return up_flow.reshape(N, 2, 8*H, 8*W)
|
85 |
+
|
86 |
+
|
87 |
+
def forward(self, image1, image2, iters=12, flow_init=None, test_mode=True):
|
88 |
+
""" Estimate optical flow between pair of frames """
|
89 |
+
|
90 |
+
# image1 = 2 * (image1 / 255.0) - 1.0
|
91 |
+
# image2 = 2 * (image2 / 255.0) - 1.0
|
92 |
+
|
93 |
+
image1 = image1.contiguous()
|
94 |
+
image2 = image2.contiguous()
|
95 |
+
|
96 |
+
hdim = self.hidden_dim
|
97 |
+
cdim = self.context_dim
|
98 |
+
|
99 |
+
# run the feature network
|
100 |
+
with autocast(enabled=self.args.mixed_precision):
|
101 |
+
fmap1, fmap2 = self.fnet([image1, image2])
|
102 |
+
|
103 |
+
fmap1 = fmap1.float()
|
104 |
+
fmap2 = fmap2.float()
|
105 |
+
|
106 |
+
if self.args.alternate_corr:
|
107 |
+
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
108 |
+
else:
|
109 |
+
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
110 |
+
|
111 |
+
# run the context network
|
112 |
+
with autocast(enabled=self.args.mixed_precision):
|
113 |
+
cnet = self.cnet(image1)
|
114 |
+
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
|
115 |
+
net = torch.tanh(net)
|
116 |
+
inp = torch.relu(inp)
|
117 |
+
|
118 |
+
coords0, coords1 = self.initialize_flow(image1)
|
119 |
+
|
120 |
+
if flow_init is not None:
|
121 |
+
coords1 = coords1 + flow_init
|
122 |
+
|
123 |
+
flow_predictions = []
|
124 |
+
for itr in range(iters):
|
125 |
+
coords1 = coords1.detach()
|
126 |
+
corr = corr_fn(coords1) # index correlation volume
|
127 |
+
|
128 |
+
flow = coords1 - coords0
|
129 |
+
with autocast(enabled=self.args.mixed_precision):
|
130 |
+
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
|
131 |
+
|
132 |
+
# F(t+1) = F(t) + \Delta(t)
|
133 |
+
coords1 = coords1 + delta_flow
|
134 |
+
|
135 |
+
# upsample predictions
|
136 |
+
if up_mask is None:
|
137 |
+
flow_up = upflow8(coords1 - coords0)
|
138 |
+
else:
|
139 |
+
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
|
140 |
+
|
141 |
+
flow_predictions.append(flow_up)
|
142 |
+
|
143 |
+
if test_mode:
|
144 |
+
return coords1 - coords0, flow_up
|
145 |
+
|
146 |
+
return flow_predictions
|
propainter/RAFT/update.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class FlowHead(nn.Module):
|
7 |
+
def __init__(self, input_dim=128, hidden_dim=256):
|
8 |
+
super(FlowHead, self).__init__()
|
9 |
+
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
10 |
+
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
|
11 |
+
self.relu = nn.ReLU(inplace=True)
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
return self.conv2(self.relu(self.conv1(x)))
|
15 |
+
|
16 |
+
class ConvGRU(nn.Module):
|
17 |
+
def __init__(self, hidden_dim=128, input_dim=192+128):
|
18 |
+
super(ConvGRU, self).__init__()
|
19 |
+
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
20 |
+
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
21 |
+
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
22 |
+
|
23 |
+
def forward(self, h, x):
|
24 |
+
hx = torch.cat([h, x], dim=1)
|
25 |
+
|
26 |
+
z = torch.sigmoid(self.convz(hx))
|
27 |
+
r = torch.sigmoid(self.convr(hx))
|
28 |
+
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
|
29 |
+
|
30 |
+
h = (1-z) * h + z * q
|
31 |
+
return h
|
32 |
+
|
33 |
+
class SepConvGRU(nn.Module):
|
34 |
+
def __init__(self, hidden_dim=128, input_dim=192+128):
|
35 |
+
super(SepConvGRU, self).__init__()
|
36 |
+
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
37 |
+
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
38 |
+
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
39 |
+
|
40 |
+
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
41 |
+
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
42 |
+
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
43 |
+
|
44 |
+
|
45 |
+
def forward(self, h, x):
|
46 |
+
# horizontal
|
47 |
+
hx = torch.cat([h, x], dim=1)
|
48 |
+
z = torch.sigmoid(self.convz1(hx))
|
49 |
+
r = torch.sigmoid(self.convr1(hx))
|
50 |
+
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
|
51 |
+
h = (1-z) * h + z * q
|
52 |
+
|
53 |
+
# vertical
|
54 |
+
hx = torch.cat([h, x], dim=1)
|
55 |
+
z = torch.sigmoid(self.convz2(hx))
|
56 |
+
r = torch.sigmoid(self.convr2(hx))
|
57 |
+
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
|
58 |
+
h = (1-z) * h + z * q
|
59 |
+
|
60 |
+
return h
|
61 |
+
|
62 |
+
class SmallMotionEncoder(nn.Module):
|
63 |
+
def __init__(self, args):
|
64 |
+
super(SmallMotionEncoder, self).__init__()
|
65 |
+
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
|
66 |
+
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
|
67 |
+
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
|
68 |
+
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
|
69 |
+
self.conv = nn.Conv2d(128, 80, 3, padding=1)
|
70 |
+
|
71 |
+
def forward(self, flow, corr):
|
72 |
+
cor = F.relu(self.convc1(corr))
|
73 |
+
flo = F.relu(self.convf1(flow))
|
74 |
+
flo = F.relu(self.convf2(flo))
|
75 |
+
cor_flo = torch.cat([cor, flo], dim=1)
|
76 |
+
out = F.relu(self.conv(cor_flo))
|
77 |
+
return torch.cat([out, flow], dim=1)
|
78 |
+
|
79 |
+
class BasicMotionEncoder(nn.Module):
|
80 |
+
def __init__(self, args):
|
81 |
+
super(BasicMotionEncoder, self).__init__()
|
82 |
+
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
|
83 |
+
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
|
84 |
+
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
|
85 |
+
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
|
86 |
+
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
|
87 |
+
self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
|
88 |
+
|
89 |
+
def forward(self, flow, corr):
|
90 |
+
cor = F.relu(self.convc1(corr))
|
91 |
+
cor = F.relu(self.convc2(cor))
|
92 |
+
flo = F.relu(self.convf1(flow))
|
93 |
+
flo = F.relu(self.convf2(flo))
|
94 |
+
|
95 |
+
cor_flo = torch.cat([cor, flo], dim=1)
|
96 |
+
out = F.relu(self.conv(cor_flo))
|
97 |
+
return torch.cat([out, flow], dim=1)
|
98 |
+
|
99 |
+
class SmallUpdateBlock(nn.Module):
|
100 |
+
def __init__(self, args, hidden_dim=96):
|
101 |
+
super(SmallUpdateBlock, self).__init__()
|
102 |
+
self.encoder = SmallMotionEncoder(args)
|
103 |
+
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
|
104 |
+
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
|
105 |
+
|
106 |
+
def forward(self, net, inp, corr, flow):
|
107 |
+
motion_features = self.encoder(flow, corr)
|
108 |
+
inp = torch.cat([inp, motion_features], dim=1)
|
109 |
+
net = self.gru(net, inp)
|
110 |
+
delta_flow = self.flow_head(net)
|
111 |
+
|
112 |
+
return net, None, delta_flow
|
113 |
+
|
114 |
+
class BasicUpdateBlock(nn.Module):
|
115 |
+
def __init__(self, args, hidden_dim=128, input_dim=128):
|
116 |
+
super(BasicUpdateBlock, self).__init__()
|
117 |
+
self.args = args
|
118 |
+
self.encoder = BasicMotionEncoder(args)
|
119 |
+
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
|
120 |
+
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
121 |
+
|
122 |
+
self.mask = nn.Sequential(
|
123 |
+
nn.Conv2d(128, 256, 3, padding=1),
|
124 |
+
nn.ReLU(inplace=True),
|
125 |
+
nn.Conv2d(256, 64*9, 1, padding=0))
|
126 |
+
|
127 |
+
def forward(self, net, inp, corr, flow, upsample=True):
|
128 |
+
motion_features = self.encoder(flow, corr)
|
129 |
+
inp = torch.cat([inp, motion_features], dim=1)
|
130 |
+
|
131 |
+
net = self.gru(net, inp)
|
132 |
+
delta_flow = self.flow_head(net)
|
133 |
+
|
134 |
+
# scale mask to balence gradients
|
135 |
+
mask = .25 * self.mask(net)
|
136 |
+
return net, mask, delta_flow
|
137 |
+
|
138 |
+
|
139 |
+
|
propainter/RAFT/utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .flow_viz import flow_to_image
|
2 |
+
from .frame_utils import writeFlow
|
propainter/RAFT/utils/augmentor.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import math
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
cv2.setNumThreads(0)
|
8 |
+
cv2.ocl.setUseOpenCL(False)
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torchvision.transforms import ColorJitter
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
|
15 |
+
class FlowAugmentor:
|
16 |
+
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
|
17 |
+
|
18 |
+
# spatial augmentation params
|
19 |
+
self.crop_size = crop_size
|
20 |
+
self.min_scale = min_scale
|
21 |
+
self.max_scale = max_scale
|
22 |
+
self.spatial_aug_prob = 0.8
|
23 |
+
self.stretch_prob = 0.8
|
24 |
+
self.max_stretch = 0.2
|
25 |
+
|
26 |
+
# flip augmentation params
|
27 |
+
self.do_flip = do_flip
|
28 |
+
self.h_flip_prob = 0.5
|
29 |
+
self.v_flip_prob = 0.1
|
30 |
+
|
31 |
+
# photometric augmentation params
|
32 |
+
self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
|
33 |
+
self.asymmetric_color_aug_prob = 0.2
|
34 |
+
self.eraser_aug_prob = 0.5
|
35 |
+
|
36 |
+
def color_transform(self, img1, img2):
|
37 |
+
""" Photometric augmentation """
|
38 |
+
|
39 |
+
# asymmetric
|
40 |
+
if np.random.rand() < self.asymmetric_color_aug_prob:
|
41 |
+
img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
|
42 |
+
img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
|
43 |
+
|
44 |
+
# symmetric
|
45 |
+
else:
|
46 |
+
image_stack = np.concatenate([img1, img2], axis=0)
|
47 |
+
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
|
48 |
+
img1, img2 = np.split(image_stack, 2, axis=0)
|
49 |
+
|
50 |
+
return img1, img2
|
51 |
+
|
52 |
+
def eraser_transform(self, img1, img2, bounds=[50, 100]):
|
53 |
+
""" Occlusion augmentation """
|
54 |
+
|
55 |
+
ht, wd = img1.shape[:2]
|
56 |
+
if np.random.rand() < self.eraser_aug_prob:
|
57 |
+
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
58 |
+
for _ in range(np.random.randint(1, 3)):
|
59 |
+
x0 = np.random.randint(0, wd)
|
60 |
+
y0 = np.random.randint(0, ht)
|
61 |
+
dx = np.random.randint(bounds[0], bounds[1])
|
62 |
+
dy = np.random.randint(bounds[0], bounds[1])
|
63 |
+
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
|
64 |
+
|
65 |
+
return img1, img2
|
66 |
+
|
67 |
+
def spatial_transform(self, img1, img2, flow):
|
68 |
+
# randomly sample scale
|
69 |
+
ht, wd = img1.shape[:2]
|
70 |
+
min_scale = np.maximum(
|
71 |
+
(self.crop_size[0] + 8) / float(ht),
|
72 |
+
(self.crop_size[1] + 8) / float(wd))
|
73 |
+
|
74 |
+
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
|
75 |
+
scale_x = scale
|
76 |
+
scale_y = scale
|
77 |
+
if np.random.rand() < self.stretch_prob:
|
78 |
+
scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
|
79 |
+
scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
|
80 |
+
|
81 |
+
scale_x = np.clip(scale_x, min_scale, None)
|
82 |
+
scale_y = np.clip(scale_y, min_scale, None)
|
83 |
+
|
84 |
+
if np.random.rand() < self.spatial_aug_prob:
|
85 |
+
# rescale the images
|
86 |
+
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
87 |
+
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
88 |
+
flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
89 |
+
flow = flow * [scale_x, scale_y]
|
90 |
+
|
91 |
+
if self.do_flip:
|
92 |
+
if np.random.rand() < self.h_flip_prob: # h-flip
|
93 |
+
img1 = img1[:, ::-1]
|
94 |
+
img2 = img2[:, ::-1]
|
95 |
+
flow = flow[:, ::-1] * [-1.0, 1.0]
|
96 |
+
|
97 |
+
if np.random.rand() < self.v_flip_prob: # v-flip
|
98 |
+
img1 = img1[::-1, :]
|
99 |
+
img2 = img2[::-1, :]
|
100 |
+
flow = flow[::-1, :] * [1.0, -1.0]
|
101 |
+
|
102 |
+
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
|
103 |
+
x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
|
104 |
+
|
105 |
+
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
106 |
+
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
107 |
+
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
108 |
+
|
109 |
+
return img1, img2, flow
|
110 |
+
|
111 |
+
def __call__(self, img1, img2, flow):
|
112 |
+
img1, img2 = self.color_transform(img1, img2)
|
113 |
+
img1, img2 = self.eraser_transform(img1, img2)
|
114 |
+
img1, img2, flow = self.spatial_transform(img1, img2, flow)
|
115 |
+
|
116 |
+
img1 = np.ascontiguousarray(img1)
|
117 |
+
img2 = np.ascontiguousarray(img2)
|
118 |
+
flow = np.ascontiguousarray(flow)
|
119 |
+
|
120 |
+
return img1, img2, flow
|
121 |
+
|
122 |
+
class SparseFlowAugmentor:
|
123 |
+
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
|
124 |
+
# spatial augmentation params
|
125 |
+
self.crop_size = crop_size
|
126 |
+
self.min_scale = min_scale
|
127 |
+
self.max_scale = max_scale
|
128 |
+
self.spatial_aug_prob = 0.8
|
129 |
+
self.stretch_prob = 0.8
|
130 |
+
self.max_stretch = 0.2
|
131 |
+
|
132 |
+
# flip augmentation params
|
133 |
+
self.do_flip = do_flip
|
134 |
+
self.h_flip_prob = 0.5
|
135 |
+
self.v_flip_prob = 0.1
|
136 |
+
|
137 |
+
# photometric augmentation params
|
138 |
+
self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
|
139 |
+
self.asymmetric_color_aug_prob = 0.2
|
140 |
+
self.eraser_aug_prob = 0.5
|
141 |
+
|
142 |
+
def color_transform(self, img1, img2):
|
143 |
+
image_stack = np.concatenate([img1, img2], axis=0)
|
144 |
+
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
|
145 |
+
img1, img2 = np.split(image_stack, 2, axis=0)
|
146 |
+
return img1, img2
|
147 |
+
|
148 |
+
def eraser_transform(self, img1, img2):
|
149 |
+
ht, wd = img1.shape[:2]
|
150 |
+
if np.random.rand() < self.eraser_aug_prob:
|
151 |
+
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
152 |
+
for _ in range(np.random.randint(1, 3)):
|
153 |
+
x0 = np.random.randint(0, wd)
|
154 |
+
y0 = np.random.randint(0, ht)
|
155 |
+
dx = np.random.randint(50, 100)
|
156 |
+
dy = np.random.randint(50, 100)
|
157 |
+
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
|
158 |
+
|
159 |
+
return img1, img2
|
160 |
+
|
161 |
+
def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
|
162 |
+
ht, wd = flow.shape[:2]
|
163 |
+
coords = np.meshgrid(np.arange(wd), np.arange(ht))
|
164 |
+
coords = np.stack(coords, axis=-1)
|
165 |
+
|
166 |
+
coords = coords.reshape(-1, 2).astype(np.float32)
|
167 |
+
flow = flow.reshape(-1, 2).astype(np.float32)
|
168 |
+
valid = valid.reshape(-1).astype(np.float32)
|
169 |
+
|
170 |
+
coords0 = coords[valid>=1]
|
171 |
+
flow0 = flow[valid>=1]
|
172 |
+
|
173 |
+
ht1 = int(round(ht * fy))
|
174 |
+
wd1 = int(round(wd * fx))
|
175 |
+
|
176 |
+
coords1 = coords0 * [fx, fy]
|
177 |
+
flow1 = flow0 * [fx, fy]
|
178 |
+
|
179 |
+
xx = np.round(coords1[:,0]).astype(np.int32)
|
180 |
+
yy = np.round(coords1[:,1]).astype(np.int32)
|
181 |
+
|
182 |
+
v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
|
183 |
+
xx = xx[v]
|
184 |
+
yy = yy[v]
|
185 |
+
flow1 = flow1[v]
|
186 |
+
|
187 |
+
flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
|
188 |
+
valid_img = np.zeros([ht1, wd1], dtype=np.int32)
|
189 |
+
|
190 |
+
flow_img[yy, xx] = flow1
|
191 |
+
valid_img[yy, xx] = 1
|
192 |
+
|
193 |
+
return flow_img, valid_img
|
194 |
+
|
195 |
+
def spatial_transform(self, img1, img2, flow, valid):
|
196 |
+
# randomly sample scale
|
197 |
+
|
198 |
+
ht, wd = img1.shape[:2]
|
199 |
+
min_scale = np.maximum(
|
200 |
+
(self.crop_size[0] + 1) / float(ht),
|
201 |
+
(self.crop_size[1] + 1) / float(wd))
|
202 |
+
|
203 |
+
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
|
204 |
+
scale_x = np.clip(scale, min_scale, None)
|
205 |
+
scale_y = np.clip(scale, min_scale, None)
|
206 |
+
|
207 |
+
if np.random.rand() < self.spatial_aug_prob:
|
208 |
+
# rescale the images
|
209 |
+
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
210 |
+
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
211 |
+
flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
|
212 |
+
|
213 |
+
if self.do_flip:
|
214 |
+
if np.random.rand() < 0.5: # h-flip
|
215 |
+
img1 = img1[:, ::-1]
|
216 |
+
img2 = img2[:, ::-1]
|
217 |
+
flow = flow[:, ::-1] * [-1.0, 1.0]
|
218 |
+
valid = valid[:, ::-1]
|
219 |
+
|
220 |
+
margin_y = 20
|
221 |
+
margin_x = 50
|
222 |
+
|
223 |
+
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
|
224 |
+
x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
|
225 |
+
|
226 |
+
y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
|
227 |
+
x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
|
228 |
+
|
229 |
+
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
230 |
+
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
231 |
+
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
232 |
+
valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
233 |
+
return img1, img2, flow, valid
|
234 |
+
|
235 |
+
|
236 |
+
def __call__(self, img1, img2, flow, valid):
|
237 |
+
img1, img2 = self.color_transform(img1, img2)
|
238 |
+
img1, img2 = self.eraser_transform(img1, img2)
|
239 |
+
img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
|
240 |
+
|
241 |
+
img1 = np.ascontiguousarray(img1)
|
242 |
+
img2 = np.ascontiguousarray(img2)
|
243 |
+
flow = np.ascontiguousarray(flow)
|
244 |
+
valid = np.ascontiguousarray(valid)
|
245 |
+
|
246 |
+
return img1, img2, flow, valid
|
propainter/RAFT/utils/flow_viz.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
|
2 |
+
|
3 |
+
|
4 |
+
# MIT License
|
5 |
+
#
|
6 |
+
# Copyright (c) 2018 Tom Runia
|
7 |
+
#
|
8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
10 |
+
# in the Software without restriction, including without limitation the rights
|
11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
13 |
+
# furnished to do so, subject to conditions.
|
14 |
+
#
|
15 |
+
# Author: Tom Runia
|
16 |
+
# Date Created: 2018-08-03
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
|
20 |
+
def make_colorwheel():
|
21 |
+
"""
|
22 |
+
Generates a color wheel for optical flow visualization as presented in:
|
23 |
+
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
24 |
+
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
|
25 |
+
|
26 |
+
Code follows the original C++ source code of Daniel Scharstein.
|
27 |
+
Code follows the the Matlab source code of Deqing Sun.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
np.ndarray: Color wheel
|
31 |
+
"""
|
32 |
+
|
33 |
+
RY = 15
|
34 |
+
YG = 6
|
35 |
+
GC = 4
|
36 |
+
CB = 11
|
37 |
+
BM = 13
|
38 |
+
MR = 6
|
39 |
+
|
40 |
+
ncols = RY + YG + GC + CB + BM + MR
|
41 |
+
colorwheel = np.zeros((ncols, 3))
|
42 |
+
col = 0
|
43 |
+
|
44 |
+
# RY
|
45 |
+
colorwheel[0:RY, 0] = 255
|
46 |
+
colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
|
47 |
+
col = col+RY
|
48 |
+
# YG
|
49 |
+
colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
|
50 |
+
colorwheel[col:col+YG, 1] = 255
|
51 |
+
col = col+YG
|
52 |
+
# GC
|
53 |
+
colorwheel[col:col+GC, 1] = 255
|
54 |
+
colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
|
55 |
+
col = col+GC
|
56 |
+
# CB
|
57 |
+
colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
|
58 |
+
colorwheel[col:col+CB, 2] = 255
|
59 |
+
col = col+CB
|
60 |
+
# BM
|
61 |
+
colorwheel[col:col+BM, 2] = 255
|
62 |
+
colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
|
63 |
+
col = col+BM
|
64 |
+
# MR
|
65 |
+
colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
|
66 |
+
colorwheel[col:col+MR, 0] = 255
|
67 |
+
return colorwheel
|
68 |
+
|
69 |
+
|
70 |
+
def flow_uv_to_colors(u, v, convert_to_bgr=False):
|
71 |
+
"""
|
72 |
+
Applies the flow color wheel to (possibly clipped) flow components u and v.
|
73 |
+
|
74 |
+
According to the C++ source code of Daniel Scharstein
|
75 |
+
According to the Matlab source code of Deqing Sun
|
76 |
+
|
77 |
+
Args:
|
78 |
+
u (np.ndarray): Input horizontal flow of shape [H,W]
|
79 |
+
v (np.ndarray): Input vertical flow of shape [H,W]
|
80 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
84 |
+
"""
|
85 |
+
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
|
86 |
+
colorwheel = make_colorwheel() # shape [55x3]
|
87 |
+
ncols = colorwheel.shape[0]
|
88 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
89 |
+
a = np.arctan2(-v, -u)/np.pi
|
90 |
+
fk = (a+1) / 2*(ncols-1)
|
91 |
+
k0 = np.floor(fk).astype(np.int32)
|
92 |
+
k1 = k0 + 1
|
93 |
+
k1[k1 == ncols] = 0
|
94 |
+
f = fk - k0
|
95 |
+
for i in range(colorwheel.shape[1]):
|
96 |
+
tmp = colorwheel[:,i]
|
97 |
+
col0 = tmp[k0] / 255.0
|
98 |
+
col1 = tmp[k1] / 255.0
|
99 |
+
col = (1-f)*col0 + f*col1
|
100 |
+
idx = (rad <= 1)
|
101 |
+
col[idx] = 1 - rad[idx] * (1-col[idx])
|
102 |
+
col[~idx] = col[~idx] * 0.75 # out of range
|
103 |
+
# Note the 2-i => BGR instead of RGB
|
104 |
+
ch_idx = 2-i if convert_to_bgr else i
|
105 |
+
flow_image[:,:,ch_idx] = np.floor(255 * col)
|
106 |
+
return flow_image
|
107 |
+
|
108 |
+
|
109 |
+
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
|
110 |
+
"""
|
111 |
+
Expects a two dimensional flow image of shape.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
|
115 |
+
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
|
116 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
120 |
+
"""
|
121 |
+
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
|
122 |
+
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
|
123 |
+
if clip_flow is not None:
|
124 |
+
flow_uv = np.clip(flow_uv, 0, clip_flow)
|
125 |
+
u = flow_uv[:,:,0]
|
126 |
+
v = flow_uv[:,:,1]
|
127 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
128 |
+
rad_max = np.max(rad)
|
129 |
+
epsilon = 1e-5
|
130 |
+
u = u / (rad_max + epsilon)
|
131 |
+
v = v / (rad_max + epsilon)
|
132 |
+
return flow_uv_to_colors(u, v, convert_to_bgr)
|
propainter/RAFT/utils/flow_viz_pt.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
|
2 |
+
import torch
|
3 |
+
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
4 |
+
|
5 |
+
@torch.no_grad()
|
6 |
+
def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
|
7 |
+
|
8 |
+
"""
|
9 |
+
Converts a flow to an RGB image.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float.
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
img (Tensor): Image Tensor of dtype uint8 where each color corresponds
|
16 |
+
to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
|
17 |
+
"""
|
18 |
+
|
19 |
+
if flow.dtype != torch.float:
|
20 |
+
raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")
|
21 |
+
|
22 |
+
orig_shape = flow.shape
|
23 |
+
if flow.ndim == 3:
|
24 |
+
flow = flow[None] # Add batch dim
|
25 |
+
|
26 |
+
if flow.ndim != 4 or flow.shape[1] != 2:
|
27 |
+
raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.")
|
28 |
+
|
29 |
+
max_norm = torch.sum(flow**2, dim=1).sqrt().max()
|
30 |
+
epsilon = torch.finfo((flow).dtype).eps
|
31 |
+
normalized_flow = flow / (max_norm + epsilon)
|
32 |
+
img = _normalized_flow_to_image(normalized_flow)
|
33 |
+
|
34 |
+
if len(orig_shape) == 3:
|
35 |
+
img = img[0] # Remove batch dim
|
36 |
+
return img
|
37 |
+
|
38 |
+
@torch.no_grad()
|
39 |
+
def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
|
40 |
+
|
41 |
+
"""
|
42 |
+
Converts a batch of normalized flow to an RGB image.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
|
46 |
+
Returns:
|
47 |
+
img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
|
48 |
+
"""
|
49 |
+
|
50 |
+
N, _, H, W = normalized_flow.shape
|
51 |
+
device = normalized_flow.device
|
52 |
+
flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device)
|
53 |
+
colorwheel = _make_colorwheel().to(device) # shape [55x3]
|
54 |
+
num_cols = colorwheel.shape[0]
|
55 |
+
norm = torch.sum(normalized_flow**2, dim=1).sqrt()
|
56 |
+
a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
|
57 |
+
fk = (a + 1) / 2 * (num_cols - 1)
|
58 |
+
k0 = torch.floor(fk).to(torch.long)
|
59 |
+
k1 = k0 + 1
|
60 |
+
k1[k1 == num_cols] = 0
|
61 |
+
f = fk - k0
|
62 |
+
|
63 |
+
for c in range(colorwheel.shape[1]):
|
64 |
+
tmp = colorwheel[:, c]
|
65 |
+
col0 = tmp[k0] / 255.0
|
66 |
+
col1 = tmp[k1] / 255.0
|
67 |
+
col = (1 - f) * col0 + f * col1
|
68 |
+
col = 1 - norm * (1 - col)
|
69 |
+
flow_image[:, c, :, :] = torch.floor(255. * col)
|
70 |
+
return flow_image
|
71 |
+
|
72 |
+
|
73 |
+
@torch.no_grad()
|
74 |
+
def _make_colorwheel() -> torch.Tensor:
|
75 |
+
"""
|
76 |
+
Generates a color wheel for optical flow visualization as presented in:
|
77 |
+
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
78 |
+
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf.
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
colorwheel (Tensor[55, 3]): Colorwheel Tensor.
|
82 |
+
"""
|
83 |
+
|
84 |
+
RY = 15
|
85 |
+
YG = 6
|
86 |
+
GC = 4
|
87 |
+
CB = 11
|
88 |
+
BM = 13
|
89 |
+
MR = 6
|
90 |
+
|
91 |
+
ncols = RY + YG + GC + CB + BM + MR
|
92 |
+
colorwheel = torch.zeros((ncols, 3))
|
93 |
+
col = 0
|
94 |
+
|
95 |
+
# RY
|
96 |
+
colorwheel[0:RY, 0] = 255
|
97 |
+
colorwheel[0:RY, 1] = torch.floor(255. * torch.arange(0., RY) / RY)
|
98 |
+
col = col + RY
|
99 |
+
# YG
|
100 |
+
colorwheel[col : col + YG, 0] = 255 - torch.floor(255. * torch.arange(0., YG) / YG)
|
101 |
+
colorwheel[col : col + YG, 1] = 255
|
102 |
+
col = col + YG
|
103 |
+
# GC
|
104 |
+
colorwheel[col : col + GC, 1] = 255
|
105 |
+
colorwheel[col : col + GC, 2] = torch.floor(255. * torch.arange(0., GC) / GC)
|
106 |
+
col = col + GC
|
107 |
+
# CB
|
108 |
+
colorwheel[col : col + CB, 1] = 255 - torch.floor(255. * torch.arange(CB) / CB)
|
109 |
+
colorwheel[col : col + CB, 2] = 255
|
110 |
+
col = col + CB
|
111 |
+
# BM
|
112 |
+
colorwheel[col : col + BM, 2] = 255
|
113 |
+
colorwheel[col : col + BM, 0] = torch.floor(255. * torch.arange(0., BM) / BM)
|
114 |
+
col = col + BM
|
115 |
+
# MR
|
116 |
+
colorwheel[col : col + MR, 2] = 255 - torch.floor(255. * torch.arange(MR) / MR)
|
117 |
+
colorwheel[col : col + MR, 0] = 255
|
118 |
+
return colorwheel
|
propainter/RAFT/utils/frame_utils.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
from os.path import *
|
4 |
+
import re
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
cv2.setNumThreads(0)
|
8 |
+
cv2.ocl.setUseOpenCL(False)
|
9 |
+
|
10 |
+
TAG_CHAR = np.array([202021.25], np.float32)
|
11 |
+
|
12 |
+
def readFlow(fn):
|
13 |
+
""" Read .flo file in Middlebury format"""
|
14 |
+
# Code adapted from:
|
15 |
+
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
|
16 |
+
|
17 |
+
# WARNING: this will work on little-endian architectures (eg Intel x86) only!
|
18 |
+
# print 'fn = %s'%(fn)
|
19 |
+
with open(fn, 'rb') as f:
|
20 |
+
magic = np.fromfile(f, np.float32, count=1)
|
21 |
+
if 202021.25 != magic:
|
22 |
+
print('Magic number incorrect. Invalid .flo file')
|
23 |
+
return None
|
24 |
+
else:
|
25 |
+
w = np.fromfile(f, np.int32, count=1)
|
26 |
+
h = np.fromfile(f, np.int32, count=1)
|
27 |
+
# print 'Reading %d x %d flo file\n' % (w, h)
|
28 |
+
data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
|
29 |
+
# Reshape data into 3D array (columns, rows, bands)
|
30 |
+
# The reshape here is for visualization, the original code is (w,h,2)
|
31 |
+
return np.resize(data, (int(h), int(w), 2))
|
32 |
+
|
33 |
+
def readPFM(file):
|
34 |
+
file = open(file, 'rb')
|
35 |
+
|
36 |
+
color = None
|
37 |
+
width = None
|
38 |
+
height = None
|
39 |
+
scale = None
|
40 |
+
endian = None
|
41 |
+
|
42 |
+
header = file.readline().rstrip()
|
43 |
+
if header == b'PF':
|
44 |
+
color = True
|
45 |
+
elif header == b'Pf':
|
46 |
+
color = False
|
47 |
+
else:
|
48 |
+
raise Exception('Not a PFM file.')
|
49 |
+
|
50 |
+
dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
|
51 |
+
if dim_match:
|
52 |
+
width, height = map(int, dim_match.groups())
|
53 |
+
else:
|
54 |
+
raise Exception('Malformed PFM header.')
|
55 |
+
|
56 |
+
scale = float(file.readline().rstrip())
|
57 |
+
if scale < 0: # little-endian
|
58 |
+
endian = '<'
|
59 |
+
scale = -scale
|
60 |
+
else:
|
61 |
+
endian = '>' # big-endian
|
62 |
+
|
63 |
+
data = np.fromfile(file, endian + 'f')
|
64 |
+
shape = (height, width, 3) if color else (height, width)
|
65 |
+
|
66 |
+
data = np.reshape(data, shape)
|
67 |
+
data = np.flipud(data)
|
68 |
+
return data
|
69 |
+
|
70 |
+
def writeFlow(filename,uv,v=None):
|
71 |
+
""" Write optical flow to file.
|
72 |
+
|
73 |
+
If v is None, uv is assumed to contain both u and v channels,
|
74 |
+
stacked in depth.
|
75 |
+
Original code by Deqing Sun, adapted from Daniel Scharstein.
|
76 |
+
"""
|
77 |
+
nBands = 2
|
78 |
+
|
79 |
+
if v is None:
|
80 |
+
assert(uv.ndim == 3)
|
81 |
+
assert(uv.shape[2] == 2)
|
82 |
+
u = uv[:,:,0]
|
83 |
+
v = uv[:,:,1]
|
84 |
+
else:
|
85 |
+
u = uv
|
86 |
+
|
87 |
+
assert(u.shape == v.shape)
|
88 |
+
height,width = u.shape
|
89 |
+
f = open(filename,'wb')
|
90 |
+
# write the header
|
91 |
+
f.write(TAG_CHAR)
|
92 |
+
np.array(width).astype(np.int32).tofile(f)
|
93 |
+
np.array(height).astype(np.int32).tofile(f)
|
94 |
+
# arrange into matrix form
|
95 |
+
tmp = np.zeros((height, width*nBands))
|
96 |
+
tmp[:,np.arange(width)*2] = u
|
97 |
+
tmp[:,np.arange(width)*2 + 1] = v
|
98 |
+
tmp.astype(np.float32).tofile(f)
|
99 |
+
f.close()
|
100 |
+
|
101 |
+
|
102 |
+
def readFlowKITTI(filename):
|
103 |
+
flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
|
104 |
+
flow = flow[:,:,::-1].astype(np.float32)
|
105 |
+
flow, valid = flow[:, :, :2], flow[:, :, 2]
|
106 |
+
flow = (flow - 2**15) / 64.0
|
107 |
+
return flow, valid
|
108 |
+
|
109 |
+
def readDispKITTI(filename):
|
110 |
+
disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
|
111 |
+
valid = disp > 0.0
|
112 |
+
flow = np.stack([-disp, np.zeros_like(disp)], -1)
|
113 |
+
return flow, valid
|
114 |
+
|
115 |
+
|
116 |
+
def writeFlowKITTI(filename, uv):
|
117 |
+
uv = 64.0 * uv + 2**15
|
118 |
+
valid = np.ones([uv.shape[0], uv.shape[1], 1])
|
119 |
+
uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
|
120 |
+
cv2.imwrite(filename, uv[..., ::-1])
|
121 |
+
|
122 |
+
|
123 |
+
def read_gen(file_name, pil=False):
|
124 |
+
ext = splitext(file_name)[-1]
|
125 |
+
if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
|
126 |
+
return Image.open(file_name)
|
127 |
+
elif ext == '.bin' or ext == '.raw':
|
128 |
+
return np.load(file_name)
|
129 |
+
elif ext == '.flo':
|
130 |
+
return readFlow(file_name).astype(np.float32)
|
131 |
+
elif ext == '.pfm':
|
132 |
+
flow = readPFM(file_name).astype(np.float32)
|
133 |
+
if len(flow.shape) == 2:
|
134 |
+
return flow
|
135 |
+
else:
|
136 |
+
return flow[:, :, :-1]
|
137 |
+
return []
|
propainter/RAFT/utils/utils.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
from scipy import interpolate
|
5 |
+
|
6 |
+
|
7 |
+
class InputPadder:
|
8 |
+
""" Pads images such that dimensions are divisible by 8 """
|
9 |
+
def __init__(self, dims, mode='sintel'):
|
10 |
+
self.ht, self.wd = dims[-2:]
|
11 |
+
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
|
12 |
+
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
|
13 |
+
if mode == 'sintel':
|
14 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
|
15 |
+
else:
|
16 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
17 |
+
|
18 |
+
def pad(self, *inputs):
|
19 |
+
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
20 |
+
|
21 |
+
def unpad(self,x):
|
22 |
+
ht, wd = x.shape[-2:]
|
23 |
+
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
24 |
+
return x[..., c[0]:c[1], c[2]:c[3]]
|
25 |
+
|
26 |
+
def forward_interpolate(flow):
|
27 |
+
flow = flow.detach().cpu().numpy()
|
28 |
+
dx, dy = flow[0], flow[1]
|
29 |
+
|
30 |
+
ht, wd = dx.shape
|
31 |
+
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
|
32 |
+
|
33 |
+
x1 = x0 + dx
|
34 |
+
y1 = y0 + dy
|
35 |
+
|
36 |
+
x1 = x1.reshape(-1)
|
37 |
+
y1 = y1.reshape(-1)
|
38 |
+
dx = dx.reshape(-1)
|
39 |
+
dy = dy.reshape(-1)
|
40 |
+
|
41 |
+
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
|
42 |
+
x1 = x1[valid]
|
43 |
+
y1 = y1[valid]
|
44 |
+
dx = dx[valid]
|
45 |
+
dy = dy[valid]
|
46 |
+
|
47 |
+
flow_x = interpolate.griddata(
|
48 |
+
(x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
|
49 |
+
|
50 |
+
flow_y = interpolate.griddata(
|
51 |
+
(x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
|
52 |
+
|
53 |
+
flow = np.stack([flow_x, flow_y], axis=0)
|
54 |
+
return torch.from_numpy(flow).float()
|
55 |
+
|
56 |
+
|
57 |
+
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
58 |
+
""" Wrapper for grid_sample, uses pixel coordinates """
|
59 |
+
H, W = img.shape[-2:]
|
60 |
+
xgrid, ygrid = coords.split([1,1], dim=-1)
|
61 |
+
xgrid = 2*xgrid/(W-1) - 1
|
62 |
+
ygrid = 2*ygrid/(H-1) - 1
|
63 |
+
|
64 |
+
grid = torch.cat([xgrid, ygrid], dim=-1)
|
65 |
+
img = F.grid_sample(img, grid, align_corners=True)
|
66 |
+
|
67 |
+
if mask:
|
68 |
+
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
69 |
+
return img, mask.float()
|
70 |
+
|
71 |
+
return img
|
72 |
+
|
73 |
+
|
74 |
+
def coords_grid(batch, ht, wd):
|
75 |
+
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
76 |
+
coords = torch.stack(coords[::-1], dim=0).float()
|
77 |
+
return coords[None].repeat(batch, 1, 1, 1)
|
78 |
+
|
79 |
+
|
80 |
+
def upflow8(flow, mode='bilinear'):
|
81 |
+
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
|
82 |
+
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
|
propainter/core/dataset.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
from PIL import Image
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torchvision.transforms as transforms
|
11 |
+
|
12 |
+
from utils.file_client import FileClient
|
13 |
+
from utils.img_util import imfrombytes
|
14 |
+
from utils.flow_util import resize_flow, flowread
|
15 |
+
from core.utils import (create_random_shape_with_random_motion, Stack,
|
16 |
+
ToTorchFormatTensor, GroupRandomHorizontalFlip,GroupRandomHorizontalFlowFlip)
|
17 |
+
|
18 |
+
|
19 |
+
class TrainDataset(torch.utils.data.Dataset):
|
20 |
+
def __init__(self, args: dict):
|
21 |
+
self.args = args
|
22 |
+
self.video_root = args['video_root']
|
23 |
+
self.flow_root = args['flow_root']
|
24 |
+
self.num_local_frames = args['num_local_frames']
|
25 |
+
self.num_ref_frames = args['num_ref_frames']
|
26 |
+
self.size = self.w, self.h = (args['w'], args['h'])
|
27 |
+
|
28 |
+
self.load_flow = args['load_flow']
|
29 |
+
if self.load_flow:
|
30 |
+
assert os.path.exists(self.flow_root)
|
31 |
+
|
32 |
+
json_path = os.path.join('./datasets', args['name'], 'train.json')
|
33 |
+
|
34 |
+
with open(json_path, 'r') as f:
|
35 |
+
self.video_train_dict = json.load(f)
|
36 |
+
self.video_names = sorted(list(self.video_train_dict.keys()))
|
37 |
+
|
38 |
+
# self.video_names = sorted(os.listdir(self.video_root))
|
39 |
+
self.video_dict = {}
|
40 |
+
self.frame_dict = {}
|
41 |
+
|
42 |
+
for v in self.video_names:
|
43 |
+
frame_list = sorted(os.listdir(os.path.join(self.video_root, v)))
|
44 |
+
v_len = len(frame_list)
|
45 |
+
if v_len > self.num_local_frames + self.num_ref_frames:
|
46 |
+
self.video_dict[v] = v_len
|
47 |
+
self.frame_dict[v] = frame_list
|
48 |
+
|
49 |
+
|
50 |
+
self.video_names = list(self.video_dict.keys()) # update names
|
51 |
+
|
52 |
+
self._to_tensors = transforms.Compose([
|
53 |
+
Stack(),
|
54 |
+
ToTorchFormatTensor(),
|
55 |
+
])
|
56 |
+
self.file_client = FileClient('disk')
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return len(self.video_names)
|
60 |
+
|
61 |
+
def _sample_index(self, length, sample_length, num_ref_frame=3):
|
62 |
+
complete_idx_set = list(range(length))
|
63 |
+
pivot = random.randint(0, length - sample_length)
|
64 |
+
local_idx = complete_idx_set[pivot:pivot + sample_length]
|
65 |
+
remain_idx = list(set(complete_idx_set) - set(local_idx))
|
66 |
+
ref_index = sorted(random.sample(remain_idx, num_ref_frame))
|
67 |
+
|
68 |
+
return local_idx + ref_index
|
69 |
+
|
70 |
+
def __getitem__(self, index):
|
71 |
+
video_name = self.video_names[index]
|
72 |
+
# create masks
|
73 |
+
all_masks = create_random_shape_with_random_motion(
|
74 |
+
self.video_dict[video_name], imageHeight=self.h, imageWidth=self.w)
|
75 |
+
|
76 |
+
# create sample index
|
77 |
+
selected_index = self._sample_index(self.video_dict[video_name],
|
78 |
+
self.num_local_frames,
|
79 |
+
self.num_ref_frames)
|
80 |
+
|
81 |
+
# read video frames
|
82 |
+
frames = []
|
83 |
+
masks = []
|
84 |
+
flows_f, flows_b = [], []
|
85 |
+
for idx in selected_index:
|
86 |
+
frame_list = self.frame_dict[video_name]
|
87 |
+
img_path = os.path.join(self.video_root, video_name, frame_list[idx])
|
88 |
+
img_bytes = self.file_client.get(img_path, 'img')
|
89 |
+
img = imfrombytes(img_bytes, float32=False)
|
90 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
91 |
+
img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR)
|
92 |
+
img = Image.fromarray(img)
|
93 |
+
|
94 |
+
frames.append(img)
|
95 |
+
masks.append(all_masks[idx])
|
96 |
+
|
97 |
+
if len(frames) <= self.num_local_frames-1 and self.load_flow:
|
98 |
+
current_n = frame_list[idx][:-4]
|
99 |
+
next_n = frame_list[idx+1][:-4]
|
100 |
+
flow_f_path = os.path.join(self.flow_root, video_name, f'{current_n}_{next_n}_f.flo')
|
101 |
+
flow_b_path = os.path.join(self.flow_root, video_name, f'{next_n}_{current_n}_b.flo')
|
102 |
+
flow_f = flowread(flow_f_path, quantize=False)
|
103 |
+
flow_b = flowread(flow_b_path, quantize=False)
|
104 |
+
flow_f = resize_flow(flow_f, self.h, self.w)
|
105 |
+
flow_b = resize_flow(flow_b, self.h, self.w)
|
106 |
+
flows_f.append(flow_f)
|
107 |
+
flows_b.append(flow_b)
|
108 |
+
|
109 |
+
if len(frames) == self.num_local_frames: # random reverse
|
110 |
+
if random.random() < 0.5:
|
111 |
+
frames.reverse()
|
112 |
+
masks.reverse()
|
113 |
+
if self.load_flow:
|
114 |
+
flows_f.reverse()
|
115 |
+
flows_b.reverse()
|
116 |
+
flows_ = flows_f
|
117 |
+
flows_f = flows_b
|
118 |
+
flows_b = flows_
|
119 |
+
|
120 |
+
if self.load_flow:
|
121 |
+
frames, flows_f, flows_b = GroupRandomHorizontalFlowFlip()(frames, flows_f, flows_b)
|
122 |
+
else:
|
123 |
+
frames = GroupRandomHorizontalFlip()(frames)
|
124 |
+
|
125 |
+
# normalizate, to tensors
|
126 |
+
frame_tensors = self._to_tensors(frames) * 2.0 - 1.0
|
127 |
+
mask_tensors = self._to_tensors(masks)
|
128 |
+
if self.load_flow:
|
129 |
+
flows_f = np.stack(flows_f, axis=-1) # H W 2 T-1
|
130 |
+
flows_b = np.stack(flows_b, axis=-1)
|
131 |
+
flows_f = torch.from_numpy(flows_f).permute(3, 2, 0, 1).contiguous().float()
|
132 |
+
flows_b = torch.from_numpy(flows_b).permute(3, 2, 0, 1).contiguous().float()
|
133 |
+
|
134 |
+
# img [-1,1] mask [0,1]
|
135 |
+
if self.load_flow:
|
136 |
+
return frame_tensors, mask_tensors, flows_f, flows_b, video_name
|
137 |
+
else:
|
138 |
+
return frame_tensors, mask_tensors, 'None', 'None', video_name
|
139 |
+
|
140 |
+
|
141 |
+
class TestDataset(torch.utils.data.Dataset):
|
142 |
+
def __init__(self, args):
|
143 |
+
self.args = args
|
144 |
+
self.size = self.w, self.h = args['size']
|
145 |
+
|
146 |
+
self.video_root = args['video_root']
|
147 |
+
self.mask_root = args['mask_root']
|
148 |
+
self.flow_root = args['flow_root']
|
149 |
+
|
150 |
+
self.load_flow = args['load_flow']
|
151 |
+
if self.load_flow:
|
152 |
+
assert os.path.exists(self.flow_root)
|
153 |
+
self.video_names = sorted(os.listdir(self.mask_root))
|
154 |
+
|
155 |
+
self.video_dict = {}
|
156 |
+
self.frame_dict = {}
|
157 |
+
|
158 |
+
for v in self.video_names:
|
159 |
+
frame_list = sorted(os.listdir(os.path.join(self.video_root, v)))
|
160 |
+
v_len = len(frame_list)
|
161 |
+
self.video_dict[v] = v_len
|
162 |
+
self.frame_dict[v] = frame_list
|
163 |
+
|
164 |
+
self._to_tensors = transforms.Compose([
|
165 |
+
Stack(),
|
166 |
+
ToTorchFormatTensor(),
|
167 |
+
])
|
168 |
+
self.file_client = FileClient('disk')
|
169 |
+
|
170 |
+
def __len__(self):
|
171 |
+
return len(self.video_names)
|
172 |
+
|
173 |
+
def __getitem__(self, index):
|
174 |
+
video_name = self.video_names[index]
|
175 |
+
selected_index = list(range(self.video_dict[video_name]))
|
176 |
+
|
177 |
+
# read video frames
|
178 |
+
frames = []
|
179 |
+
masks = []
|
180 |
+
flows_f, flows_b = [], []
|
181 |
+
for idx in selected_index:
|
182 |
+
frame_list = self.frame_dict[video_name]
|
183 |
+
frame_path = os.path.join(self.video_root, video_name, frame_list[idx])
|
184 |
+
|
185 |
+
img_bytes = self.file_client.get(frame_path, 'input')
|
186 |
+
img = imfrombytes(img_bytes, float32=False)
|
187 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
188 |
+
img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR)
|
189 |
+
img = Image.fromarray(img)
|
190 |
+
|
191 |
+
frames.append(img)
|
192 |
+
|
193 |
+
mask_path = os.path.join(self.mask_root, video_name, str(idx).zfill(5) + '.png')
|
194 |
+
mask = Image.open(mask_path).resize(self.size, Image.NEAREST).convert('L')
|
195 |
+
|
196 |
+
# origin: 0 indicates missing. now: 1 indicates missing
|
197 |
+
mask = np.asarray(mask)
|
198 |
+
m = np.array(mask > 0).astype(np.uint8)
|
199 |
+
|
200 |
+
m = cv2.dilate(m,
|
201 |
+
cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)),
|
202 |
+
iterations=4)
|
203 |
+
mask = Image.fromarray(m * 255)
|
204 |
+
masks.append(mask)
|
205 |
+
|
206 |
+
if len(frames) <= len(selected_index)-1 and self.load_flow:
|
207 |
+
current_n = frame_list[idx][:-4]
|
208 |
+
next_n = frame_list[idx+1][:-4]
|
209 |
+
flow_f_path = os.path.join(self.flow_root, video_name, f'{current_n}_{next_n}_f.flo')
|
210 |
+
flow_b_path = os.path.join(self.flow_root, video_name, f'{next_n}_{current_n}_b.flo')
|
211 |
+
flow_f = flowread(flow_f_path, quantize=False)
|
212 |
+
flow_b = flowread(flow_b_path, quantize=False)
|
213 |
+
flow_f = resize_flow(flow_f, self.h, self.w)
|
214 |
+
flow_b = resize_flow(flow_b, self.h, self.w)
|
215 |
+
flows_f.append(flow_f)
|
216 |
+
flows_b.append(flow_b)
|
217 |
+
|
218 |
+
# normalizate, to tensors
|
219 |
+
frames_PIL = [np.array(f).astype(np.uint8) for f in frames]
|
220 |
+
frame_tensors = self._to_tensors(frames) * 2.0 - 1.0
|
221 |
+
mask_tensors = self._to_tensors(masks)
|
222 |
+
|
223 |
+
if self.load_flow:
|
224 |
+
flows_f = np.stack(flows_f, axis=-1) # H W 2 T-1
|
225 |
+
flows_b = np.stack(flows_b, axis=-1)
|
226 |
+
flows_f = torch.from_numpy(flows_f).permute(3, 2, 0, 1).contiguous().float()
|
227 |
+
flows_b = torch.from_numpy(flows_b).permute(3, 2, 0, 1).contiguous().float()
|
228 |
+
|
229 |
+
if self.load_flow:
|
230 |
+
return frame_tensors, mask_tensors, flows_f, flows_b, video_name, frames_PIL
|
231 |
+
else:
|
232 |
+
return frame_tensors, mask_tensors, 'None', 'None', video_name
|
propainter/core/dist.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def get_world_size():
|
6 |
+
"""Find OMPI world size without calling mpi functions
|
7 |
+
:rtype: int
|
8 |
+
"""
|
9 |
+
if os.environ.get('PMI_SIZE') is not None:
|
10 |
+
return int(os.environ.get('PMI_SIZE') or 1)
|
11 |
+
elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None:
|
12 |
+
return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1)
|
13 |
+
else:
|
14 |
+
return torch.cuda.device_count()
|
15 |
+
|
16 |
+
|
17 |
+
def get_global_rank():
|
18 |
+
"""Find OMPI world rank without calling mpi functions
|
19 |
+
:rtype: int
|
20 |
+
"""
|
21 |
+
if os.environ.get('PMI_RANK') is not None:
|
22 |
+
return int(os.environ.get('PMI_RANK') or 0)
|
23 |
+
elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None:
|
24 |
+
return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0)
|
25 |
+
else:
|
26 |
+
return 0
|
27 |
+
|
28 |
+
|
29 |
+
def get_local_rank():
|
30 |
+
"""Find OMPI local rank without calling mpi functions
|
31 |
+
:rtype: int
|
32 |
+
"""
|
33 |
+
if os.environ.get('MPI_LOCALRANKID') is not None:
|
34 |
+
return int(os.environ.get('MPI_LOCALRANKID') or 0)
|
35 |
+
elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None:
|
36 |
+
return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0)
|
37 |
+
else:
|
38 |
+
return 0
|
39 |
+
|
40 |
+
|
41 |
+
def get_master_ip():
|
42 |
+
if os.environ.get('AZ_BATCH_MASTER_NODE') is not None:
|
43 |
+
return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0]
|
44 |
+
elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None:
|
45 |
+
return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE')
|
46 |
+
else:
|
47 |
+
return "127.0.0.1"
|
propainter/core/loss.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import lpips
|
4 |
+
from model.vgg_arch import VGGFeatureExtractor
|
5 |
+
|
6 |
+
class PerceptualLoss(nn.Module):
|
7 |
+
"""Perceptual loss with commonly used style loss.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
layer_weights (dict): The weight for each layer of vgg feature.
|
11 |
+
Here is an example: {'conv5_4': 1.}, which means the conv5_4
|
12 |
+
feature layer (before relu5_4) will be extracted with weight
|
13 |
+
1.0 in calculting losses.
|
14 |
+
vgg_type (str): The type of vgg network used as feature extractor.
|
15 |
+
Default: 'vgg19'.
|
16 |
+
use_input_norm (bool): If True, normalize the input image in vgg.
|
17 |
+
Default: True.
|
18 |
+
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
|
19 |
+
Default: False.
|
20 |
+
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
|
21 |
+
loss will be calculated and the loss will multiplied by the
|
22 |
+
weight. Default: 1.0.
|
23 |
+
style_weight (float): If `style_weight > 0`, the style loss will be
|
24 |
+
calculated and the loss will multiplied by the weight.
|
25 |
+
Default: 0.
|
26 |
+
criterion (str): Criterion used for perceptual loss. Default: 'l1'.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self,
|
30 |
+
layer_weights,
|
31 |
+
vgg_type='vgg19',
|
32 |
+
use_input_norm=True,
|
33 |
+
range_norm=False,
|
34 |
+
perceptual_weight=1.0,
|
35 |
+
style_weight=0.,
|
36 |
+
criterion='l1'):
|
37 |
+
super(PerceptualLoss, self).__init__()
|
38 |
+
self.perceptual_weight = perceptual_weight
|
39 |
+
self.style_weight = style_weight
|
40 |
+
self.layer_weights = layer_weights
|
41 |
+
self.vgg = VGGFeatureExtractor(
|
42 |
+
layer_name_list=list(layer_weights.keys()),
|
43 |
+
vgg_type=vgg_type,
|
44 |
+
use_input_norm=use_input_norm,
|
45 |
+
range_norm=range_norm)
|
46 |
+
|
47 |
+
self.criterion_type = criterion
|
48 |
+
if self.criterion_type == 'l1':
|
49 |
+
self.criterion = torch.nn.L1Loss()
|
50 |
+
elif self.criterion_type == 'l2':
|
51 |
+
self.criterion = torch.nn.L2loss()
|
52 |
+
elif self.criterion_type == 'mse':
|
53 |
+
self.criterion = torch.nn.MSELoss(reduction='mean')
|
54 |
+
elif self.criterion_type == 'fro':
|
55 |
+
self.criterion = None
|
56 |
+
else:
|
57 |
+
raise NotImplementedError(f'{criterion} criterion has not been supported.')
|
58 |
+
|
59 |
+
def forward(self, x, gt):
|
60 |
+
"""Forward function.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
x (Tensor): Input tensor with shape (n, c, h, w).
|
64 |
+
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
Tensor: Forward results.
|
68 |
+
"""
|
69 |
+
# extract vgg features
|
70 |
+
x_features = self.vgg(x)
|
71 |
+
gt_features = self.vgg(gt.detach())
|
72 |
+
|
73 |
+
# calculate perceptual loss
|
74 |
+
if self.perceptual_weight > 0:
|
75 |
+
percep_loss = 0
|
76 |
+
for k in x_features.keys():
|
77 |
+
if self.criterion_type == 'fro':
|
78 |
+
percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
|
79 |
+
else:
|
80 |
+
percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
|
81 |
+
percep_loss *= self.perceptual_weight
|
82 |
+
else:
|
83 |
+
percep_loss = None
|
84 |
+
|
85 |
+
# calculate style loss
|
86 |
+
if self.style_weight > 0:
|
87 |
+
style_loss = 0
|
88 |
+
for k in x_features.keys():
|
89 |
+
if self.criterion_type == 'fro':
|
90 |
+
style_loss += torch.norm(
|
91 |
+
self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
|
92 |
+
else:
|
93 |
+
style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
|
94 |
+
gt_features[k])) * self.layer_weights[k]
|
95 |
+
style_loss *= self.style_weight
|
96 |
+
else:
|
97 |
+
style_loss = None
|
98 |
+
|
99 |
+
return percep_loss, style_loss
|
100 |
+
|
101 |
+
def _gram_mat(self, x):
|
102 |
+
"""Calculate Gram matrix.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
x (torch.Tensor): Tensor with shape of (n, c, h, w).
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
torch.Tensor: Gram matrix.
|
109 |
+
"""
|
110 |
+
n, c, h, w = x.size()
|
111 |
+
features = x.view(n, c, w * h)
|
112 |
+
features_t = features.transpose(1, 2)
|
113 |
+
gram = features.bmm(features_t) / (c * h * w)
|
114 |
+
return gram
|
115 |
+
|
116 |
+
class LPIPSLoss(nn.Module):
|
117 |
+
def __init__(self,
|
118 |
+
loss_weight=1.0,
|
119 |
+
use_input_norm=True,
|
120 |
+
range_norm=False,):
|
121 |
+
super(LPIPSLoss, self).__init__()
|
122 |
+
self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
|
123 |
+
self.loss_weight = loss_weight
|
124 |
+
self.use_input_norm = use_input_norm
|
125 |
+
self.range_norm = range_norm
|
126 |
+
|
127 |
+
if self.use_input_norm:
|
128 |
+
# the mean is for image with range [0, 1]
|
129 |
+
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
130 |
+
# the std is for image with range [0, 1]
|
131 |
+
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
132 |
+
|
133 |
+
def forward(self, pred, target):
|
134 |
+
if self.range_norm:
|
135 |
+
pred = (pred + 1) / 2
|
136 |
+
target = (target + 1) / 2
|
137 |
+
if self.use_input_norm:
|
138 |
+
pred = (pred - self.mean) / self.std
|
139 |
+
target = (target - self.mean) / self.std
|
140 |
+
lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
|
141 |
+
return self.loss_weight * lpips_loss.mean(), None
|
142 |
+
|
143 |
+
|
144 |
+
class AdversarialLoss(nn.Module):
|
145 |
+
r"""
|
146 |
+
Adversarial loss
|
147 |
+
https://arxiv.org/abs/1711.10337
|
148 |
+
"""
|
149 |
+
def __init__(self,
|
150 |
+
type='nsgan',
|
151 |
+
target_real_label=1.0,
|
152 |
+
target_fake_label=0.0):
|
153 |
+
r"""
|
154 |
+
type = nsgan | lsgan | hinge
|
155 |
+
"""
|
156 |
+
super(AdversarialLoss, self).__init__()
|
157 |
+
self.type = type
|
158 |
+
self.register_buffer('real_label', torch.tensor(target_real_label))
|
159 |
+
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
160 |
+
|
161 |
+
if type == 'nsgan':
|
162 |
+
self.criterion = nn.BCELoss()
|
163 |
+
elif type == 'lsgan':
|
164 |
+
self.criterion = nn.MSELoss()
|
165 |
+
elif type == 'hinge':
|
166 |
+
self.criterion = nn.ReLU()
|
167 |
+
|
168 |
+
def __call__(self, outputs, is_real, is_disc=None):
|
169 |
+
if self.type == 'hinge':
|
170 |
+
if is_disc:
|
171 |
+
if is_real:
|
172 |
+
outputs = -outputs
|
173 |
+
return self.criterion(1 + outputs).mean()
|
174 |
+
else:
|
175 |
+
return (-outputs).mean()
|
176 |
+
else:
|
177 |
+
labels = (self.real_label
|
178 |
+
if is_real else self.fake_label).expand_as(outputs)
|
179 |
+
loss = self.criterion(outputs, labels)
|
180 |
+
return loss
|
propainter/core/lr_scheduler.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
LR scheduler from BasicSR https://github.com/xinntao/BasicSR
|
3 |
+
"""
|
4 |
+
import math
|
5 |
+
from collections import Counter
|
6 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
7 |
+
|
8 |
+
|
9 |
+
class MultiStepRestartLR(_LRScheduler):
|
10 |
+
""" MultiStep with restarts learning rate scheme.
|
11 |
+
Args:
|
12 |
+
optimizer (torch.nn.optimizer): Torch optimizer.
|
13 |
+
milestones (list): Iterations that will decrease learning rate.
|
14 |
+
gamma (float): Decrease ratio. Default: 0.1.
|
15 |
+
restarts (list): Restart iterations. Default: [0].
|
16 |
+
restart_weights (list): Restart weights at each restart iteration.
|
17 |
+
Default: [1].
|
18 |
+
last_epoch (int): Used in _LRScheduler. Default: -1.
|
19 |
+
"""
|
20 |
+
def __init__(self,
|
21 |
+
optimizer,
|
22 |
+
milestones,
|
23 |
+
gamma=0.1,
|
24 |
+
restarts=(0, ),
|
25 |
+
restart_weights=(1, ),
|
26 |
+
last_epoch=-1):
|
27 |
+
self.milestones = Counter(milestones)
|
28 |
+
self.gamma = gamma
|
29 |
+
self.restarts = restarts
|
30 |
+
self.restart_weights = restart_weights
|
31 |
+
assert len(self.restarts) == len(
|
32 |
+
self.restart_weights), 'restarts and their weights do not match.'
|
33 |
+
super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
|
34 |
+
|
35 |
+
def get_lr(self):
|
36 |
+
if self.last_epoch in self.restarts:
|
37 |
+
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
|
38 |
+
return [
|
39 |
+
group['initial_lr'] * weight
|
40 |
+
for group in self.optimizer.param_groups
|
41 |
+
]
|
42 |
+
if self.last_epoch not in self.milestones:
|
43 |
+
return [group['lr'] for group in self.optimizer.param_groups]
|
44 |
+
return [
|
45 |
+
group['lr'] * self.gamma**self.milestones[self.last_epoch]
|
46 |
+
for group in self.optimizer.param_groups
|
47 |
+
]
|
48 |
+
|
49 |
+
|
50 |
+
def get_position_from_periods(iteration, cumulative_period):
|
51 |
+
"""Get the position from a period list.
|
52 |
+
It will return the index of the right-closest number in the period list.
|
53 |
+
For example, the cumulative_period = [100, 200, 300, 400],
|
54 |
+
if iteration == 50, return 0;
|
55 |
+
if iteration == 210, return 2;
|
56 |
+
if iteration == 300, return 2.
|
57 |
+
Args:
|
58 |
+
iteration (int): Current iteration.
|
59 |
+
cumulative_period (list[int]): Cumulative period list.
|
60 |
+
Returns:
|
61 |
+
int: The position of the right-closest number in the period list.
|
62 |
+
"""
|
63 |
+
for i, period in enumerate(cumulative_period):
|
64 |
+
if iteration <= period:
|
65 |
+
return i
|
66 |
+
|
67 |
+
|
68 |
+
class CosineAnnealingRestartLR(_LRScheduler):
|
69 |
+
""" Cosine annealing with restarts learning rate scheme.
|
70 |
+
An example of config:
|
71 |
+
periods = [10, 10, 10, 10]
|
72 |
+
restart_weights = [1, 0.5, 0.5, 0.5]
|
73 |
+
eta_min=1e-7
|
74 |
+
It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
|
75 |
+
scheduler will restart with the weights in restart_weights.
|
76 |
+
Args:
|
77 |
+
optimizer (torch.nn.optimizer): Torch optimizer.
|
78 |
+
periods (list): Period for each cosine anneling cycle.
|
79 |
+
restart_weights (list): Restart weights at each restart iteration.
|
80 |
+
Default: [1].
|
81 |
+
eta_min (float): The mimimum lr. Default: 0.
|
82 |
+
last_epoch (int): Used in _LRScheduler. Default: -1.
|
83 |
+
"""
|
84 |
+
def __init__(self,
|
85 |
+
optimizer,
|
86 |
+
periods,
|
87 |
+
restart_weights=(1, ),
|
88 |
+
eta_min=1e-7,
|
89 |
+
last_epoch=-1):
|
90 |
+
self.periods = periods
|
91 |
+
self.restart_weights = restart_weights
|
92 |
+
self.eta_min = eta_min
|
93 |
+
assert (len(self.periods) == len(self.restart_weights)
|
94 |
+
), 'periods and restart_weights should have the same length.'
|
95 |
+
self.cumulative_period = [
|
96 |
+
sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
|
97 |
+
]
|
98 |
+
super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
|
99 |
+
|
100 |
+
def get_lr(self):
|
101 |
+
idx = get_position_from_periods(self.last_epoch,
|
102 |
+
self.cumulative_period)
|
103 |
+
current_weight = self.restart_weights[idx]
|
104 |
+
nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
|
105 |
+
current_period = self.periods[idx]
|
106 |
+
|
107 |
+
return [
|
108 |
+
self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
|
109 |
+
(1 + math.cos(math.pi * (
|
110 |
+
(self.last_epoch - nearest_restart) / current_period)))
|
111 |
+
for base_lr in self.base_lrs
|
112 |
+
]
|
propainter/core/metrics.py
ADDED
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
# from skimage import measure
|
3 |
+
from skimage.metrics import structural_similarity as compare_ssim
|
4 |
+
from scipy import linalg
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from propainter.core.utils import to_tensors
|
11 |
+
|
12 |
+
|
13 |
+
def calculate_epe(flow1, flow2):
|
14 |
+
"""Calculate End point errors."""
|
15 |
+
|
16 |
+
epe = torch.sum((flow1 - flow2)**2, dim=1).sqrt()
|
17 |
+
epe = epe.view(-1)
|
18 |
+
return epe.mean().item()
|
19 |
+
|
20 |
+
|
21 |
+
def calculate_psnr(img1, img2):
|
22 |
+
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
|
23 |
+
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
|
24 |
+
Args:
|
25 |
+
img1 (ndarray): Images with range [0, 255].
|
26 |
+
img2 (ndarray): Images with range [0, 255].
|
27 |
+
Returns:
|
28 |
+
float: psnr result.
|
29 |
+
"""
|
30 |
+
|
31 |
+
assert img1.shape == img2.shape, \
|
32 |
+
(f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
|
33 |
+
|
34 |
+
mse = np.mean((img1 - img2)**2)
|
35 |
+
if mse == 0:
|
36 |
+
return float('inf')
|
37 |
+
return 20. * np.log10(255. / np.sqrt(mse))
|
38 |
+
|
39 |
+
|
40 |
+
def calc_psnr_and_ssim(img1, img2):
|
41 |
+
"""Calculate PSNR and SSIM for images.
|
42 |
+
img1: ndarray, range [0, 255]
|
43 |
+
img2: ndarray, range [0, 255]
|
44 |
+
"""
|
45 |
+
img1 = img1.astype(np.float64)
|
46 |
+
img2 = img2.astype(np.float64)
|
47 |
+
|
48 |
+
psnr = calculate_psnr(img1, img2)
|
49 |
+
ssim = compare_ssim(img1,
|
50 |
+
img2,
|
51 |
+
data_range=255,
|
52 |
+
multichannel=True,
|
53 |
+
win_size=65,
|
54 |
+
channel_axis=2)
|
55 |
+
|
56 |
+
return psnr, ssim
|
57 |
+
|
58 |
+
|
59 |
+
###########################
|
60 |
+
# I3D models
|
61 |
+
###########################
|
62 |
+
|
63 |
+
|
64 |
+
def init_i3d_model(i3d_model_path):
|
65 |
+
print(f"[Loading I3D model from {i3d_model_path} for FID score ..]")
|
66 |
+
i3d_model = InceptionI3d(400, in_channels=3, final_endpoint='Logits')
|
67 |
+
i3d_model.load_state_dict(torch.load(i3d_model_path))
|
68 |
+
i3d_model.to(torch.device('cuda:0'))
|
69 |
+
return i3d_model
|
70 |
+
|
71 |
+
|
72 |
+
def calculate_i3d_activations(video1, video2, i3d_model, device):
|
73 |
+
"""Calculate VFID metric.
|
74 |
+
video1: list[PIL.Image]
|
75 |
+
video2: list[PIL.Image]
|
76 |
+
"""
|
77 |
+
video1 = to_tensors()(video1).unsqueeze(0).to(device)
|
78 |
+
video2 = to_tensors()(video2).unsqueeze(0).to(device)
|
79 |
+
video1_activations = get_i3d_activations(
|
80 |
+
video1, i3d_model).cpu().numpy().flatten()
|
81 |
+
video2_activations = get_i3d_activations(
|
82 |
+
video2, i3d_model).cpu().numpy().flatten()
|
83 |
+
|
84 |
+
return video1_activations, video2_activations
|
85 |
+
|
86 |
+
|
87 |
+
def calculate_vfid(real_activations, fake_activations):
|
88 |
+
"""
|
89 |
+
Given two distribution of features, compute the FID score between them
|
90 |
+
Params:
|
91 |
+
real_activations: list[ndarray]
|
92 |
+
fake_activations: list[ndarray]
|
93 |
+
"""
|
94 |
+
m1 = np.mean(real_activations, axis=0)
|
95 |
+
m2 = np.mean(fake_activations, axis=0)
|
96 |
+
s1 = np.cov(real_activations, rowvar=False)
|
97 |
+
s2 = np.cov(fake_activations, rowvar=False)
|
98 |
+
return calculate_frechet_distance(m1, s1, m2, s2)
|
99 |
+
|
100 |
+
|
101 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
102 |
+
"""Numpy implementation of the Frechet Distance.
|
103 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
104 |
+
and X_2 ~ N(mu_2, C_2) is
|
105 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
106 |
+
Stable version by Dougal J. Sutherland.
|
107 |
+
Params:
|
108 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
109 |
+
inception net (like returned by the function 'get_predictions')
|
110 |
+
for generated samples.
|
111 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
112 |
+
representive data set.
|
113 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
114 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
115 |
+
representive data set.
|
116 |
+
Returns:
|
117 |
+
-- : The Frechet Distance.
|
118 |
+
"""
|
119 |
+
|
120 |
+
mu1 = np.atleast_1d(mu1)
|
121 |
+
mu2 = np.atleast_1d(mu2)
|
122 |
+
|
123 |
+
sigma1 = np.atleast_2d(sigma1)
|
124 |
+
sigma2 = np.atleast_2d(sigma2)
|
125 |
+
|
126 |
+
assert mu1.shape == mu2.shape, \
|
127 |
+
'Training and test mean vectors have different lengths'
|
128 |
+
assert sigma1.shape == sigma2.shape, \
|
129 |
+
'Training and test covariances have different dimensions'
|
130 |
+
|
131 |
+
diff = mu1 - mu2
|
132 |
+
|
133 |
+
# Product might be almost singular
|
134 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
135 |
+
if not np.isfinite(covmean).all():
|
136 |
+
msg = ('fid calculation produces singular product; '
|
137 |
+
'adding %s to diagonal of cov estimates') % eps
|
138 |
+
print(msg)
|
139 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
140 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
141 |
+
|
142 |
+
# Numerical error might give slight imaginary component
|
143 |
+
if np.iscomplexobj(covmean):
|
144 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
145 |
+
m = np.max(np.abs(covmean.imag))
|
146 |
+
raise ValueError('Imaginary component {}'.format(m))
|
147 |
+
covmean = covmean.real
|
148 |
+
|
149 |
+
tr_covmean = np.trace(covmean)
|
150 |
+
|
151 |
+
return (diff.dot(diff) + np.trace(sigma1) + # NOQA
|
152 |
+
np.trace(sigma2) - 2 * tr_covmean)
|
153 |
+
|
154 |
+
|
155 |
+
def get_i3d_activations(batched_video,
|
156 |
+
i3d_model,
|
157 |
+
target_endpoint='Logits',
|
158 |
+
flatten=True,
|
159 |
+
grad_enabled=False):
|
160 |
+
"""
|
161 |
+
Get features from i3d model and flatten them to 1d feature,
|
162 |
+
valid target endpoints are defined in InceptionI3d.VALID_ENDPOINTS
|
163 |
+
VALID_ENDPOINTS = (
|
164 |
+
'Conv3d_1a_7x7',
|
165 |
+
'MaxPool3d_2a_3x3',
|
166 |
+
'Conv3d_2b_1x1',
|
167 |
+
'Conv3d_2c_3x3',
|
168 |
+
'MaxPool3d_3a_3x3',
|
169 |
+
'Mixed_3b',
|
170 |
+
'Mixed_3c',
|
171 |
+
'MaxPool3d_4a_3x3',
|
172 |
+
'Mixed_4b',
|
173 |
+
'Mixed_4c',
|
174 |
+
'Mixed_4d',
|
175 |
+
'Mixed_4e',
|
176 |
+
'Mixed_4f',
|
177 |
+
'MaxPool3d_5a_2x2',
|
178 |
+
'Mixed_5b',
|
179 |
+
'Mixed_5c',
|
180 |
+
'Logits',
|
181 |
+
'Predictions',
|
182 |
+
)
|
183 |
+
"""
|
184 |
+
with torch.set_grad_enabled(grad_enabled):
|
185 |
+
feat = i3d_model.extract_features(batched_video.transpose(1, 2),
|
186 |
+
target_endpoint)
|
187 |
+
if flatten:
|
188 |
+
feat = feat.view(feat.size(0), -1)
|
189 |
+
|
190 |
+
return feat
|
191 |
+
|
192 |
+
|
193 |
+
# This code is from https://github.com/piergiaj/pytorch-i3d/blob/master/pytorch_i3d.py
|
194 |
+
# I only fix flake8 errors and do some cleaning here
|
195 |
+
|
196 |
+
|
197 |
+
class MaxPool3dSamePadding(nn.MaxPool3d):
|
198 |
+
def compute_pad(self, dim, s):
|
199 |
+
if s % self.stride[dim] == 0:
|
200 |
+
return max(self.kernel_size[dim] - self.stride[dim], 0)
|
201 |
+
else:
|
202 |
+
return max(self.kernel_size[dim] - (s % self.stride[dim]), 0)
|
203 |
+
|
204 |
+
def forward(self, x):
|
205 |
+
# compute 'same' padding
|
206 |
+
(batch, channel, t, h, w) = x.size()
|
207 |
+
pad_t = self.compute_pad(0, t)
|
208 |
+
pad_h = self.compute_pad(1, h)
|
209 |
+
pad_w = self.compute_pad(2, w)
|
210 |
+
|
211 |
+
pad_t_f = pad_t // 2
|
212 |
+
pad_t_b = pad_t - pad_t_f
|
213 |
+
pad_h_f = pad_h // 2
|
214 |
+
pad_h_b = pad_h - pad_h_f
|
215 |
+
pad_w_f = pad_w // 2
|
216 |
+
pad_w_b = pad_w - pad_w_f
|
217 |
+
|
218 |
+
pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
|
219 |
+
x = F.pad(x, pad)
|
220 |
+
return super(MaxPool3dSamePadding, self).forward(x)
|
221 |
+
|
222 |
+
|
223 |
+
class Unit3D(nn.Module):
|
224 |
+
def __init__(self,
|
225 |
+
in_channels,
|
226 |
+
output_channels,
|
227 |
+
kernel_shape=(1, 1, 1),
|
228 |
+
stride=(1, 1, 1),
|
229 |
+
padding=0,
|
230 |
+
activation_fn=F.relu,
|
231 |
+
use_batch_norm=True,
|
232 |
+
use_bias=False,
|
233 |
+
name='unit_3d'):
|
234 |
+
"""Initializes Unit3D module."""
|
235 |
+
super(Unit3D, self).__init__()
|
236 |
+
|
237 |
+
self._output_channels = output_channels
|
238 |
+
self._kernel_shape = kernel_shape
|
239 |
+
self._stride = stride
|
240 |
+
self._use_batch_norm = use_batch_norm
|
241 |
+
self._activation_fn = activation_fn
|
242 |
+
self._use_bias = use_bias
|
243 |
+
self.name = name
|
244 |
+
self.padding = padding
|
245 |
+
|
246 |
+
self.conv3d = nn.Conv3d(
|
247 |
+
in_channels=in_channels,
|
248 |
+
out_channels=self._output_channels,
|
249 |
+
kernel_size=self._kernel_shape,
|
250 |
+
stride=self._stride,
|
251 |
+
padding=0, # we always want padding to be 0 here. We will
|
252 |
+
# dynamically pad based on input size in forward function
|
253 |
+
bias=self._use_bias)
|
254 |
+
|
255 |
+
if self._use_batch_norm:
|
256 |
+
self.bn = nn.BatchNorm3d(self._output_channels,
|
257 |
+
eps=0.001,
|
258 |
+
momentum=0.01)
|
259 |
+
|
260 |
+
def compute_pad(self, dim, s):
|
261 |
+
if s % self._stride[dim] == 0:
|
262 |
+
return max(self._kernel_shape[dim] - self._stride[dim], 0)
|
263 |
+
else:
|
264 |
+
return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0)
|
265 |
+
|
266 |
+
def forward(self, x):
|
267 |
+
# compute 'same' padding
|
268 |
+
(batch, channel, t, h, w) = x.size()
|
269 |
+
pad_t = self.compute_pad(0, t)
|
270 |
+
pad_h = self.compute_pad(1, h)
|
271 |
+
pad_w = self.compute_pad(2, w)
|
272 |
+
|
273 |
+
pad_t_f = pad_t // 2
|
274 |
+
pad_t_b = pad_t - pad_t_f
|
275 |
+
pad_h_f = pad_h // 2
|
276 |
+
pad_h_b = pad_h - pad_h_f
|
277 |
+
pad_w_f = pad_w // 2
|
278 |
+
pad_w_b = pad_w - pad_w_f
|
279 |
+
|
280 |
+
pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
|
281 |
+
x = F.pad(x, pad)
|
282 |
+
|
283 |
+
x = self.conv3d(x)
|
284 |
+
if self._use_batch_norm:
|
285 |
+
x = self.bn(x)
|
286 |
+
if self._activation_fn is not None:
|
287 |
+
x = self._activation_fn(x)
|
288 |
+
return x
|
289 |
+
|
290 |
+
|
291 |
+
class InceptionModule(nn.Module):
|
292 |
+
def __init__(self, in_channels, out_channels, name):
|
293 |
+
super(InceptionModule, self).__init__()
|
294 |
+
|
295 |
+
self.b0 = Unit3D(in_channels=in_channels,
|
296 |
+
output_channels=out_channels[0],
|
297 |
+
kernel_shape=[1, 1, 1],
|
298 |
+
padding=0,
|
299 |
+
name=name + '/Branch_0/Conv3d_0a_1x1')
|
300 |
+
self.b1a = Unit3D(in_channels=in_channels,
|
301 |
+
output_channels=out_channels[1],
|
302 |
+
kernel_shape=[1, 1, 1],
|
303 |
+
padding=0,
|
304 |
+
name=name + '/Branch_1/Conv3d_0a_1x1')
|
305 |
+
self.b1b = Unit3D(in_channels=out_channels[1],
|
306 |
+
output_channels=out_channels[2],
|
307 |
+
kernel_shape=[3, 3, 3],
|
308 |
+
name=name + '/Branch_1/Conv3d_0b_3x3')
|
309 |
+
self.b2a = Unit3D(in_channels=in_channels,
|
310 |
+
output_channels=out_channels[3],
|
311 |
+
kernel_shape=[1, 1, 1],
|
312 |
+
padding=0,
|
313 |
+
name=name + '/Branch_2/Conv3d_0a_1x1')
|
314 |
+
self.b2b = Unit3D(in_channels=out_channels[3],
|
315 |
+
output_channels=out_channels[4],
|
316 |
+
kernel_shape=[3, 3, 3],
|
317 |
+
name=name + '/Branch_2/Conv3d_0b_3x3')
|
318 |
+
self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3],
|
319 |
+
stride=(1, 1, 1),
|
320 |
+
padding=0)
|
321 |
+
self.b3b = Unit3D(in_channels=in_channels,
|
322 |
+
output_channels=out_channels[5],
|
323 |
+
kernel_shape=[1, 1, 1],
|
324 |
+
padding=0,
|
325 |
+
name=name + '/Branch_3/Conv3d_0b_1x1')
|
326 |
+
self.name = name
|
327 |
+
|
328 |
+
def forward(self, x):
|
329 |
+
b0 = self.b0(x)
|
330 |
+
b1 = self.b1b(self.b1a(x))
|
331 |
+
b2 = self.b2b(self.b2a(x))
|
332 |
+
b3 = self.b3b(self.b3a(x))
|
333 |
+
return torch.cat([b0, b1, b2, b3], dim=1)
|
334 |
+
|
335 |
+
|
336 |
+
class InceptionI3d(nn.Module):
|
337 |
+
"""Inception-v1 I3D architecture.
|
338 |
+
The model is introduced in:
|
339 |
+
Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset
|
340 |
+
Joao Carreira, Andrew Zisserman
|
341 |
+
https://arxiv.org/pdf/1705.07750v1.pdf.
|
342 |
+
See also the Inception architecture, introduced in:
|
343 |
+
Going deeper with convolutions
|
344 |
+
Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
|
345 |
+
Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
|
346 |
+
http://arxiv.org/pdf/1409.4842v1.pdf.
|
347 |
+
"""
|
348 |
+
|
349 |
+
# Endpoints of the model in order. During construction, all the endpoints up
|
350 |
+
# to a designated `final_endpoint` are returned in a dictionary as the
|
351 |
+
# second return value.
|
352 |
+
VALID_ENDPOINTS = (
|
353 |
+
'Conv3d_1a_7x7',
|
354 |
+
'MaxPool3d_2a_3x3',
|
355 |
+
'Conv3d_2b_1x1',
|
356 |
+
'Conv3d_2c_3x3',
|
357 |
+
'MaxPool3d_3a_3x3',
|
358 |
+
'Mixed_3b',
|
359 |
+
'Mixed_3c',
|
360 |
+
'MaxPool3d_4a_3x3',
|
361 |
+
'Mixed_4b',
|
362 |
+
'Mixed_4c',
|
363 |
+
'Mixed_4d',
|
364 |
+
'Mixed_4e',
|
365 |
+
'Mixed_4f',
|
366 |
+
'MaxPool3d_5a_2x2',
|
367 |
+
'Mixed_5b',
|
368 |
+
'Mixed_5c',
|
369 |
+
'Logits',
|
370 |
+
'Predictions',
|
371 |
+
)
|
372 |
+
|
373 |
+
def __init__(self,
|
374 |
+
num_classes=400,
|
375 |
+
spatial_squeeze=True,
|
376 |
+
final_endpoint='Logits',
|
377 |
+
name='inception_i3d',
|
378 |
+
in_channels=3,
|
379 |
+
dropout_keep_prob=0.5):
|
380 |
+
"""Initializes I3D model instance.
|
381 |
+
Args:
|
382 |
+
num_classes: The number of outputs in the logit layer (default 400, which
|
383 |
+
matches the Kinetics dataset).
|
384 |
+
spatial_squeeze: Whether to squeeze the spatial dimensions for the logits
|
385 |
+
before returning (default True).
|
386 |
+
final_endpoint: The model contains many possible endpoints.
|
387 |
+
`final_endpoint` specifies the last endpoint for the model to be built
|
388 |
+
up to. In addition to the output at `final_endpoint`, all the outputs
|
389 |
+
at endpoints up to `final_endpoint` will also be returned, in a
|
390 |
+
dictionary. `final_endpoint` must be one of
|
391 |
+
InceptionI3d.VALID_ENDPOINTS (default 'Logits').
|
392 |
+
name: A string (optional). The name of this module.
|
393 |
+
Raises:
|
394 |
+
ValueError: if `final_endpoint` is not recognized.
|
395 |
+
"""
|
396 |
+
|
397 |
+
if final_endpoint not in self.VALID_ENDPOINTS:
|
398 |
+
raise ValueError('Unknown final endpoint %s' % final_endpoint)
|
399 |
+
|
400 |
+
super(InceptionI3d, self).__init__()
|
401 |
+
self._num_classes = num_classes
|
402 |
+
self._spatial_squeeze = spatial_squeeze
|
403 |
+
self._final_endpoint = final_endpoint
|
404 |
+
self.logits = None
|
405 |
+
|
406 |
+
if self._final_endpoint not in self.VALID_ENDPOINTS:
|
407 |
+
raise ValueError('Unknown final endpoint %s' %
|
408 |
+
self._final_endpoint)
|
409 |
+
|
410 |
+
self.end_points = {}
|
411 |
+
end_point = 'Conv3d_1a_7x7'
|
412 |
+
self.end_points[end_point] = Unit3D(in_channels=in_channels,
|
413 |
+
output_channels=64,
|
414 |
+
kernel_shape=[7, 7, 7],
|
415 |
+
stride=(2, 2, 2),
|
416 |
+
padding=(3, 3, 3),
|
417 |
+
name=name + end_point)
|
418 |
+
if self._final_endpoint == end_point:
|
419 |
+
return
|
420 |
+
|
421 |
+
end_point = 'MaxPool3d_2a_3x3'
|
422 |
+
self.end_points[end_point] = MaxPool3dSamePadding(
|
423 |
+
kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0)
|
424 |
+
if self._final_endpoint == end_point:
|
425 |
+
return
|
426 |
+
|
427 |
+
end_point = 'Conv3d_2b_1x1'
|
428 |
+
self.end_points[end_point] = Unit3D(in_channels=64,
|
429 |
+
output_channels=64,
|
430 |
+
kernel_shape=[1, 1, 1],
|
431 |
+
padding=0,
|
432 |
+
name=name + end_point)
|
433 |
+
if self._final_endpoint == end_point:
|
434 |
+
return
|
435 |
+
|
436 |
+
end_point = 'Conv3d_2c_3x3'
|
437 |
+
self.end_points[end_point] = Unit3D(in_channels=64,
|
438 |
+
output_channels=192,
|
439 |
+
kernel_shape=[3, 3, 3],
|
440 |
+
padding=1,
|
441 |
+
name=name + end_point)
|
442 |
+
if self._final_endpoint == end_point:
|
443 |
+
return
|
444 |
+
|
445 |
+
end_point = 'MaxPool3d_3a_3x3'
|
446 |
+
self.end_points[end_point] = MaxPool3dSamePadding(
|
447 |
+
kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0)
|
448 |
+
if self._final_endpoint == end_point:
|
449 |
+
return
|
450 |
+
|
451 |
+
end_point = 'Mixed_3b'
|
452 |
+
self.end_points[end_point] = InceptionModule(192,
|
453 |
+
[64, 96, 128, 16, 32, 32],
|
454 |
+
name + end_point)
|
455 |
+
if self._final_endpoint == end_point:
|
456 |
+
return
|
457 |
+
|
458 |
+
end_point = 'Mixed_3c'
|
459 |
+
self.end_points[end_point] = InceptionModule(
|
460 |
+
256, [128, 128, 192, 32, 96, 64], name + end_point)
|
461 |
+
if self._final_endpoint == end_point:
|
462 |
+
return
|
463 |
+
|
464 |
+
end_point = 'MaxPool3d_4a_3x3'
|
465 |
+
self.end_points[end_point] = MaxPool3dSamePadding(
|
466 |
+
kernel_size=[3, 3, 3], stride=(2, 2, 2), padding=0)
|
467 |
+
if self._final_endpoint == end_point:
|
468 |
+
return
|
469 |
+
|
470 |
+
end_point = 'Mixed_4b'
|
471 |
+
self.end_points[end_point] = InceptionModule(
|
472 |
+
128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point)
|
473 |
+
if self._final_endpoint == end_point:
|
474 |
+
return
|
475 |
+
|
476 |
+
end_point = 'Mixed_4c'
|
477 |
+
self.end_points[end_point] = InceptionModule(
|
478 |
+
192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point)
|
479 |
+
if self._final_endpoint == end_point:
|
480 |
+
return
|
481 |
+
|
482 |
+
end_point = 'Mixed_4d'
|
483 |
+
self.end_points[end_point] = InceptionModule(
|
484 |
+
160 + 224 + 64 + 64, [128, 128, 256, 24, 64, 64], name + end_point)
|
485 |
+
if self._final_endpoint == end_point:
|
486 |
+
return
|
487 |
+
|
488 |
+
end_point = 'Mixed_4e'
|
489 |
+
self.end_points[end_point] = InceptionModule(
|
490 |
+
128 + 256 + 64 + 64, [112, 144, 288, 32, 64, 64], name + end_point)
|
491 |
+
if self._final_endpoint == end_point:
|
492 |
+
return
|
493 |
+
|
494 |
+
end_point = 'Mixed_4f'
|
495 |
+
self.end_points[end_point] = InceptionModule(
|
496 |
+
112 + 288 + 64 + 64, [256, 160, 320, 32, 128, 128],
|
497 |
+
name + end_point)
|
498 |
+
if self._final_endpoint == end_point:
|
499 |
+
return
|
500 |
+
|
501 |
+
end_point = 'MaxPool3d_5a_2x2'
|
502 |
+
self.end_points[end_point] = MaxPool3dSamePadding(
|
503 |
+
kernel_size=[2, 2, 2], stride=(2, 2, 2), padding=0)
|
504 |
+
if self._final_endpoint == end_point:
|
505 |
+
return
|
506 |
+
|
507 |
+
end_point = 'Mixed_5b'
|
508 |
+
self.end_points[end_point] = InceptionModule(
|
509 |
+
256 + 320 + 128 + 128, [256, 160, 320, 32, 128, 128],
|
510 |
+
name + end_point)
|
511 |
+
if self._final_endpoint == end_point:
|
512 |
+
return
|
513 |
+
|
514 |
+
end_point = 'Mixed_5c'
|
515 |
+
self.end_points[end_point] = InceptionModule(
|
516 |
+
256 + 320 + 128 + 128, [384, 192, 384, 48, 128, 128],
|
517 |
+
name + end_point)
|
518 |
+
if self._final_endpoint == end_point:
|
519 |
+
return
|
520 |
+
|
521 |
+
end_point = 'Logits'
|
522 |
+
self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7], stride=(1, 1, 1))
|
523 |
+
self.dropout = nn.Dropout(dropout_keep_prob)
|
524 |
+
self.logits = Unit3D(in_channels=384 + 384 + 128 + 128,
|
525 |
+
output_channels=self._num_classes,
|
526 |
+
kernel_shape=[1, 1, 1],
|
527 |
+
padding=0,
|
528 |
+
activation_fn=None,
|
529 |
+
use_batch_norm=False,
|
530 |
+
use_bias=True,
|
531 |
+
name='logits')
|
532 |
+
|
533 |
+
self.build()
|
534 |
+
|
535 |
+
def replace_logits(self, num_classes):
|
536 |
+
self._num_classes = num_classes
|
537 |
+
self.logits = Unit3D(in_channels=384 + 384 + 128 + 128,
|
538 |
+
output_channels=self._num_classes,
|
539 |
+
kernel_shape=[1, 1, 1],
|
540 |
+
padding=0,
|
541 |
+
activation_fn=None,
|
542 |
+
use_batch_norm=False,
|
543 |
+
use_bias=True,
|
544 |
+
name='logits')
|
545 |
+
|
546 |
+
def build(self):
|
547 |
+
for k in self.end_points.keys():
|
548 |
+
self.add_module(k, self.end_points[k])
|
549 |
+
|
550 |
+
def forward(self, x):
|
551 |
+
for end_point in self.VALID_ENDPOINTS:
|
552 |
+
if end_point in self.end_points:
|
553 |
+
x = self._modules[end_point](
|
554 |
+
x) # use _modules to work with dataparallel
|
555 |
+
|
556 |
+
x = self.logits(self.dropout(self.avg_pool(x)))
|
557 |
+
if self._spatial_squeeze:
|
558 |
+
logits = x.squeeze(3).squeeze(3)
|
559 |
+
# logits is batch X time X classes, which is what we want to work with
|
560 |
+
return logits
|
561 |
+
|
562 |
+
def extract_features(self, x, target_endpoint='Logits'):
|
563 |
+
for end_point in self.VALID_ENDPOINTS:
|
564 |
+
if end_point in self.end_points:
|
565 |
+
x = self._modules[end_point](x)
|
566 |
+
if end_point == target_endpoint:
|
567 |
+
break
|
568 |
+
if target_endpoint == 'Logits':
|
569 |
+
return x.mean(4).mean(3).mean(2)
|
570 |
+
else:
|
571 |
+
return x
|
propainter/core/prefetch_dataloader.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import queue as Queue
|
2 |
+
import threading
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
|
6 |
+
|
7 |
+
class PrefetchGenerator(threading.Thread):
|
8 |
+
"""A general prefetch generator.
|
9 |
+
|
10 |
+
Ref:
|
11 |
+
https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
|
12 |
+
|
13 |
+
Args:
|
14 |
+
generator: Python generator.
|
15 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, generator, num_prefetch_queue):
|
19 |
+
threading.Thread.__init__(self)
|
20 |
+
self.queue = Queue.Queue(num_prefetch_queue)
|
21 |
+
self.generator = generator
|
22 |
+
self.daemon = True
|
23 |
+
self.start()
|
24 |
+
|
25 |
+
def run(self):
|
26 |
+
for item in self.generator:
|
27 |
+
self.queue.put(item)
|
28 |
+
self.queue.put(None)
|
29 |
+
|
30 |
+
def __next__(self):
|
31 |
+
next_item = self.queue.get()
|
32 |
+
if next_item is None:
|
33 |
+
raise StopIteration
|
34 |
+
return next_item
|
35 |
+
|
36 |
+
def __iter__(self):
|
37 |
+
return self
|
38 |
+
|
39 |
+
|
40 |
+
class PrefetchDataLoader(DataLoader):
|
41 |
+
"""Prefetch version of dataloader.
|
42 |
+
|
43 |
+
Ref:
|
44 |
+
https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
|
45 |
+
|
46 |
+
TODO:
|
47 |
+
Need to test on single gpu and ddp (multi-gpu). There is a known issue in
|
48 |
+
ddp.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
52 |
+
kwargs (dict): Other arguments for dataloader.
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(self, num_prefetch_queue, **kwargs):
|
56 |
+
self.num_prefetch_queue = num_prefetch_queue
|
57 |
+
super(PrefetchDataLoader, self).__init__(**kwargs)
|
58 |
+
|
59 |
+
def __iter__(self):
|
60 |
+
return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
|
61 |
+
|
62 |
+
|
63 |
+
class CPUPrefetcher():
|
64 |
+
"""CPU prefetcher.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
loader: Dataloader.
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, loader):
|
71 |
+
self.ori_loader = loader
|
72 |
+
self.loader = iter(loader)
|
73 |
+
|
74 |
+
def next(self):
|
75 |
+
try:
|
76 |
+
return next(self.loader)
|
77 |
+
except StopIteration:
|
78 |
+
return None
|
79 |
+
|
80 |
+
def reset(self):
|
81 |
+
self.loader = iter(self.ori_loader)
|
82 |
+
|
83 |
+
|
84 |
+
class CUDAPrefetcher():
|
85 |
+
"""CUDA prefetcher.
|
86 |
+
|
87 |
+
Ref:
|
88 |
+
https://github.com/NVIDIA/apex/issues/304#
|
89 |
+
|
90 |
+
It may consums more GPU memory.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
loader: Dataloader.
|
94 |
+
opt (dict): Options.
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(self, loader, opt):
|
98 |
+
self.ori_loader = loader
|
99 |
+
self.loader = iter(loader)
|
100 |
+
self.opt = opt
|
101 |
+
self.stream = torch.cuda.Stream()
|
102 |
+
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
|
103 |
+
self.preload()
|
104 |
+
|
105 |
+
def preload(self):
|
106 |
+
try:
|
107 |
+
self.batch = next(self.loader) # self.batch is a dict
|
108 |
+
except StopIteration:
|
109 |
+
self.batch = None
|
110 |
+
return None
|
111 |
+
# put tensors to gpu
|
112 |
+
with torch.cuda.stream(self.stream):
|
113 |
+
for k, v in self.batch.items():
|
114 |
+
if torch.is_tensor(v):
|
115 |
+
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
|
116 |
+
|
117 |
+
def next(self):
|
118 |
+
torch.cuda.current_stream().wait_stream(self.stream)
|
119 |
+
batch = self.batch
|
120 |
+
self.preload()
|
121 |
+
return batch
|
122 |
+
|
123 |
+
def reset(self):
|
124 |
+
self.loader = iter(self.ori_loader)
|
125 |
+
self.preload()
|
propainter/core/trainer.py
ADDED
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import logging
|
4 |
+
import importlib
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from core.prefetch_dataloader import PrefetchDataLoader, CPUPrefetcher
|
11 |
+
from torch.utils.data.distributed import DistributedSampler
|
12 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
13 |
+
import torchvision
|
14 |
+
from torch.utils.tensorboard import SummaryWriter
|
15 |
+
|
16 |
+
from core.lr_scheduler import MultiStepRestartLR, CosineAnnealingRestartLR
|
17 |
+
from core.loss import AdversarialLoss, PerceptualLoss, LPIPSLoss
|
18 |
+
from core.dataset import TrainDataset
|
19 |
+
|
20 |
+
from model.modules.flow_comp_raft import RAFT_bi, FlowLoss, EdgeLoss
|
21 |
+
from model.recurrent_flow_completion import RecurrentFlowCompleteNet
|
22 |
+
|
23 |
+
from RAFT.utils.flow_viz_pt import flow_to_image
|
24 |
+
|
25 |
+
|
26 |
+
class Trainer:
|
27 |
+
def __init__(self, config):
|
28 |
+
self.config = config
|
29 |
+
self.epoch = 0
|
30 |
+
self.iteration = 0
|
31 |
+
self.num_local_frames = config['train_data_loader']['num_local_frames']
|
32 |
+
self.num_ref_frames = config['train_data_loader']['num_ref_frames']
|
33 |
+
|
34 |
+
# setup data set and data loader
|
35 |
+
self.train_dataset = TrainDataset(config['train_data_loader'])
|
36 |
+
|
37 |
+
self.train_sampler = None
|
38 |
+
self.train_args = config['trainer']
|
39 |
+
if config['distributed']:
|
40 |
+
self.train_sampler = DistributedSampler(
|
41 |
+
self.train_dataset,
|
42 |
+
num_replicas=config['world_size'],
|
43 |
+
rank=config['global_rank'])
|
44 |
+
|
45 |
+
dataloader_args = dict(
|
46 |
+
dataset=self.train_dataset,
|
47 |
+
batch_size=self.train_args['batch_size'] // config['world_size'],
|
48 |
+
shuffle=(self.train_sampler is None),
|
49 |
+
num_workers=self.train_args['num_workers'],
|
50 |
+
sampler=self.train_sampler,
|
51 |
+
drop_last=True)
|
52 |
+
|
53 |
+
self.train_loader = PrefetchDataLoader(self.train_args['num_prefetch_queue'], **dataloader_args)
|
54 |
+
self.prefetcher = CPUPrefetcher(self.train_loader)
|
55 |
+
|
56 |
+
# set loss functions
|
57 |
+
self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS'])
|
58 |
+
self.adversarial_loss = self.adversarial_loss.to(self.config['device'])
|
59 |
+
self.l1_loss = nn.L1Loss()
|
60 |
+
# self.perc_loss = PerceptualLoss(
|
61 |
+
# layer_weights={'conv3_4': 0.25, 'conv4_4': 0.25, 'conv5_4': 0.5},
|
62 |
+
# use_input_norm=True,
|
63 |
+
# range_norm=True,
|
64 |
+
# criterion='l1'
|
65 |
+
# ).to(self.config['device'])
|
66 |
+
|
67 |
+
if self.config['losses']['perceptual_weight'] > 0:
|
68 |
+
self.perc_loss = LPIPSLoss(use_input_norm=True, range_norm=True).to(self.config['device'])
|
69 |
+
|
70 |
+
# self.flow_comp_loss = FlowCompletionLoss().to(self.config['device'])
|
71 |
+
# self.flow_comp_loss = FlowCompletionLoss(self.config['device'])
|
72 |
+
|
73 |
+
# set raft
|
74 |
+
self.fix_raft = RAFT_bi(device = self.config['device'])
|
75 |
+
self.fix_flow_complete = RecurrentFlowCompleteNet('weights/recurrent_flow_completion.pth')
|
76 |
+
for p in self.fix_flow_complete.parameters():
|
77 |
+
p.requires_grad = False
|
78 |
+
self.fix_flow_complete.to(self.config['device'])
|
79 |
+
self.fix_flow_complete.eval()
|
80 |
+
|
81 |
+
# self.flow_loss = FlowLoss()
|
82 |
+
|
83 |
+
# setup models including generator and discriminator
|
84 |
+
net = importlib.import_module('model.' + config['model']['net'])
|
85 |
+
self.netG = net.InpaintGenerator()
|
86 |
+
# print(self.netG)
|
87 |
+
self.netG = self.netG.to(self.config['device'])
|
88 |
+
if not self.config['model'].get('no_dis', False):
|
89 |
+
if self.config['model'].get('dis_2d', False):
|
90 |
+
self.netD = net.Discriminator_2D(
|
91 |
+
in_channels=3,
|
92 |
+
use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge')
|
93 |
+
else:
|
94 |
+
self.netD = net.Discriminator(
|
95 |
+
in_channels=3,
|
96 |
+
use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge')
|
97 |
+
self.netD = self.netD.to(self.config['device'])
|
98 |
+
|
99 |
+
self.interp_mode = self.config['model']['interp_mode']
|
100 |
+
# setup optimizers and schedulers
|
101 |
+
self.setup_optimizers()
|
102 |
+
self.setup_schedulers()
|
103 |
+
self.load()
|
104 |
+
|
105 |
+
if config['distributed']:
|
106 |
+
self.netG = DDP(self.netG,
|
107 |
+
device_ids=[self.config['local_rank']],
|
108 |
+
output_device=self.config['local_rank'],
|
109 |
+
broadcast_buffers=True,
|
110 |
+
find_unused_parameters=True)
|
111 |
+
if not self.config['model']['no_dis']:
|
112 |
+
self.netD = DDP(self.netD,
|
113 |
+
device_ids=[self.config['local_rank']],
|
114 |
+
output_device=self.config['local_rank'],
|
115 |
+
broadcast_buffers=True,
|
116 |
+
find_unused_parameters=False)
|
117 |
+
|
118 |
+
# set summary writer
|
119 |
+
self.dis_writer = None
|
120 |
+
self.gen_writer = None
|
121 |
+
self.summary = {}
|
122 |
+
if self.config['global_rank'] == 0 or (not config['distributed']):
|
123 |
+
if not self.config['model']['no_dis']:
|
124 |
+
self.dis_writer = SummaryWriter(
|
125 |
+
os.path.join(config['save_dir'], 'dis'))
|
126 |
+
self.gen_writer = SummaryWriter(
|
127 |
+
os.path.join(config['save_dir'], 'gen'))
|
128 |
+
|
129 |
+
def setup_optimizers(self):
|
130 |
+
"""Set up optimizers."""
|
131 |
+
backbone_params = []
|
132 |
+
for name, param in self.netG.named_parameters():
|
133 |
+
if param.requires_grad:
|
134 |
+
backbone_params.append(param)
|
135 |
+
else:
|
136 |
+
print(f'Params {name} will not be optimized.')
|
137 |
+
|
138 |
+
optim_params = [
|
139 |
+
{
|
140 |
+
'params': backbone_params,
|
141 |
+
'lr': self.config['trainer']['lr']
|
142 |
+
},
|
143 |
+
]
|
144 |
+
|
145 |
+
self.optimG = torch.optim.Adam(optim_params,
|
146 |
+
betas=(self.config['trainer']['beta1'],
|
147 |
+
self.config['trainer']['beta2']))
|
148 |
+
|
149 |
+
if not self.config['model']['no_dis']:
|
150 |
+
self.optimD = torch.optim.Adam(
|
151 |
+
self.netD.parameters(),
|
152 |
+
lr=self.config['trainer']['lr'],
|
153 |
+
betas=(self.config['trainer']['beta1'],
|
154 |
+
self.config['trainer']['beta2']))
|
155 |
+
|
156 |
+
def setup_schedulers(self):
|
157 |
+
"""Set up schedulers."""
|
158 |
+
scheduler_opt = self.config['trainer']['scheduler']
|
159 |
+
scheduler_type = scheduler_opt.pop('type')
|
160 |
+
|
161 |
+
if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
|
162 |
+
self.scheG = MultiStepRestartLR(
|
163 |
+
self.optimG,
|
164 |
+
milestones=scheduler_opt['milestones'],
|
165 |
+
gamma=scheduler_opt['gamma'])
|
166 |
+
if not self.config['model']['no_dis']:
|
167 |
+
self.scheD = MultiStepRestartLR(
|
168 |
+
self.optimD,
|
169 |
+
milestones=scheduler_opt['milestones'],
|
170 |
+
gamma=scheduler_opt['gamma'])
|
171 |
+
elif scheduler_type == 'CosineAnnealingRestartLR':
|
172 |
+
self.scheG = CosineAnnealingRestartLR(
|
173 |
+
self.optimG,
|
174 |
+
periods=scheduler_opt['periods'],
|
175 |
+
restart_weights=scheduler_opt['restart_weights'],
|
176 |
+
eta_min=scheduler_opt['eta_min'])
|
177 |
+
if not self.config['model']['no_dis']:
|
178 |
+
self.scheD = CosineAnnealingRestartLR(
|
179 |
+
self.optimD,
|
180 |
+
periods=scheduler_opt['periods'],
|
181 |
+
restart_weights=scheduler_opt['restart_weights'],
|
182 |
+
eta_min=scheduler_opt['eta_min'])
|
183 |
+
else:
|
184 |
+
raise NotImplementedError(
|
185 |
+
f'Scheduler {scheduler_type} is not implemented yet.')
|
186 |
+
|
187 |
+
def update_learning_rate(self):
|
188 |
+
"""Update learning rate."""
|
189 |
+
self.scheG.step()
|
190 |
+
if not self.config['model']['no_dis']:
|
191 |
+
self.scheD.step()
|
192 |
+
|
193 |
+
def get_lr(self):
|
194 |
+
"""Get current learning rate."""
|
195 |
+
return self.optimG.param_groups[0]['lr']
|
196 |
+
|
197 |
+
def add_summary(self, writer, name, val):
|
198 |
+
"""Add tensorboard summary."""
|
199 |
+
if name not in self.summary:
|
200 |
+
self.summary[name] = 0
|
201 |
+
self.summary[name] += val
|
202 |
+
n = self.train_args['log_freq']
|
203 |
+
if writer is not None and self.iteration % n == 0:
|
204 |
+
writer.add_scalar(name, self.summary[name] / n, self.iteration)
|
205 |
+
self.summary[name] = 0
|
206 |
+
|
207 |
+
def load(self):
|
208 |
+
"""Load netG (and netD)."""
|
209 |
+
# get the latest checkpoint
|
210 |
+
model_path = self.config['save_dir']
|
211 |
+
# TODO: add resume name
|
212 |
+
if os.path.isfile(os.path.join(model_path, 'latest.ckpt')):
|
213 |
+
latest_epoch = open(os.path.join(model_path, 'latest.ckpt'),
|
214 |
+
'r').read().splitlines()[-1]
|
215 |
+
else:
|
216 |
+
ckpts = [
|
217 |
+
os.path.basename(i).split('.pth')[0]
|
218 |
+
for i in glob.glob(os.path.join(model_path, '*.pth'))
|
219 |
+
]
|
220 |
+
ckpts.sort()
|
221 |
+
latest_epoch = ckpts[-1][4:] if len(ckpts) > 0 else None
|
222 |
+
|
223 |
+
if latest_epoch is not None:
|
224 |
+
gen_path = os.path.join(model_path,
|
225 |
+
f'gen_{int(latest_epoch):06d}.pth')
|
226 |
+
dis_path = os.path.join(model_path,
|
227 |
+
f'dis_{int(latest_epoch):06d}.pth')
|
228 |
+
opt_path = os.path.join(model_path,
|
229 |
+
f'opt_{int(latest_epoch):06d}.pth')
|
230 |
+
|
231 |
+
if self.config['global_rank'] == 0:
|
232 |
+
print(f'Loading model from {gen_path}...')
|
233 |
+
dataG = torch.load(gen_path, map_location=self.config['device'])
|
234 |
+
self.netG.load_state_dict(dataG)
|
235 |
+
if not self.config['model']['no_dis'] and self.config['model']['load_d']:
|
236 |
+
dataD = torch.load(dis_path, map_location=self.config['device'])
|
237 |
+
self.netD.load_state_dict(dataD)
|
238 |
+
|
239 |
+
data_opt = torch.load(opt_path, map_location=self.config['device'])
|
240 |
+
self.optimG.load_state_dict(data_opt['optimG'])
|
241 |
+
# self.scheG.load_state_dict(data_opt['scheG'])
|
242 |
+
if not self.config['model']['no_dis'] and self.config['model']['load_d']:
|
243 |
+
self.optimD.load_state_dict(data_opt['optimD'])
|
244 |
+
# self.scheD.load_state_dict(data_opt['scheD'])
|
245 |
+
self.epoch = data_opt['epoch']
|
246 |
+
self.iteration = data_opt['iteration']
|
247 |
+
else:
|
248 |
+
gen_path = self.config['trainer'].get('gen_path', None)
|
249 |
+
dis_path = self.config['trainer'].get('dis_path', None)
|
250 |
+
opt_path = self.config['trainer'].get('opt_path', None)
|
251 |
+
if gen_path is not None:
|
252 |
+
if self.config['global_rank'] == 0:
|
253 |
+
print(f'Loading Gen-Net from {gen_path}...')
|
254 |
+
dataG = torch.load(gen_path, map_location=self.config['device'])
|
255 |
+
self.netG.load_state_dict(dataG)
|
256 |
+
|
257 |
+
if dis_path is not None and not self.config['model']['no_dis'] and self.config['model']['load_d']:
|
258 |
+
if self.config['global_rank'] == 0:
|
259 |
+
print(f'Loading Dis-Net from {dis_path}...')
|
260 |
+
dataD = torch.load(dis_path, map_location=self.config['device'])
|
261 |
+
self.netD.load_state_dict(dataD)
|
262 |
+
if opt_path is not None:
|
263 |
+
data_opt = torch.load(opt_path, map_location=self.config['device'])
|
264 |
+
self.optimG.load_state_dict(data_opt['optimG'])
|
265 |
+
self.scheG.load_state_dict(data_opt['scheG'])
|
266 |
+
if not self.config['model']['no_dis'] and self.config['model']['load_d']:
|
267 |
+
self.optimD.load_state_dict(data_opt['optimD'])
|
268 |
+
self.scheD.load_state_dict(data_opt['scheD'])
|
269 |
+
else:
|
270 |
+
if self.config['global_rank'] == 0:
|
271 |
+
print('Warnning: There is no trained model found.'
|
272 |
+
'An initialized model will be used.')
|
273 |
+
|
274 |
+
def save(self, it):
|
275 |
+
"""Save parameters every eval_epoch"""
|
276 |
+
if self.config['global_rank'] == 0:
|
277 |
+
# configure path
|
278 |
+
gen_path = os.path.join(self.config['save_dir'],
|
279 |
+
f'gen_{it:06d}.pth')
|
280 |
+
dis_path = os.path.join(self.config['save_dir'],
|
281 |
+
f'dis_{it:06d}.pth')
|
282 |
+
opt_path = os.path.join(self.config['save_dir'],
|
283 |
+
f'opt_{it:06d}.pth')
|
284 |
+
print(f'\nsaving model to {gen_path} ...')
|
285 |
+
|
286 |
+
# remove .module for saving
|
287 |
+
if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP):
|
288 |
+
netG = self.netG.module
|
289 |
+
if not self.config['model']['no_dis']:
|
290 |
+
netD = self.netD.module
|
291 |
+
else:
|
292 |
+
netG = self.netG
|
293 |
+
if not self.config['model']['no_dis']:
|
294 |
+
netD = self.netD
|
295 |
+
|
296 |
+
# save checkpoints
|
297 |
+
torch.save(netG.state_dict(), gen_path)
|
298 |
+
if not self.config['model']['no_dis']:
|
299 |
+
torch.save(netD.state_dict(), dis_path)
|
300 |
+
torch.save(
|
301 |
+
{
|
302 |
+
'epoch': self.epoch,
|
303 |
+
'iteration': self.iteration,
|
304 |
+
'optimG': self.optimG.state_dict(),
|
305 |
+
'optimD': self.optimD.state_dict(),
|
306 |
+
'scheG': self.scheG.state_dict(),
|
307 |
+
'scheD': self.scheD.state_dict()
|
308 |
+
}, opt_path)
|
309 |
+
else:
|
310 |
+
torch.save(
|
311 |
+
{
|
312 |
+
'epoch': self.epoch,
|
313 |
+
'iteration': self.iteration,
|
314 |
+
'optimG': self.optimG.state_dict(),
|
315 |
+
'scheG': self.scheG.state_dict()
|
316 |
+
}, opt_path)
|
317 |
+
|
318 |
+
latest_path = os.path.join(self.config['save_dir'], 'latest.ckpt')
|
319 |
+
os.system(f"echo {it:06d} > {latest_path}")
|
320 |
+
|
321 |
+
def train(self):
|
322 |
+
"""training entry"""
|
323 |
+
pbar = range(int(self.train_args['iterations']))
|
324 |
+
if self.config['global_rank'] == 0:
|
325 |
+
pbar = tqdm(pbar,
|
326 |
+
initial=self.iteration,
|
327 |
+
dynamic_ncols=True,
|
328 |
+
smoothing=0.01)
|
329 |
+
|
330 |
+
os.makedirs('logs', exist_ok=True)
|
331 |
+
|
332 |
+
logging.basicConfig(
|
333 |
+
level=logging.INFO,
|
334 |
+
format="%(asctime)s %(filename)s[line:%(lineno)d]"
|
335 |
+
"%(levelname)s %(message)s",
|
336 |
+
datefmt="%a, %d %b %Y %H:%M:%S",
|
337 |
+
filename=f"logs/{self.config['save_dir'].split('/')[-1]}.log",
|
338 |
+
filemode='w')
|
339 |
+
|
340 |
+
while True:
|
341 |
+
self.epoch += 1
|
342 |
+
self.prefetcher.reset()
|
343 |
+
if self.config['distributed']:
|
344 |
+
self.train_sampler.set_epoch(self.epoch)
|
345 |
+
self._train_epoch(pbar)
|
346 |
+
if self.iteration > self.train_args['iterations']:
|
347 |
+
break
|
348 |
+
print('\nEnd training....')
|
349 |
+
|
350 |
+
def _train_epoch(self, pbar):
|
351 |
+
"""Process input and calculate loss every training epoch"""
|
352 |
+
device = self.config['device']
|
353 |
+
train_data = self.prefetcher.next()
|
354 |
+
while train_data is not None:
|
355 |
+
self.iteration += 1
|
356 |
+
frames, masks, flows_f, flows_b, _ = train_data
|
357 |
+
frames, masks = frames.to(device), masks.to(device).float()
|
358 |
+
l_t = self.num_local_frames
|
359 |
+
b, t, c, h, w = frames.size()
|
360 |
+
gt_local_frames = frames[:, :l_t, ...]
|
361 |
+
local_masks = masks[:, :l_t, ...].contiguous()
|
362 |
+
|
363 |
+
masked_frames = frames * (1 - masks)
|
364 |
+
masked_local_frames = masked_frames[:, :l_t, ...]
|
365 |
+
# get gt optical flow
|
366 |
+
if flows_f[0] == 'None' or flows_b[0] == 'None':
|
367 |
+
gt_flows_bi = self.fix_raft(gt_local_frames)
|
368 |
+
else:
|
369 |
+
gt_flows_bi = (flows_f.to(device), flows_b.to(device))
|
370 |
+
|
371 |
+
# ---- complete flow ----
|
372 |
+
pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, local_masks)
|
373 |
+
pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, local_masks)
|
374 |
+
# pred_flows_bi = gt_flows_bi
|
375 |
+
|
376 |
+
# ---- image propagation ----
|
377 |
+
prop_imgs, updated_local_masks = self.netG.module.img_propagation(masked_local_frames, pred_flows_bi, local_masks, interpolation=self.interp_mode)
|
378 |
+
updated_masks = masks.clone()
|
379 |
+
updated_masks[:, :l_t, ...] = updated_local_masks.view(b, l_t, 1, h, w)
|
380 |
+
updated_frames = masked_frames.clone()
|
381 |
+
prop_local_frames = gt_local_frames * (1-local_masks) + prop_imgs.view(b, l_t, 3, h, w) * local_masks # merge
|
382 |
+
updated_frames[:, :l_t, ...] = prop_local_frames
|
383 |
+
|
384 |
+
# ---- feature propagation + Transformer ----
|
385 |
+
pred_imgs = self.netG(updated_frames, pred_flows_bi, masks, updated_masks, l_t)
|
386 |
+
pred_imgs = pred_imgs.view(b, -1, c, h, w)
|
387 |
+
|
388 |
+
# get the local frames
|
389 |
+
pred_local_frames = pred_imgs[:, :l_t, ...]
|
390 |
+
comp_local_frames = gt_local_frames * (1. - local_masks) + pred_local_frames * local_masks
|
391 |
+
comp_imgs = frames * (1. - masks) + pred_imgs * masks
|
392 |
+
|
393 |
+
gen_loss = 0
|
394 |
+
dis_loss = 0
|
395 |
+
# optimize net_g
|
396 |
+
if not self.config['model']['no_dis']:
|
397 |
+
for p in self.netD.parameters():
|
398 |
+
p.requires_grad = False
|
399 |
+
|
400 |
+
self.optimG.zero_grad()
|
401 |
+
|
402 |
+
# generator l1 loss
|
403 |
+
hole_loss = self.l1_loss(pred_imgs * masks, frames * masks)
|
404 |
+
hole_loss = hole_loss / torch.mean(masks) * self.config['losses']['hole_weight']
|
405 |
+
gen_loss += hole_loss
|
406 |
+
self.add_summary(self.gen_writer, 'loss/hole_loss', hole_loss.item())
|
407 |
+
|
408 |
+
valid_loss = self.l1_loss(pred_imgs * (1 - masks), frames * (1 - masks))
|
409 |
+
valid_loss = valid_loss / torch.mean(1-masks) * self.config['losses']['valid_weight']
|
410 |
+
gen_loss += valid_loss
|
411 |
+
self.add_summary(self.gen_writer, 'loss/valid_loss', valid_loss.item())
|
412 |
+
|
413 |
+
# perceptual loss
|
414 |
+
if self.config['losses']['perceptual_weight'] > 0:
|
415 |
+
perc_loss = self.perc_loss(pred_imgs.view(-1,3,h,w), frames.view(-1,3,h,w))[0] * self.config['losses']['perceptual_weight']
|
416 |
+
gen_loss += perc_loss
|
417 |
+
self.add_summary(self.gen_writer, 'loss/perc_loss', perc_loss.item())
|
418 |
+
|
419 |
+
# gan loss
|
420 |
+
if not self.config['model']['no_dis']:
|
421 |
+
# generator adversarial loss
|
422 |
+
gen_clip = self.netD(comp_imgs)
|
423 |
+
gan_loss = self.adversarial_loss(gen_clip, True, False)
|
424 |
+
gan_loss = gan_loss * self.config['losses']['adversarial_weight']
|
425 |
+
gen_loss += gan_loss
|
426 |
+
self.add_summary(self.gen_writer, 'loss/gan_loss', gan_loss.item())
|
427 |
+
gen_loss.backward()
|
428 |
+
self.optimG.step()
|
429 |
+
|
430 |
+
if not self.config['model']['no_dis']:
|
431 |
+
# optimize net_d
|
432 |
+
for p in self.netD.parameters():
|
433 |
+
p.requires_grad = True
|
434 |
+
self.optimD.zero_grad()
|
435 |
+
|
436 |
+
# discriminator adversarial loss
|
437 |
+
real_clip = self.netD(frames)
|
438 |
+
fake_clip = self.netD(comp_imgs.detach())
|
439 |
+
dis_real_loss = self.adversarial_loss(real_clip, True, True)
|
440 |
+
dis_fake_loss = self.adversarial_loss(fake_clip, False, True)
|
441 |
+
dis_loss += (dis_real_loss + dis_fake_loss) / 2
|
442 |
+
self.add_summary(self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item())
|
443 |
+
self.add_summary(self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item())
|
444 |
+
dis_loss.backward()
|
445 |
+
self.optimD.step()
|
446 |
+
|
447 |
+
self.update_learning_rate()
|
448 |
+
|
449 |
+
# write image to tensorboard
|
450 |
+
if self.iteration % 200 == 0:
|
451 |
+
# img to cpu
|
452 |
+
t = 0
|
453 |
+
gt_local_frames_cpu = ((gt_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu()
|
454 |
+
masked_local_frames = ((masked_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu()
|
455 |
+
prop_local_frames_cpu = ((prop_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu()
|
456 |
+
pred_local_frames_cpu = ((pred_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu()
|
457 |
+
img_results = torch.cat([masked_local_frames[0][t], gt_local_frames_cpu[0][t],
|
458 |
+
prop_local_frames_cpu[0][t], pred_local_frames_cpu[0][t]], 1)
|
459 |
+
img_results = torchvision.utils.make_grid(img_results, nrow=1, normalize=True)
|
460 |
+
if self.gen_writer is not None:
|
461 |
+
self.gen_writer.add_image(f'img/img:inp-gt-res-{t}', img_results, self.iteration)
|
462 |
+
|
463 |
+
t = 5
|
464 |
+
if masked_local_frames.shape[1] > 5:
|
465 |
+
img_results = torch.cat([masked_local_frames[0][t], gt_local_frames_cpu[0][t],
|
466 |
+
prop_local_frames_cpu[0][t], pred_local_frames_cpu[0][t]], 1)
|
467 |
+
img_results = torchvision.utils.make_grid(img_results, nrow=1, normalize=True)
|
468 |
+
if self.gen_writer is not None:
|
469 |
+
self.gen_writer.add_image(f'img/img:inp-gt-res-{t}', img_results, self.iteration)
|
470 |
+
|
471 |
+
# flow to cpu
|
472 |
+
gt_flows_forward_cpu = flow_to_image(gt_flows_bi[0][0]).cpu()
|
473 |
+
masked_flows_forward_cpu = (gt_flows_forward_cpu[0] * (1-local_masks[0][0].cpu())).to(gt_flows_forward_cpu)
|
474 |
+
pred_flows_forward_cpu = flow_to_image(pred_flows_bi[0][0]).cpu()
|
475 |
+
|
476 |
+
flow_results = torch.cat([gt_flows_forward_cpu[0], masked_flows_forward_cpu, pred_flows_forward_cpu[0]], 1)
|
477 |
+
if self.gen_writer is not None:
|
478 |
+
self.gen_writer.add_image('img/flow:gt-pred', flow_results, self.iteration)
|
479 |
+
|
480 |
+
# console logs
|
481 |
+
if self.config['global_rank'] == 0:
|
482 |
+
pbar.update(1)
|
483 |
+
if not self.config['model']['no_dis']:
|
484 |
+
pbar.set_description((f"d: {dis_loss.item():.3f}; "
|
485 |
+
f"hole: {hole_loss.item():.3f}; "
|
486 |
+
f"valid: {valid_loss.item():.3f}"))
|
487 |
+
else:
|
488 |
+
pbar.set_description((f"hole: {hole_loss.item():.3f}; "
|
489 |
+
f"valid: {valid_loss.item():.3f}"))
|
490 |
+
|
491 |
+
if self.iteration % self.train_args['log_freq'] == 0:
|
492 |
+
if not self.config['model']['no_dis']:
|
493 |
+
logging.info(f"[Iter {self.iteration}] "
|
494 |
+
f"d: {dis_loss.item():.4f}; "
|
495 |
+
f"hole: {hole_loss.item():.4f}; "
|
496 |
+
f"valid: {valid_loss.item():.4f}")
|
497 |
+
else:
|
498 |
+
logging.info(f"[Iter {self.iteration}] "
|
499 |
+
f"hole: {hole_loss.item():.4f}; "
|
500 |
+
f"valid: {valid_loss.item():.4f}")
|
501 |
+
|
502 |
+
# saving models
|
503 |
+
if self.iteration % self.train_args['save_freq'] == 0:
|
504 |
+
self.save(int(self.iteration))
|
505 |
+
|
506 |
+
if self.iteration > self.train_args['iterations']:
|
507 |
+
break
|
508 |
+
|
509 |
+
train_data = self.prefetcher.next()
|
propainter/core/trainer_flow_w_edge.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import logging
|
4 |
+
import importlib
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from core.prefetch_dataloader import PrefetchDataLoader, CPUPrefetcher
|
11 |
+
from torch.utils.data.distributed import DistributedSampler
|
12 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
13 |
+
|
14 |
+
from torch.utils.tensorboard import SummaryWriter
|
15 |
+
|
16 |
+
from core.lr_scheduler import MultiStepRestartLR, CosineAnnealingRestartLR
|
17 |
+
from core.dataset import TrainDataset
|
18 |
+
|
19 |
+
from model.modules.flow_comp_raft import RAFT_bi, FlowLoss, EdgeLoss
|
20 |
+
|
21 |
+
# from skimage.feature import canny
|
22 |
+
from model.canny.canny_filter import Canny
|
23 |
+
from RAFT.utils.flow_viz_pt import flow_to_image
|
24 |
+
|
25 |
+
|
26 |
+
class Trainer:
|
27 |
+
def __init__(self, config):
|
28 |
+
self.config = config
|
29 |
+
self.epoch = 0
|
30 |
+
self.iteration = 0
|
31 |
+
self.num_local_frames = config['train_data_loader']['num_local_frames']
|
32 |
+
self.num_ref_frames = config['train_data_loader']['num_ref_frames']
|
33 |
+
|
34 |
+
# setup data set and data loader
|
35 |
+
self.train_dataset = TrainDataset(config['train_data_loader'])
|
36 |
+
|
37 |
+
self.train_sampler = None
|
38 |
+
self.train_args = config['trainer']
|
39 |
+
if config['distributed']:
|
40 |
+
self.train_sampler = DistributedSampler(
|
41 |
+
self.train_dataset,
|
42 |
+
num_replicas=config['world_size'],
|
43 |
+
rank=config['global_rank'])
|
44 |
+
|
45 |
+
dataloader_args = dict(
|
46 |
+
dataset=self.train_dataset,
|
47 |
+
batch_size=self.train_args['batch_size'] // config['world_size'],
|
48 |
+
shuffle=(self.train_sampler is None),
|
49 |
+
num_workers=self.train_args['num_workers'],
|
50 |
+
sampler=self.train_sampler,
|
51 |
+
drop_last=True)
|
52 |
+
|
53 |
+
self.train_loader = PrefetchDataLoader(self.train_args['num_prefetch_queue'], **dataloader_args)
|
54 |
+
self.prefetcher = CPUPrefetcher(self.train_loader)
|
55 |
+
|
56 |
+
# set raft
|
57 |
+
self.fix_raft = RAFT_bi(device = self.config['device'])
|
58 |
+
self.flow_loss = FlowLoss()
|
59 |
+
self.edge_loss = EdgeLoss()
|
60 |
+
self.canny = Canny(sigma=(2,2), low_threshold=0.1, high_threshold=0.2)
|
61 |
+
|
62 |
+
# setup models including generator and discriminator
|
63 |
+
net = importlib.import_module('model.' + config['model']['net'])
|
64 |
+
self.netG = net.RecurrentFlowCompleteNet()
|
65 |
+
# print(self.netG)
|
66 |
+
self.netG = self.netG.to(self.config['device'])
|
67 |
+
|
68 |
+
# setup optimizers and schedulers
|
69 |
+
self.setup_optimizers()
|
70 |
+
self.setup_schedulers()
|
71 |
+
self.load()
|
72 |
+
|
73 |
+
if config['distributed']:
|
74 |
+
self.netG = DDP(self.netG,
|
75 |
+
device_ids=[self.config['local_rank']],
|
76 |
+
output_device=self.config['local_rank'],
|
77 |
+
broadcast_buffers=True,
|
78 |
+
find_unused_parameters=True)
|
79 |
+
|
80 |
+
# set summary writer
|
81 |
+
self.dis_writer = None
|
82 |
+
self.gen_writer = None
|
83 |
+
self.summary = {}
|
84 |
+
if self.config['global_rank'] == 0 or (not config['distributed']):
|
85 |
+
self.gen_writer = SummaryWriter(
|
86 |
+
os.path.join(config['save_dir'], 'gen'))
|
87 |
+
|
88 |
+
def setup_optimizers(self):
|
89 |
+
"""Set up optimizers."""
|
90 |
+
backbone_params = []
|
91 |
+
for name, param in self.netG.named_parameters():
|
92 |
+
if param.requires_grad:
|
93 |
+
backbone_params.append(param)
|
94 |
+
else:
|
95 |
+
print(f'Params {name} will not be optimized.')
|
96 |
+
|
97 |
+
optim_params = [
|
98 |
+
{
|
99 |
+
'params': backbone_params,
|
100 |
+
'lr': self.config['trainer']['lr']
|
101 |
+
},
|
102 |
+
]
|
103 |
+
|
104 |
+
self.optimG = torch.optim.Adam(optim_params,
|
105 |
+
betas=(self.config['trainer']['beta1'],
|
106 |
+
self.config['trainer']['beta2']))
|
107 |
+
|
108 |
+
|
109 |
+
def setup_schedulers(self):
|
110 |
+
"""Set up schedulers."""
|
111 |
+
scheduler_opt = self.config['trainer']['scheduler']
|
112 |
+
scheduler_type = scheduler_opt.pop('type')
|
113 |
+
|
114 |
+
if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
|
115 |
+
self.scheG = MultiStepRestartLR(
|
116 |
+
self.optimG,
|
117 |
+
milestones=scheduler_opt['milestones'],
|
118 |
+
gamma=scheduler_opt['gamma'])
|
119 |
+
elif scheduler_type == 'CosineAnnealingRestartLR':
|
120 |
+
self.scheG = CosineAnnealingRestartLR(
|
121 |
+
self.optimG,
|
122 |
+
periods=scheduler_opt['periods'],
|
123 |
+
restart_weights=scheduler_opt['restart_weights'])
|
124 |
+
else:
|
125 |
+
raise NotImplementedError(
|
126 |
+
f'Scheduler {scheduler_type} is not implemented yet.')
|
127 |
+
|
128 |
+
def update_learning_rate(self):
|
129 |
+
"""Update learning rate."""
|
130 |
+
self.scheG.step()
|
131 |
+
|
132 |
+
def get_lr(self):
|
133 |
+
"""Get current learning rate."""
|
134 |
+
return self.optimG.param_groups[0]['lr']
|
135 |
+
|
136 |
+
def add_summary(self, writer, name, val):
|
137 |
+
"""Add tensorboard summary."""
|
138 |
+
if name not in self.summary:
|
139 |
+
self.summary[name] = 0
|
140 |
+
self.summary[name] += val
|
141 |
+
n = self.train_args['log_freq']
|
142 |
+
if writer is not None and self.iteration % n == 0:
|
143 |
+
writer.add_scalar(name, self.summary[name] / n, self.iteration)
|
144 |
+
self.summary[name] = 0
|
145 |
+
|
146 |
+
def load(self):
|
147 |
+
"""Load netG."""
|
148 |
+
# get the latest checkpoint
|
149 |
+
model_path = self.config['save_dir']
|
150 |
+
if os.path.isfile(os.path.join(model_path, 'latest.ckpt')):
|
151 |
+
latest_epoch = open(os.path.join(model_path, 'latest.ckpt'),
|
152 |
+
'r').read().splitlines()[-1]
|
153 |
+
else:
|
154 |
+
ckpts = [
|
155 |
+
os.path.basename(i).split('.pth')[0]
|
156 |
+
for i in glob.glob(os.path.join(model_path, '*.pth'))
|
157 |
+
]
|
158 |
+
ckpts.sort()
|
159 |
+
latest_epoch = ckpts[-1][4:] if len(ckpts) > 0 else None
|
160 |
+
|
161 |
+
if latest_epoch is not None:
|
162 |
+
gen_path = os.path.join(model_path, f'gen_{int(latest_epoch):06d}.pth')
|
163 |
+
opt_path = os.path.join(model_path,f'opt_{int(latest_epoch):06d}.pth')
|
164 |
+
|
165 |
+
if self.config['global_rank'] == 0:
|
166 |
+
print(f'Loading model from {gen_path}...')
|
167 |
+
dataG = torch.load(gen_path, map_location=self.config['device'])
|
168 |
+
self.netG.load_state_dict(dataG)
|
169 |
+
|
170 |
+
|
171 |
+
data_opt = torch.load(opt_path, map_location=self.config['device'])
|
172 |
+
self.optimG.load_state_dict(data_opt['optimG'])
|
173 |
+
self.scheG.load_state_dict(data_opt['scheG'])
|
174 |
+
|
175 |
+
self.epoch = data_opt['epoch']
|
176 |
+
self.iteration = data_opt['iteration']
|
177 |
+
|
178 |
+
else:
|
179 |
+
if self.config['global_rank'] == 0:
|
180 |
+
print('Warnning: There is no trained model found.'
|
181 |
+
'An initialized model will be used.')
|
182 |
+
|
183 |
+
def save(self, it):
|
184 |
+
"""Save parameters every eval_epoch"""
|
185 |
+
if self.config['global_rank'] == 0:
|
186 |
+
# configure path
|
187 |
+
gen_path = os.path.join(self.config['save_dir'],
|
188 |
+
f'gen_{it:06d}.pth')
|
189 |
+
opt_path = os.path.join(self.config['save_dir'],
|
190 |
+
f'opt_{it:06d}.pth')
|
191 |
+
print(f'\nsaving model to {gen_path} ...')
|
192 |
+
|
193 |
+
# remove .module for saving
|
194 |
+
if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP):
|
195 |
+
netG = self.netG.module
|
196 |
+
else:
|
197 |
+
netG = self.netG
|
198 |
+
|
199 |
+
# save checkpoints
|
200 |
+
torch.save(netG.state_dict(), gen_path)
|
201 |
+
torch.save(
|
202 |
+
{
|
203 |
+
'epoch': self.epoch,
|
204 |
+
'iteration': self.iteration,
|
205 |
+
'optimG': self.optimG.state_dict(),
|
206 |
+
'scheG': self.scheG.state_dict()
|
207 |
+
}, opt_path)
|
208 |
+
|
209 |
+
latest_path = os.path.join(self.config['save_dir'], 'latest.ckpt')
|
210 |
+
os.system(f"echo {it:06d} > {latest_path}")
|
211 |
+
|
212 |
+
def train(self):
|
213 |
+
"""training entry"""
|
214 |
+
pbar = range(int(self.train_args['iterations']))
|
215 |
+
if self.config['global_rank'] == 0:
|
216 |
+
pbar = tqdm(pbar,
|
217 |
+
initial=self.iteration,
|
218 |
+
dynamic_ncols=True,
|
219 |
+
smoothing=0.01)
|
220 |
+
|
221 |
+
os.makedirs('logs', exist_ok=True)
|
222 |
+
|
223 |
+
logging.basicConfig(
|
224 |
+
level=logging.INFO,
|
225 |
+
format="%(asctime)s %(filename)s[line:%(lineno)d]"
|
226 |
+
"%(levelname)s %(message)s",
|
227 |
+
datefmt="%a, %d %b %Y %H:%M:%S",
|
228 |
+
filename=f"logs/{self.config['save_dir'].split('/')[-1]}.log",
|
229 |
+
filemode='w')
|
230 |
+
|
231 |
+
while True:
|
232 |
+
self.epoch += 1
|
233 |
+
self.prefetcher.reset()
|
234 |
+
if self.config['distributed']:
|
235 |
+
self.train_sampler.set_epoch(self.epoch)
|
236 |
+
self._train_epoch(pbar)
|
237 |
+
if self.iteration > self.train_args['iterations']:
|
238 |
+
break
|
239 |
+
print('\nEnd training....')
|
240 |
+
|
241 |
+
# def get_edges(self, flows): # fgvc
|
242 |
+
# # (b, t, 2, H, W)
|
243 |
+
# b, t, _, h, w = flows.shape
|
244 |
+
# flows = flows.view(-1, 2, h, w)
|
245 |
+
# flows_list = flows.permute(0, 2, 3, 1).cpu().numpy()
|
246 |
+
# edges = []
|
247 |
+
# for f in list(flows_list):
|
248 |
+
# flows_gray = (f[:, :, 0] ** 2 + f[:, :, 1] ** 2) ** 0.5
|
249 |
+
# if flows_gray.max() < 1:
|
250 |
+
# flows_gray = flows_gray*0
|
251 |
+
# else:
|
252 |
+
# flows_gray = flows_gray / flows_gray.max()
|
253 |
+
|
254 |
+
# edge = canny(flows_gray, sigma=2, low_threshold=0.1, high_threshold=0.2) # fgvc
|
255 |
+
# edge = torch.from_numpy(edge).view(1, 1, h, w).float()
|
256 |
+
# edges.append(edge)
|
257 |
+
# edges = torch.stack(edges, dim=0).to(self.config['device'])
|
258 |
+
# edges = edges.view(b, t, 1, h, w)
|
259 |
+
# return edges
|
260 |
+
|
261 |
+
def get_edges(self, flows):
|
262 |
+
# (b, t, 2, H, W)
|
263 |
+
b, t, _, h, w = flows.shape
|
264 |
+
flows = flows.view(-1, 2, h, w)
|
265 |
+
flows_gray = (flows[:, 0, None] ** 2 + flows[:, 1, None] ** 2) ** 0.5
|
266 |
+
if flows_gray.max() < 1:
|
267 |
+
flows_gray = flows_gray*0
|
268 |
+
else:
|
269 |
+
flows_gray = flows_gray / flows_gray.max()
|
270 |
+
|
271 |
+
magnitude, edges = self.canny(flows_gray.float())
|
272 |
+
edges = edges.view(b, t, 1, h, w)
|
273 |
+
return edges
|
274 |
+
|
275 |
+
def _train_epoch(self, pbar):
|
276 |
+
"""Process input and calculate loss every training epoch"""
|
277 |
+
device = self.config['device']
|
278 |
+
train_data = self.prefetcher.next()
|
279 |
+
while train_data is not None:
|
280 |
+
self.iteration += 1
|
281 |
+
frames, masks, flows_f, flows_b, _ = train_data
|
282 |
+
frames, masks = frames.to(device), masks.to(device)
|
283 |
+
masks = masks.float()
|
284 |
+
|
285 |
+
l_t = self.num_local_frames
|
286 |
+
b, t, c, h, w = frames.size()
|
287 |
+
gt_local_frames = frames[:, :l_t, ...]
|
288 |
+
local_masks = masks[:, :l_t, ...].contiguous()
|
289 |
+
|
290 |
+
# get gt optical flow
|
291 |
+
if flows_f[0] == 'None' or flows_b[0] == 'None':
|
292 |
+
gt_flows_bi = self.fix_raft(gt_local_frames)
|
293 |
+
else:
|
294 |
+
gt_flows_bi = (flows_f.to(device), flows_b.to(device))
|
295 |
+
|
296 |
+
# get gt edge
|
297 |
+
gt_edges_forward = self.get_edges(gt_flows_bi[0])
|
298 |
+
gt_edges_backward = self.get_edges(gt_flows_bi[1])
|
299 |
+
gt_edges_bi = [gt_edges_forward, gt_edges_backward]
|
300 |
+
|
301 |
+
# complete flow
|
302 |
+
pred_flows_bi, pred_edges_bi = self.netG.module.forward_bidirect_flow(gt_flows_bi, local_masks)
|
303 |
+
|
304 |
+
# optimize net_g
|
305 |
+
self.optimG.zero_grad()
|
306 |
+
|
307 |
+
# compulte flow_loss
|
308 |
+
flow_loss, warp_loss = self.flow_loss(pred_flows_bi, gt_flows_bi, local_masks, gt_local_frames)
|
309 |
+
flow_loss = flow_loss * self.config['losses']['flow_weight']
|
310 |
+
warp_loss = warp_loss * 0.01
|
311 |
+
self.add_summary(self.gen_writer, 'loss/flow_loss', flow_loss.item())
|
312 |
+
self.add_summary(self.gen_writer, 'loss/warp_loss', warp_loss.item())
|
313 |
+
|
314 |
+
# compute edge loss
|
315 |
+
edge_loss = self.edge_loss(pred_edges_bi, gt_edges_bi, local_masks)
|
316 |
+
edge_loss = edge_loss*1.0
|
317 |
+
self.add_summary(self.gen_writer, 'loss/edge_loss', edge_loss.item())
|
318 |
+
|
319 |
+
loss = flow_loss + warp_loss + edge_loss
|
320 |
+
loss.backward()
|
321 |
+
self.optimG.step()
|
322 |
+
self.update_learning_rate()
|
323 |
+
|
324 |
+
# write image to tensorboard
|
325 |
+
# if self.iteration % 200 == 0:
|
326 |
+
if self.iteration % 200 == 0 and self.gen_writer is not None:
|
327 |
+
t = 5
|
328 |
+
# forward to cpu
|
329 |
+
gt_flows_forward_cpu = flow_to_image(gt_flows_bi[0][0]).cpu()
|
330 |
+
masked_flows_forward_cpu = (gt_flows_forward_cpu[t] * (1-local_masks[0][t].cpu())).to(gt_flows_forward_cpu)
|
331 |
+
pred_flows_forward_cpu = flow_to_image(pred_flows_bi[0][0]).cpu()
|
332 |
+
|
333 |
+
flow_results = torch.cat([gt_flows_forward_cpu[t], masked_flows_forward_cpu, pred_flows_forward_cpu[t]], 1)
|
334 |
+
self.gen_writer.add_image('img/flow-f:gt-pred', flow_results, self.iteration)
|
335 |
+
|
336 |
+
# backward to cpu
|
337 |
+
gt_flows_backward_cpu = flow_to_image(gt_flows_bi[1][0]).cpu()
|
338 |
+
masked_flows_backward_cpu = (gt_flows_backward_cpu[t] * (1-local_masks[0][t+1].cpu())).to(gt_flows_backward_cpu)
|
339 |
+
pred_flows_backward_cpu = flow_to_image(pred_flows_bi[1][0]).cpu()
|
340 |
+
|
341 |
+
flow_results = torch.cat([gt_flows_backward_cpu[t], masked_flows_backward_cpu, pred_flows_backward_cpu[t]], 1)
|
342 |
+
self.gen_writer.add_image('img/flow-b:gt-pred', flow_results, self.iteration)
|
343 |
+
|
344 |
+
# TODO: show edge
|
345 |
+
# forward
|
346 |
+
gt_edges_forward_cpu = gt_edges_bi[0][0].cpu()
|
347 |
+
masked_edges_forward_cpu = (gt_edges_forward_cpu[t] * (1-local_masks[0][t].cpu())).to(gt_edges_forward_cpu)
|
348 |
+
pred_edges_forward_cpu = pred_edges_bi[0][0].cpu()
|
349 |
+
|
350 |
+
edge_results = torch.cat([gt_edges_forward_cpu[t], masked_edges_forward_cpu, pred_edges_forward_cpu[t]], 1)
|
351 |
+
self.gen_writer.add_image('img/edge-f:gt-pred', edge_results, self.iteration)
|
352 |
+
# backward
|
353 |
+
gt_edges_backward_cpu = gt_edges_bi[1][0].cpu()
|
354 |
+
masked_edges_backward_cpu = (gt_edges_backward_cpu[t] * (1-local_masks[0][t+1].cpu())).to(gt_edges_backward_cpu)
|
355 |
+
pred_edges_backward_cpu = pred_edges_bi[1][0].cpu()
|
356 |
+
|
357 |
+
edge_results = torch.cat([gt_edges_backward_cpu[t], masked_edges_backward_cpu, pred_edges_backward_cpu[t]], 1)
|
358 |
+
self.gen_writer.add_image('img/edge-b:gt-pred', edge_results, self.iteration)
|
359 |
+
|
360 |
+
# console logs
|
361 |
+
if self.config['global_rank'] == 0:
|
362 |
+
pbar.update(1)
|
363 |
+
pbar.set_description((f"flow: {flow_loss.item():.3f}; "
|
364 |
+
f"warp: {warp_loss.item():.3f}; "
|
365 |
+
f"edge: {edge_loss.item():.3f}; "
|
366 |
+
f"lr: {self.get_lr()}"))
|
367 |
+
|
368 |
+
if self.iteration % self.train_args['log_freq'] == 0:
|
369 |
+
logging.info(f"[Iter {self.iteration}] "
|
370 |
+
f"flow: {flow_loss.item():.4f}; "
|
371 |
+
f"warp: {warp_loss.item():.4f}")
|
372 |
+
|
373 |
+
# saving models
|
374 |
+
if self.iteration % self.train_args['save_freq'] == 0:
|
375 |
+
self.save(int(self.iteration))
|
376 |
+
|
377 |
+
if self.iteration > self.train_args['iterations']:
|
378 |
+
break
|
379 |
+
|
380 |
+
train_data = self.prefetcher.next()
|
propainter/core/utils.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import io
|
3 |
+
import cv2
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image, ImageOps
|
7 |
+
import zipfile
|
8 |
+
import math
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import matplotlib
|
12 |
+
import matplotlib.patches as patches
|
13 |
+
from matplotlib.path import Path
|
14 |
+
from matplotlib import pyplot as plt
|
15 |
+
from torchvision import transforms
|
16 |
+
|
17 |
+
# matplotlib.use('agg')
|
18 |
+
|
19 |
+
# ###########################################################################
|
20 |
+
# Directory IO
|
21 |
+
# ###########################################################################
|
22 |
+
|
23 |
+
|
24 |
+
def read_dirnames_under_root(root_dir):
|
25 |
+
dirnames = [
|
26 |
+
name for i, name in enumerate(sorted(os.listdir(root_dir)))
|
27 |
+
if os.path.isdir(os.path.join(root_dir, name))
|
28 |
+
]
|
29 |
+
print(f'Reading directories under {root_dir}, num: {len(dirnames)}')
|
30 |
+
return dirnames
|
31 |
+
|
32 |
+
|
33 |
+
class TrainZipReader(object):
|
34 |
+
file_dict = dict()
|
35 |
+
|
36 |
+
def __init__(self):
|
37 |
+
super(TrainZipReader, self).__init__()
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
def build_file_dict(path):
|
41 |
+
file_dict = TrainZipReader.file_dict
|
42 |
+
if path in file_dict:
|
43 |
+
return file_dict[path]
|
44 |
+
else:
|
45 |
+
file_handle = zipfile.ZipFile(path, 'r')
|
46 |
+
file_dict[path] = file_handle
|
47 |
+
return file_dict[path]
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def imread(path, idx):
|
51 |
+
zfile = TrainZipReader.build_file_dict(path)
|
52 |
+
filelist = zfile.namelist()
|
53 |
+
filelist.sort()
|
54 |
+
data = zfile.read(filelist[idx])
|
55 |
+
#
|
56 |
+
im = Image.open(io.BytesIO(data))
|
57 |
+
return im
|
58 |
+
|
59 |
+
|
60 |
+
class TestZipReader(object):
|
61 |
+
file_dict = dict()
|
62 |
+
|
63 |
+
def __init__(self):
|
64 |
+
super(TestZipReader, self).__init__()
|
65 |
+
|
66 |
+
@staticmethod
|
67 |
+
def build_file_dict(path):
|
68 |
+
file_dict = TestZipReader.file_dict
|
69 |
+
if path in file_dict:
|
70 |
+
return file_dict[path]
|
71 |
+
else:
|
72 |
+
file_handle = zipfile.ZipFile(path, 'r')
|
73 |
+
file_dict[path] = file_handle
|
74 |
+
return file_dict[path]
|
75 |
+
|
76 |
+
@staticmethod
|
77 |
+
def imread(path, idx):
|
78 |
+
zfile = TestZipReader.build_file_dict(path)
|
79 |
+
filelist = zfile.namelist()
|
80 |
+
filelist.sort()
|
81 |
+
data = zfile.read(filelist[idx])
|
82 |
+
file_bytes = np.asarray(bytearray(data), dtype=np.uint8)
|
83 |
+
im = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
|
84 |
+
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
|
85 |
+
# im = Image.open(io.BytesIO(data))
|
86 |
+
return im
|
87 |
+
|
88 |
+
|
89 |
+
# ###########################################################################
|
90 |
+
# Data augmentation
|
91 |
+
# ###########################################################################
|
92 |
+
|
93 |
+
|
94 |
+
def to_tensors():
|
95 |
+
return transforms.Compose([Stack(), ToTorchFormatTensor()])
|
96 |
+
|
97 |
+
|
98 |
+
class GroupRandomHorizontalFlowFlip(object):
|
99 |
+
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
|
100 |
+
"""
|
101 |
+
def __call__(self, img_group, flowF_group, flowB_group):
|
102 |
+
v = random.random()
|
103 |
+
if v < 0.5:
|
104 |
+
ret_img = [
|
105 |
+
img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group
|
106 |
+
]
|
107 |
+
ret_flowF = [ff[:, ::-1] * [-1.0, 1.0] for ff in flowF_group]
|
108 |
+
ret_flowB = [fb[:, ::-1] * [-1.0, 1.0] for fb in flowB_group]
|
109 |
+
return ret_img, ret_flowF, ret_flowB
|
110 |
+
else:
|
111 |
+
return img_group, flowF_group, flowB_group
|
112 |
+
|
113 |
+
|
114 |
+
class GroupRandomHorizontalFlip(object):
|
115 |
+
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
|
116 |
+
"""
|
117 |
+
def __call__(self, img_group, is_flow=False):
|
118 |
+
v = random.random()
|
119 |
+
if v < 0.5:
|
120 |
+
ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
|
121 |
+
if is_flow:
|
122 |
+
for i in range(0, len(ret), 2):
|
123 |
+
# invert flow pixel values when flipping
|
124 |
+
ret[i] = ImageOps.invert(ret[i])
|
125 |
+
return ret
|
126 |
+
else:
|
127 |
+
return img_group
|
128 |
+
|
129 |
+
|
130 |
+
class Stack(object):
|
131 |
+
def __init__(self, roll=False):
|
132 |
+
self.roll = roll
|
133 |
+
|
134 |
+
def __call__(self, img_group):
|
135 |
+
mode = img_group[0].mode
|
136 |
+
if mode == '1':
|
137 |
+
img_group = [img.convert('L') for img in img_group]
|
138 |
+
mode = 'L'
|
139 |
+
if mode == 'L':
|
140 |
+
return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2)
|
141 |
+
elif mode == 'RGB':
|
142 |
+
if self.roll:
|
143 |
+
return np.stack([np.array(x)[:, :, ::-1] for x in img_group],
|
144 |
+
axis=2)
|
145 |
+
else:
|
146 |
+
return np.stack(img_group, axis=2)
|
147 |
+
else:
|
148 |
+
raise NotImplementedError(f"Image mode {mode}")
|
149 |
+
|
150 |
+
|
151 |
+
class ToTorchFormatTensor(object):
|
152 |
+
""" Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
|
153 |
+
to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
|
154 |
+
def __init__(self, div=True):
|
155 |
+
self.div = div
|
156 |
+
|
157 |
+
def __call__(self, pic):
|
158 |
+
if isinstance(pic, np.ndarray):
|
159 |
+
# numpy img: [L, C, H, W]
|
160 |
+
img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous()
|
161 |
+
else:
|
162 |
+
# handle PIL Image
|
163 |
+
img = torch.ByteTensor(torch.ByteStorage.from_buffer(
|
164 |
+
pic.tobytes()))
|
165 |
+
img = img.view(pic.size[1], pic.size[0], len(pic.mode))
|
166 |
+
# put it from HWC to CHW format
|
167 |
+
# yikes, this transpose takes 80% of the loading time/CPU
|
168 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
169 |
+
img = img.float().div(255) if self.div else img.float()
|
170 |
+
return img
|
171 |
+
|
172 |
+
|
173 |
+
# ###########################################################################
|
174 |
+
# Create masks with random shape
|
175 |
+
# ###########################################################################
|
176 |
+
|
177 |
+
|
178 |
+
def create_random_shape_with_random_motion(video_length,
|
179 |
+
imageHeight=240,
|
180 |
+
imageWidth=432):
|
181 |
+
# get a random shape
|
182 |
+
height = random.randint(imageHeight // 3, imageHeight - 1)
|
183 |
+
width = random.randint(imageWidth // 3, imageWidth - 1)
|
184 |
+
edge_num = random.randint(6, 8)
|
185 |
+
ratio = random.randint(6, 8) / 10
|
186 |
+
|
187 |
+
region = get_random_shape(edge_num=edge_num,
|
188 |
+
ratio=ratio,
|
189 |
+
height=height,
|
190 |
+
width=width)
|
191 |
+
region_width, region_height = region.size
|
192 |
+
# get random position
|
193 |
+
x, y = random.randint(0, imageHeight - region_height), random.randint(
|
194 |
+
0, imageWidth - region_width)
|
195 |
+
velocity = get_random_velocity(max_speed=3)
|
196 |
+
m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8))
|
197 |
+
m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
|
198 |
+
masks = [m.convert('L')]
|
199 |
+
# return fixed masks
|
200 |
+
if random.uniform(0, 1) > 0.5:
|
201 |
+
return masks * video_length
|
202 |
+
# return moving masks
|
203 |
+
for _ in range(video_length - 1):
|
204 |
+
x, y, velocity = random_move_control_points(x,
|
205 |
+
y,
|
206 |
+
imageHeight,
|
207 |
+
imageWidth,
|
208 |
+
velocity,
|
209 |
+
region.size,
|
210 |
+
maxLineAcceleration=(3,
|
211 |
+
0.5),
|
212 |
+
maxInitSpeed=3)
|
213 |
+
m = Image.fromarray(
|
214 |
+
np.zeros((imageHeight, imageWidth)).astype(np.uint8))
|
215 |
+
m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
|
216 |
+
masks.append(m.convert('L'))
|
217 |
+
return masks
|
218 |
+
|
219 |
+
|
220 |
+
def create_random_shape_with_random_motion_zoom_rotation(video_length, zoomin=0.9, zoomout=1.1, rotmin=1, rotmax=10, imageHeight=240, imageWidth=432):
|
221 |
+
# get a random shape
|
222 |
+
assert zoomin < 1, "Zoom-in parameter must be smaller than 1"
|
223 |
+
assert zoomout > 1, "Zoom-out parameter must be larger than 1"
|
224 |
+
assert rotmin < rotmax, "Minimum value of rotation must be smaller than maximun value !"
|
225 |
+
height = random.randint(imageHeight//3, imageHeight-1)
|
226 |
+
width = random.randint(imageWidth//3, imageWidth-1)
|
227 |
+
edge_num = random.randint(6, 8)
|
228 |
+
ratio = random.randint(6, 8)/10
|
229 |
+
region = get_random_shape(
|
230 |
+
edge_num=edge_num, ratio=ratio, height=height, width=width)
|
231 |
+
region_width, region_height = region.size
|
232 |
+
# get random position
|
233 |
+
x, y = random.randint(
|
234 |
+
0, imageHeight-region_height), random.randint(0, imageWidth-region_width)
|
235 |
+
velocity = get_random_velocity(max_speed=3)
|
236 |
+
m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8))
|
237 |
+
m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
|
238 |
+
masks = [m.convert('L')]
|
239 |
+
# return fixed masks
|
240 |
+
if random.uniform(0, 1) > 0.5:
|
241 |
+
return masks*video_length # -> directly copy all the base masks
|
242 |
+
# return moving masks
|
243 |
+
for _ in range(video_length-1):
|
244 |
+
x, y, velocity = random_move_control_points(
|
245 |
+
x, y, imageHeight, imageWidth, velocity, region.size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3)
|
246 |
+
m = Image.fromarray(
|
247 |
+
np.zeros((imageHeight, imageWidth)).astype(np.uint8))
|
248 |
+
### add by kaidong, to simulate zoon-in, zoom-out and rotation
|
249 |
+
extra_transform = random.uniform(0, 1)
|
250 |
+
# zoom in and zoom out
|
251 |
+
if extra_transform > 0.75:
|
252 |
+
resize_coefficient = random.uniform(zoomin, zoomout)
|
253 |
+
region = region.resize((math.ceil(region_width * resize_coefficient), math.ceil(region_height * resize_coefficient)), Image.NEAREST)
|
254 |
+
m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
|
255 |
+
region_width, region_height = region.size
|
256 |
+
# rotation
|
257 |
+
elif extra_transform > 0.5:
|
258 |
+
m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
|
259 |
+
m = m.rotate(random.randint(rotmin, rotmax))
|
260 |
+
# region_width, region_height = region.size
|
261 |
+
### end
|
262 |
+
else:
|
263 |
+
m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
|
264 |
+
masks.append(m.convert('L'))
|
265 |
+
return masks
|
266 |
+
|
267 |
+
|
268 |
+
def get_random_shape(edge_num=9, ratio=0.7, width=432, height=240):
|
269 |
+
'''
|
270 |
+
There is the initial point and 3 points per cubic bezier curve.
|
271 |
+
Thus, the curve will only pass though n points, which will be the sharp edges.
|
272 |
+
The other 2 modify the shape of the bezier curve.
|
273 |
+
edge_num, Number of possibly sharp edges
|
274 |
+
points_num, number of points in the Path
|
275 |
+
ratio, (0, 1) magnitude of the perturbation from the unit circle,
|
276 |
+
'''
|
277 |
+
points_num = edge_num*3 + 1
|
278 |
+
angles = np.linspace(0, 2*np.pi, points_num)
|
279 |
+
codes = np.full(points_num, Path.CURVE4)
|
280 |
+
codes[0] = Path.MOVETO
|
281 |
+
# Using this instead of Path.CLOSEPOLY avoids an innecessary straight line
|
282 |
+
verts = np.stack((np.cos(angles), np.sin(angles))).T * \
|
283 |
+
(2*ratio*np.random.random(points_num)+1-ratio)[:, None]
|
284 |
+
verts[-1, :] = verts[0, :]
|
285 |
+
path = Path(verts, codes)
|
286 |
+
# draw paths into images
|
287 |
+
fig = plt.figure()
|
288 |
+
ax = fig.add_subplot(111)
|
289 |
+
patch = patches.PathPatch(path, facecolor='black', lw=2)
|
290 |
+
ax.add_patch(patch)
|
291 |
+
ax.set_xlim(np.min(verts)*1.1, np.max(verts)*1.1)
|
292 |
+
ax.set_ylim(np.min(verts)*1.1, np.max(verts)*1.1)
|
293 |
+
ax.axis('off') # removes the axis to leave only the shape
|
294 |
+
fig.canvas.draw()
|
295 |
+
# convert plt images into numpy images
|
296 |
+
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
297 |
+
data = data.reshape((fig.canvas.get_width_height()[::-1] + (3,)))
|
298 |
+
plt.close(fig)
|
299 |
+
# postprocess
|
300 |
+
data = cv2.resize(data, (width, height))[:, :, 0]
|
301 |
+
data = (1 - np.array(data > 0).astype(np.uint8))*255
|
302 |
+
corrdinates = np.where(data > 0)
|
303 |
+
xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max(
|
304 |
+
corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1])
|
305 |
+
region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax))
|
306 |
+
return region
|
307 |
+
|
308 |
+
|
309 |
+
def random_accelerate(velocity, maxAcceleration, dist='uniform'):
|
310 |
+
speed, angle = velocity
|
311 |
+
d_speed, d_angle = maxAcceleration
|
312 |
+
if dist == 'uniform':
|
313 |
+
speed += np.random.uniform(-d_speed, d_speed)
|
314 |
+
angle += np.random.uniform(-d_angle, d_angle)
|
315 |
+
elif dist == 'guassian':
|
316 |
+
speed += np.random.normal(0, d_speed / 2)
|
317 |
+
angle += np.random.normal(0, d_angle / 2)
|
318 |
+
else:
|
319 |
+
raise NotImplementedError(
|
320 |
+
f'Distribution type {dist} is not supported.')
|
321 |
+
return (speed, angle)
|
322 |
+
|
323 |
+
|
324 |
+
def get_random_velocity(max_speed=3, dist='uniform'):
|
325 |
+
if dist == 'uniform':
|
326 |
+
speed = np.random.uniform(max_speed)
|
327 |
+
elif dist == 'guassian':
|
328 |
+
speed = np.abs(np.random.normal(0, max_speed / 2))
|
329 |
+
else:
|
330 |
+
raise NotImplementedError(
|
331 |
+
f'Distribution type {dist} is not supported.')
|
332 |
+
angle = np.random.uniform(0, 2 * np.pi)
|
333 |
+
return (speed, angle)
|
334 |
+
|
335 |
+
|
336 |
+
def random_move_control_points(X,
|
337 |
+
Y,
|
338 |
+
imageHeight,
|
339 |
+
imageWidth,
|
340 |
+
lineVelocity,
|
341 |
+
region_size,
|
342 |
+
maxLineAcceleration=(3, 0.5),
|
343 |
+
maxInitSpeed=3):
|
344 |
+
region_width, region_height = region_size
|
345 |
+
speed, angle = lineVelocity
|
346 |
+
X += int(speed * np.cos(angle))
|
347 |
+
Y += int(speed * np.sin(angle))
|
348 |
+
lineVelocity = random_accelerate(lineVelocity,
|
349 |
+
maxLineAcceleration,
|
350 |
+
dist='guassian')
|
351 |
+
if ((X > imageHeight - region_height) or (X < 0)
|
352 |
+
or (Y > imageWidth - region_width) or (Y < 0)):
|
353 |
+
lineVelocity = get_random_velocity(maxInitSpeed, dist='guassian')
|
354 |
+
new_X = np.clip(X, 0, imageHeight - region_height)
|
355 |
+
new_Y = np.clip(Y, 0, imageWidth - region_width)
|
356 |
+
return new_X, new_Y, lineVelocity
|
357 |
+
|
358 |
+
|
359 |
+
if __name__ == '__main__':
|
360 |
+
|
361 |
+
trials = 10
|
362 |
+
for _ in range(trials):
|
363 |
+
video_length = 10
|
364 |
+
# The returned masks are either stationary (50%) or moving (50%)
|
365 |
+
masks = create_random_shape_with_random_motion(video_length,
|
366 |
+
imageHeight=240,
|
367 |
+
imageWidth=432)
|
368 |
+
|
369 |
+
for m in masks:
|
370 |
+
cv2.imshow('mask', np.array(m))
|
371 |
+
cv2.waitKey(500)
|
propainter/inference.py
ADDED
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import os
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import scipy.ndimage
|
6 |
+
from PIL import Image
|
7 |
+
from tqdm import tqdm
|
8 |
+
import torch
|
9 |
+
import torchvision
|
10 |
+
import gc
|
11 |
+
|
12 |
+
try:
|
13 |
+
from model.modules.flow_comp_raft import RAFT_bi
|
14 |
+
from model.recurrent_flow_completion import RecurrentFlowCompleteNet
|
15 |
+
from model.propainter import InpaintGenerator
|
16 |
+
from utils.download_util import load_file_from_url
|
17 |
+
from core.utils import to_tensors
|
18 |
+
from model.misc import get_device
|
19 |
+
except:
|
20 |
+
from propainter.model.modules.flow_comp_raft import RAFT_bi
|
21 |
+
from propainter.model.recurrent_flow_completion import RecurrentFlowCompleteNet
|
22 |
+
from propainter.model.propainter import InpaintGenerator
|
23 |
+
from propainter.utils.download_util import load_file_from_url
|
24 |
+
from propainter.core.utils import to_tensors
|
25 |
+
from propainter.model.misc import get_device
|
26 |
+
|
27 |
+
import warnings
|
28 |
+
warnings.filterwarnings("ignore")
|
29 |
+
|
30 |
+
pretrain_model_url = 'https://github.com/sczhou/ProPainter/releases/download/v0.1.0/'
|
31 |
+
MaxSideThresh = 960
|
32 |
+
|
33 |
+
|
34 |
+
# resize frames
|
35 |
+
def resize_frames(frames, size=None):
|
36 |
+
if size is not None:
|
37 |
+
out_size = size
|
38 |
+
process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
|
39 |
+
frames = [f.resize(process_size) for f in frames]
|
40 |
+
else:
|
41 |
+
out_size = frames[0].size
|
42 |
+
process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
|
43 |
+
if not out_size == process_size:
|
44 |
+
frames = [f.resize(process_size) for f in frames]
|
45 |
+
|
46 |
+
return frames, process_size, out_size
|
47 |
+
|
48 |
+
# read frames from video
|
49 |
+
def read_frame_from_videos(frame_root, video_length):
|
50 |
+
if frame_root.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path
|
51 |
+
video_name = os.path.basename(frame_root)[:-4]
|
52 |
+
vframes, aframes, info = torchvision.io.read_video(filename=frame_root, pts_unit='sec', end_pts=video_length) # RGB
|
53 |
+
frames = list(vframes.numpy())
|
54 |
+
frames = [Image.fromarray(f) for f in frames]
|
55 |
+
fps = info['video_fps']
|
56 |
+
nframes = len(frames)
|
57 |
+
else:
|
58 |
+
video_name = os.path.basename(frame_root)
|
59 |
+
frames = []
|
60 |
+
fr_lst = sorted(os.listdir(frame_root))
|
61 |
+
for fr in fr_lst:
|
62 |
+
frame = cv2.imread(os.path.join(frame_root, fr))
|
63 |
+
frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
64 |
+
frames.append(frame)
|
65 |
+
fps = None
|
66 |
+
nframes = len(frames)
|
67 |
+
size = frames[0].size
|
68 |
+
|
69 |
+
return frames, fps, size, video_name, nframes
|
70 |
+
|
71 |
+
def binary_mask(mask, th=0.1):
|
72 |
+
mask[mask>th] = 1
|
73 |
+
mask[mask<=th] = 0
|
74 |
+
return mask
|
75 |
+
|
76 |
+
# read frame-wise masks
|
77 |
+
def read_mask(mpath, frames_len, size, flow_mask_dilates=8, mask_dilates=5):
|
78 |
+
masks_img = []
|
79 |
+
masks_dilated = []
|
80 |
+
flow_masks = []
|
81 |
+
|
82 |
+
if mpath.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path
|
83 |
+
masks_img = [Image.open(mpath)]
|
84 |
+
elif mpath.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path
|
85 |
+
cap = cv2.VideoCapture(mpath)
|
86 |
+
if not cap.isOpened():
|
87 |
+
print("Error: Could not open video.")
|
88 |
+
exit()
|
89 |
+
idx = 0
|
90 |
+
while True:
|
91 |
+
ret, frame = cap.read()
|
92 |
+
if not ret:
|
93 |
+
break
|
94 |
+
if(idx >= frames_len):
|
95 |
+
break
|
96 |
+
masks_img.append(Image.fromarray(frame))
|
97 |
+
idx += 1
|
98 |
+
cap.release()
|
99 |
+
else:
|
100 |
+
mnames = sorted(os.listdir(mpath))
|
101 |
+
for mp in mnames:
|
102 |
+
masks_img.append(Image.open(os.path.join(mpath, mp)))
|
103 |
+
# print(mp)
|
104 |
+
|
105 |
+
for mask_img in masks_img:
|
106 |
+
if size is not None:
|
107 |
+
mask_img = mask_img.resize(size, Image.NEAREST)
|
108 |
+
mask_img = np.array(mask_img.convert('L'))
|
109 |
+
|
110 |
+
# Dilate 8 pixel so that all known pixel is trustworthy
|
111 |
+
if flow_mask_dilates > 0:
|
112 |
+
flow_mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=flow_mask_dilates).astype(np.uint8)
|
113 |
+
else:
|
114 |
+
flow_mask_img = binary_mask(mask_img).astype(np.uint8)
|
115 |
+
# Close the small holes inside the foreground objects
|
116 |
+
# flow_mask_img = cv2.morphologyEx(flow_mask_img, cv2.MORPH_CLOSE, np.ones((21, 21),np.uint8)).astype(bool)
|
117 |
+
# flow_mask_img = scipy.ndimage.binary_fill_holes(flow_mask_img).astype(np.uint8)
|
118 |
+
flow_masks.append(Image.fromarray(flow_mask_img * 255))
|
119 |
+
|
120 |
+
if mask_dilates > 0:
|
121 |
+
mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=mask_dilates).astype(np.uint8)
|
122 |
+
else:
|
123 |
+
mask_img = binary_mask(mask_img).astype(np.uint8)
|
124 |
+
masks_dilated.append(Image.fromarray(mask_img * 255))
|
125 |
+
|
126 |
+
if len(masks_img) == 1:
|
127 |
+
flow_masks = flow_masks * frames_len
|
128 |
+
masks_dilated = masks_dilated * frames_len
|
129 |
+
|
130 |
+
return flow_masks, masks_dilated
|
131 |
+
|
132 |
+
def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=-1):
|
133 |
+
ref_index = []
|
134 |
+
if ref_num == -1:
|
135 |
+
for i in range(0, length, ref_stride):
|
136 |
+
if i not in neighbor_ids:
|
137 |
+
ref_index.append(i)
|
138 |
+
else:
|
139 |
+
start_idx = max(0, mid_neighbor_id - ref_stride * (ref_num // 2))
|
140 |
+
end_idx = min(length, mid_neighbor_id + ref_stride * (ref_num // 2))
|
141 |
+
for i in range(start_idx, end_idx, ref_stride):
|
142 |
+
if i not in neighbor_ids:
|
143 |
+
if len(ref_index) > ref_num:
|
144 |
+
break
|
145 |
+
ref_index.append(i)
|
146 |
+
return ref_index
|
147 |
+
|
148 |
+
|
149 |
+
class Propainter:
|
150 |
+
def __init__(
|
151 |
+
self, propainter_model_dir, device):
|
152 |
+
self.device = device
|
153 |
+
##############################################
|
154 |
+
# set up RAFT and flow competition model
|
155 |
+
##############################################
|
156 |
+
ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'raft-things.pth'),
|
157 |
+
model_dir=propainter_model_dir, progress=True, file_name=None)
|
158 |
+
self.fix_raft = RAFT_bi(ckpt_path, device)
|
159 |
+
|
160 |
+
ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'recurrent_flow_completion.pth'),
|
161 |
+
model_dir=propainter_model_dir, progress=True, file_name=None)
|
162 |
+
self.fix_flow_complete = RecurrentFlowCompleteNet(ckpt_path)
|
163 |
+
for p in self.fix_flow_complete.parameters():
|
164 |
+
p.requires_grad = False
|
165 |
+
self.fix_flow_complete.to(device)
|
166 |
+
self.fix_flow_complete.eval()
|
167 |
+
|
168 |
+
##############################################
|
169 |
+
# set up ProPainter model
|
170 |
+
##############################################
|
171 |
+
ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'ProPainter.pth'),
|
172 |
+
model_dir=propainter_model_dir, progress=True, file_name=None)
|
173 |
+
self.model = InpaintGenerator(model_path=ckpt_path).to(device)
|
174 |
+
self.model.eval()
|
175 |
+
def forward(self, video, mask, output_path, resize_ratio=0.6, video_length=2, height=-1, width=-1,
|
176 |
+
mask_dilation=4, ref_stride=10, neighbor_length=10, subvideo_length=80,
|
177 |
+
raft_iter=20, save_fps=24, save_frames=False, fp16=True):
|
178 |
+
|
179 |
+
# Use fp16 precision during inference to reduce running memory cost
|
180 |
+
use_half = True if fp16 else False
|
181 |
+
if self.device == torch.device('cpu'):
|
182 |
+
use_half = False
|
183 |
+
|
184 |
+
################ read input video ################
|
185 |
+
frames, fps, size, video_name, nframes = read_frame_from_videos(video, video_length)
|
186 |
+
frames = frames[:nframes]
|
187 |
+
if not width == -1 and not height == -1:
|
188 |
+
size = (width, height)
|
189 |
+
|
190 |
+
longer_edge = max(size[0], size[1])
|
191 |
+
if(longer_edge > MaxSideThresh):
|
192 |
+
scale = MaxSideThresh / longer_edge
|
193 |
+
resize_ratio = resize_ratio * scale
|
194 |
+
if not resize_ratio == 1.0:
|
195 |
+
size = (int(resize_ratio * size[0]), int(resize_ratio * size[1]))
|
196 |
+
|
197 |
+
frames, size, out_size = resize_frames(frames, size)
|
198 |
+
fps = save_fps if fps is None else fps
|
199 |
+
|
200 |
+
################ read mask ################
|
201 |
+
frames_len = len(frames)
|
202 |
+
flow_masks, masks_dilated = read_mask(mask, frames_len, size,
|
203 |
+
flow_mask_dilates=mask_dilation,
|
204 |
+
mask_dilates=mask_dilation)
|
205 |
+
flow_masks = flow_masks[:nframes]
|
206 |
+
masks_dilated = masks_dilated[:nframes]
|
207 |
+
w, h = size
|
208 |
+
|
209 |
+
################ adjust input ################
|
210 |
+
frames_len = min(len(frames), len(masks_dilated))
|
211 |
+
frames = frames[:frames_len]
|
212 |
+
flow_masks = flow_masks[:frames_len]
|
213 |
+
masks_dilated = masks_dilated[:frames_len]
|
214 |
+
|
215 |
+
ori_frames_inp = [np.array(f).astype(np.uint8) for f in frames]
|
216 |
+
frames = to_tensors()(frames).unsqueeze(0) * 2 - 1
|
217 |
+
flow_masks = to_tensors()(flow_masks).unsqueeze(0)
|
218 |
+
masks_dilated = to_tensors()(masks_dilated).unsqueeze(0)
|
219 |
+
frames, flow_masks, masks_dilated = frames.to(self.device), flow_masks.to(self.device), masks_dilated.to(self.device)
|
220 |
+
|
221 |
+
##############################################
|
222 |
+
# ProPainter inference
|
223 |
+
##############################################
|
224 |
+
video_length = frames.size(1)
|
225 |
+
print(f'Priori generating: [{video_length} frames]...')
|
226 |
+
with torch.no_grad():
|
227 |
+
# ---- compute flow ----
|
228 |
+
new_longer_edge = max(frames.size(-1), frames.size(-2))
|
229 |
+
if new_longer_edge <= 640:
|
230 |
+
short_clip_len = 12
|
231 |
+
elif new_longer_edge <= 720:
|
232 |
+
short_clip_len = 8
|
233 |
+
elif new_longer_edge <= 1280:
|
234 |
+
short_clip_len = 4
|
235 |
+
else:
|
236 |
+
short_clip_len = 2
|
237 |
+
|
238 |
+
# use fp32 for RAFT
|
239 |
+
if frames.size(1) > short_clip_len:
|
240 |
+
gt_flows_f_list, gt_flows_b_list = [], []
|
241 |
+
for f in range(0, video_length, short_clip_len):
|
242 |
+
end_f = min(video_length, f + short_clip_len)
|
243 |
+
if f == 0:
|
244 |
+
flows_f, flows_b = self.fix_raft(frames[:,f:end_f], iters=raft_iter)
|
245 |
+
else:
|
246 |
+
flows_f, flows_b = self.fix_raft(frames[:,f-1:end_f], iters=raft_iter)
|
247 |
+
|
248 |
+
gt_flows_f_list.append(flows_f)
|
249 |
+
gt_flows_b_list.append(flows_b)
|
250 |
+
torch.cuda.empty_cache()
|
251 |
+
|
252 |
+
gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
|
253 |
+
gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
|
254 |
+
gt_flows_bi = (gt_flows_f, gt_flows_b)
|
255 |
+
else:
|
256 |
+
gt_flows_bi = self.fix_raft(frames, iters=raft_iter)
|
257 |
+
torch.cuda.empty_cache()
|
258 |
+
torch.cuda.empty_cache()
|
259 |
+
gc.collect()
|
260 |
+
|
261 |
+
if use_half:
|
262 |
+
frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half()
|
263 |
+
gt_flows_bi = (gt_flows_bi[0].half(), gt_flows_bi[1].half())
|
264 |
+
self.fix_flow_complete = self.fix_flow_complete.half()
|
265 |
+
self.model = self.model.half()
|
266 |
+
|
267 |
+
# ---- complete flow ----
|
268 |
+
flow_length = gt_flows_bi[0].size(1)
|
269 |
+
if flow_length > subvideo_length:
|
270 |
+
pred_flows_f, pred_flows_b = [], []
|
271 |
+
pad_len = 5
|
272 |
+
for f in range(0, flow_length, subvideo_length):
|
273 |
+
s_f = max(0, f - pad_len)
|
274 |
+
e_f = min(flow_length, f + subvideo_length + pad_len)
|
275 |
+
pad_len_s = max(0, f) - s_f
|
276 |
+
pad_len_e = e_f - min(flow_length, f + subvideo_length)
|
277 |
+
pred_flows_bi_sub, _ = self.fix_flow_complete.forward_bidirect_flow(
|
278 |
+
(gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
|
279 |
+
flow_masks[:, s_f:e_f+1])
|
280 |
+
pred_flows_bi_sub = self.fix_flow_complete.combine_flow(
|
281 |
+
(gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
|
282 |
+
pred_flows_bi_sub,
|
283 |
+
flow_masks[:, s_f:e_f+1])
|
284 |
+
|
285 |
+
pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
|
286 |
+
pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
|
287 |
+
torch.cuda.empty_cache()
|
288 |
+
|
289 |
+
pred_flows_f = torch.cat(pred_flows_f, dim=1)
|
290 |
+
pred_flows_b = torch.cat(pred_flows_b, dim=1)
|
291 |
+
pred_flows_bi = (pred_flows_f, pred_flows_b)
|
292 |
+
else:
|
293 |
+
pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)
|
294 |
+
pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)
|
295 |
+
torch.cuda.empty_cache()
|
296 |
+
torch.cuda.empty_cache()
|
297 |
+
gc.collect()
|
298 |
+
|
299 |
+
|
300 |
+
masks_dilated_ori = masks_dilated.clone()
|
301 |
+
# ---- Pre-propagation ----
|
302 |
+
subvideo_length_img_prop = min(100, subvideo_length) # ensure a minimum of 100 frames for image propagation
|
303 |
+
if(len(frames[0]))>subvideo_length_img_prop: # perform propagation only when length of frames is larger than subvideo_length_img_prop
|
304 |
+
sample_rate = len(frames[0])//(subvideo_length_img_prop//2)
|
305 |
+
index_sample = list(range(0, len(frames[0]), sample_rate))
|
306 |
+
sample_frames = torch.stack([frames[0][i].to(torch.float32) for i in index_sample]).unsqueeze(0) # use fp32 for RAFT
|
307 |
+
sample_masks_dilated = torch.stack([masks_dilated[0][i] for i in index_sample]).unsqueeze(0)
|
308 |
+
sample_flow_masks = torch.stack([flow_masks[0][i] for i in index_sample]).unsqueeze(0)
|
309 |
+
|
310 |
+
## recompute flow for sampled frames
|
311 |
+
# use fp32 for RAFT
|
312 |
+
sample_video_length = sample_frames.size(1)
|
313 |
+
if sample_frames.size(1) > short_clip_len:
|
314 |
+
gt_flows_f_list, gt_flows_b_list = [], []
|
315 |
+
for f in range(0, sample_video_length, short_clip_len):
|
316 |
+
end_f = min(sample_video_length, f + short_clip_len)
|
317 |
+
if f == 0:
|
318 |
+
flows_f, flows_b = self.fix_raft(sample_frames[:,f:end_f], iters=raft_iter)
|
319 |
+
else:
|
320 |
+
flows_f, flows_b = self.fix_raft(sample_frames[:,f-1:end_f], iters=raft_iter)
|
321 |
+
|
322 |
+
gt_flows_f_list.append(flows_f)
|
323 |
+
gt_flows_b_list.append(flows_b)
|
324 |
+
torch.cuda.empty_cache()
|
325 |
+
|
326 |
+
gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
|
327 |
+
gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
|
328 |
+
sample_gt_flows_bi = (gt_flows_f, gt_flows_b)
|
329 |
+
else:
|
330 |
+
sample_gt_flows_bi = self.fix_raft(sample_frames, iters=raft_iter)
|
331 |
+
torch.cuda.empty_cache()
|
332 |
+
torch.cuda.empty_cache()
|
333 |
+
gc.collect()
|
334 |
+
|
335 |
+
if use_half:
|
336 |
+
sample_frames, sample_flow_masks, sample_masks_dilated = sample_frames.half(), sample_flow_masks.half(), sample_masks_dilated.half()
|
337 |
+
sample_gt_flows_bi = (sample_gt_flows_bi[0].half(), sample_gt_flows_bi[1].half())
|
338 |
+
|
339 |
+
# ---- complete flow ----
|
340 |
+
flow_length = sample_gt_flows_bi[0].size(1)
|
341 |
+
if flow_length > subvideo_length:
|
342 |
+
pred_flows_f, pred_flows_b = [], []
|
343 |
+
pad_len = 5
|
344 |
+
for f in range(0, flow_length, subvideo_length):
|
345 |
+
s_f = max(0, f - pad_len)
|
346 |
+
e_f = min(flow_length, f + subvideo_length + pad_len)
|
347 |
+
pad_len_s = max(0, f) - s_f
|
348 |
+
pad_len_e = e_f - min(flow_length, f + subvideo_length)
|
349 |
+
pred_flows_bi_sub, _ = self.fix_flow_complete.forward_bidirect_flow(
|
350 |
+
(sample_gt_flows_bi[0][:, s_f:e_f], sample_gt_flows_bi[1][:, s_f:e_f]),
|
351 |
+
sample_flow_masks[:, s_f:e_f+1])
|
352 |
+
pred_flows_bi_sub = self.fix_flow_complete.combine_flow(
|
353 |
+
(sample_gt_flows_bi[0][:, s_f:e_f], sample_gt_flows_bi[1][:, s_f:e_f]),
|
354 |
+
pred_flows_bi_sub,
|
355 |
+
sample_flow_masks[:, s_f:e_f+1])
|
356 |
+
|
357 |
+
pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
|
358 |
+
pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
|
359 |
+
torch.cuda.empty_cache()
|
360 |
+
|
361 |
+
pred_flows_f = torch.cat(pred_flows_f, dim=1)
|
362 |
+
pred_flows_b = torch.cat(pred_flows_b, dim=1)
|
363 |
+
sample_pred_flows_bi = (pred_flows_f, pred_flows_b)
|
364 |
+
else:
|
365 |
+
sample_pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(sample_gt_flows_bi, sample_flow_masks)
|
366 |
+
sample_pred_flows_bi = self.fix_flow_complete.combine_flow(sample_gt_flows_bi, sample_pred_flows_bi, sample_flow_masks)
|
367 |
+
torch.cuda.empty_cache()
|
368 |
+
torch.cuda.empty_cache()
|
369 |
+
gc.collect()
|
370 |
+
|
371 |
+
masked_frames = sample_frames * (1 - sample_masks_dilated)
|
372 |
+
|
373 |
+
if sample_video_length > subvideo_length_img_prop:
|
374 |
+
updated_frames, updated_masks = [], []
|
375 |
+
pad_len = 10
|
376 |
+
for f in range(0, sample_video_length, subvideo_length_img_prop):
|
377 |
+
s_f = max(0, f - pad_len)
|
378 |
+
e_f = min(sample_video_length, f + subvideo_length_img_prop + pad_len)
|
379 |
+
pad_len_s = max(0, f) - s_f
|
380 |
+
pad_len_e = e_f - min(sample_video_length, f + subvideo_length_img_prop)
|
381 |
+
|
382 |
+
b, t, _, _, _ = sample_masks_dilated[:, s_f:e_f].size()
|
383 |
+
pred_flows_bi_sub = (sample_pred_flows_bi[0][:, s_f:e_f-1], sample_pred_flows_bi[1][:, s_f:e_f-1])
|
384 |
+
prop_imgs_sub, updated_local_masks_sub = self.model.img_propagation(masked_frames[:, s_f:e_f],
|
385 |
+
pred_flows_bi_sub,
|
386 |
+
sample_masks_dilated[:, s_f:e_f],
|
387 |
+
'nearest')
|
388 |
+
updated_frames_sub = sample_frames[:, s_f:e_f] * (1 - sample_masks_dilated[:, s_f:e_f]) + \
|
389 |
+
prop_imgs_sub.view(b, t, 3, h, w) * sample_masks_dilated[:, s_f:e_f]
|
390 |
+
updated_masks_sub = updated_local_masks_sub.view(b, t, 1, h, w)
|
391 |
+
|
392 |
+
updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
393 |
+
updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
394 |
+
torch.cuda.empty_cache()
|
395 |
+
|
396 |
+
updated_frames = torch.cat(updated_frames, dim=1)
|
397 |
+
updated_masks = torch.cat(updated_masks, dim=1)
|
398 |
+
else:
|
399 |
+
b, t, _, _, _ = sample_masks_dilated.size()
|
400 |
+
prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, sample_pred_flows_bi, sample_masks_dilated, 'nearest')
|
401 |
+
updated_frames = sample_frames * (1 - sample_masks_dilated) + prop_imgs.view(b, t, 3, h, w) * sample_masks_dilated
|
402 |
+
updated_masks = updated_local_masks.view(b, t, 1, h, w)
|
403 |
+
torch.cuda.empty_cache()
|
404 |
+
|
405 |
+
## replace input frames/masks with updated frames/masks
|
406 |
+
for i,index in enumerate(index_sample):
|
407 |
+
frames[0][index] = updated_frames[0][i]
|
408 |
+
masks_dilated[0][index] = updated_masks[0][i]
|
409 |
+
|
410 |
+
|
411 |
+
# ---- frame-by-frame image propagation ----
|
412 |
+
masked_frames = frames * (1 - masks_dilated)
|
413 |
+
subvideo_length_img_prop = min(100, subvideo_length) # ensure a minimum of 100 frames for image propagation
|
414 |
+
if video_length > subvideo_length_img_prop:
|
415 |
+
updated_frames, updated_masks = [], []
|
416 |
+
pad_len = 10
|
417 |
+
for f in range(0, video_length, subvideo_length_img_prop):
|
418 |
+
s_f = max(0, f - pad_len)
|
419 |
+
e_f = min(video_length, f + subvideo_length_img_prop + pad_len)
|
420 |
+
pad_len_s = max(0, f) - s_f
|
421 |
+
pad_len_e = e_f - min(video_length, f + subvideo_length_img_prop)
|
422 |
+
|
423 |
+
b, t, _, _, _ = masks_dilated[:, s_f:e_f].size()
|
424 |
+
pred_flows_bi_sub = (pred_flows_bi[0][:, s_f:e_f-1], pred_flows_bi[1][:, s_f:e_f-1])
|
425 |
+
prop_imgs_sub, updated_local_masks_sub = self.model.img_propagation(masked_frames[:, s_f:e_f],
|
426 |
+
pred_flows_bi_sub,
|
427 |
+
masks_dilated[:, s_f:e_f],
|
428 |
+
'nearest')
|
429 |
+
updated_frames_sub = frames[:, s_f:e_f] * (1 - masks_dilated[:, s_f:e_f]) + \
|
430 |
+
prop_imgs_sub.view(b, t, 3, h, w) * masks_dilated[:, s_f:e_f]
|
431 |
+
updated_masks_sub = updated_local_masks_sub.view(b, t, 1, h, w)
|
432 |
+
|
433 |
+
updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
434 |
+
updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
435 |
+
torch.cuda.empty_cache()
|
436 |
+
|
437 |
+
updated_frames = torch.cat(updated_frames, dim=1)
|
438 |
+
updated_masks = torch.cat(updated_masks, dim=1)
|
439 |
+
else:
|
440 |
+
b, t, _, _, _ = masks_dilated.size()
|
441 |
+
prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, 'nearest')
|
442 |
+
updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated
|
443 |
+
updated_masks = updated_local_masks.view(b, t, 1, h, w)
|
444 |
+
torch.cuda.empty_cache()
|
445 |
+
|
446 |
+
comp_frames = [None] * video_length
|
447 |
+
|
448 |
+
neighbor_stride = neighbor_length // 2
|
449 |
+
if video_length > subvideo_length:
|
450 |
+
ref_num = subvideo_length // ref_stride
|
451 |
+
else:
|
452 |
+
ref_num = -1
|
453 |
+
|
454 |
+
torch.cuda.empty_cache()
|
455 |
+
# ---- feature propagation + transformer ----
|
456 |
+
for f in tqdm(range(0, video_length, neighbor_stride)):
|
457 |
+
neighbor_ids = [
|
458 |
+
i for i in range(max(0, f - neighbor_stride),
|
459 |
+
min(video_length, f + neighbor_stride + 1))
|
460 |
+
]
|
461 |
+
ref_ids = get_ref_index(f, neighbor_ids, video_length, ref_stride, ref_num)
|
462 |
+
selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :]
|
463 |
+
selected_masks = masks_dilated[:, neighbor_ids + ref_ids, :, :, :]
|
464 |
+
selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :]
|
465 |
+
selected_pred_flows_bi = (pred_flows_bi[0][:, neighbor_ids[:-1], :, :, :], pred_flows_bi[1][:, neighbor_ids[:-1], :, :, :])
|
466 |
+
|
467 |
+
with torch.no_grad():
|
468 |
+
# 1.0 indicates mask
|
469 |
+
l_t = len(neighbor_ids)
|
470 |
+
|
471 |
+
# pred_img = selected_imgs # results of image propagation
|
472 |
+
pred_img = self.model(selected_imgs, selected_pred_flows_bi, selected_masks, selected_update_masks, l_t)
|
473 |
+
pred_img = pred_img.view(-1, 3, h, w)
|
474 |
+
|
475 |
+
## compose with input frames
|
476 |
+
pred_img = (pred_img + 1) / 2
|
477 |
+
pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255
|
478 |
+
binary_masks = masks_dilated_ori[0, neighbor_ids, :, :, :].cpu().permute(
|
479 |
+
0, 2, 3, 1).numpy().astype(np.uint8) # use original mask
|
480 |
+
for i in range(len(neighbor_ids)):
|
481 |
+
idx = neighbor_ids[i]
|
482 |
+
img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \
|
483 |
+
+ ori_frames_inp[idx] * (1 - binary_masks[i])
|
484 |
+
if comp_frames[idx] is None:
|
485 |
+
comp_frames[idx] = img
|
486 |
+
else:
|
487 |
+
comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5
|
488 |
+
|
489 |
+
comp_frames[idx] = comp_frames[idx].astype(np.uint8)
|
490 |
+
|
491 |
+
torch.cuda.empty_cache()
|
492 |
+
|
493 |
+
##save composed video##
|
494 |
+
comp_frames = [cv2.resize(f, out_size) for f in comp_frames]
|
495 |
+
writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"),
|
496 |
+
fps, (comp_frames[0].shape[1],comp_frames[0].shape[0]))
|
497 |
+
for f in range(video_length):
|
498 |
+
frame = comp_frames[f].astype(np.uint8)
|
499 |
+
writer.write(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
500 |
+
writer.release()
|
501 |
+
|
502 |
+
torch.cuda.empty_cache()
|
503 |
+
|
504 |
+
return output_path
|
505 |
+
|
506 |
+
|
507 |
+
|
508 |
+
if __name__ == '__main__':
|
509 |
+
|
510 |
+
device = get_device()
|
511 |
+
propainter_model_dir = "weights/propainter"
|
512 |
+
propainter = Propainter(propainter_model_dir, device=device)
|
513 |
+
|
514 |
+
video = "examples/example1/video.mp4"
|
515 |
+
mask = "examples/example1/mask.mp4"
|
516 |
+
output = "results/priori.mp4"
|
517 |
+
res = propainter.forward(video, mask, output)
|
518 |
+
|
519 |
+
|
520 |
+
|
propainter/model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
propainter/model/canny/canny_filter.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from .gaussian import gaussian_blur2d
|
9 |
+
from .kernels import get_canny_nms_kernel, get_hysteresis_kernel
|
10 |
+
from .sobel import spatial_gradient
|
11 |
+
|
12 |
+
def rgb_to_grayscale(image, rgb_weights = None):
|
13 |
+
if len(image.shape) < 3 or image.shape[-3] != 3:
|
14 |
+
raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
|
15 |
+
|
16 |
+
if rgb_weights is None:
|
17 |
+
# 8 bit images
|
18 |
+
if image.dtype == torch.uint8:
|
19 |
+
rgb_weights = torch.tensor([76, 150, 29], device=image.device, dtype=torch.uint8)
|
20 |
+
# floating point images
|
21 |
+
elif image.dtype in (torch.float16, torch.float32, torch.float64):
|
22 |
+
rgb_weights = torch.tensor([0.299, 0.587, 0.114], device=image.device, dtype=image.dtype)
|
23 |
+
else:
|
24 |
+
raise TypeError(f"Unknown data type: {image.dtype}")
|
25 |
+
else:
|
26 |
+
# is tensor that we make sure is in the same device/dtype
|
27 |
+
rgb_weights = rgb_weights.to(image)
|
28 |
+
|
29 |
+
# unpack the color image channels with RGB order
|
30 |
+
r = image[..., 0:1, :, :]
|
31 |
+
g = image[..., 1:2, :, :]
|
32 |
+
b = image[..., 2:3, :, :]
|
33 |
+
|
34 |
+
w_r, w_g, w_b = rgb_weights.unbind()
|
35 |
+
return w_r * r + w_g * g + w_b * b
|
36 |
+
|
37 |
+
|
38 |
+
def canny(
|
39 |
+
input: torch.Tensor,
|
40 |
+
low_threshold: float = 0.1,
|
41 |
+
high_threshold: float = 0.2,
|
42 |
+
kernel_size: Tuple[int, int] = (5, 5),
|
43 |
+
sigma: Tuple[float, float] = (1, 1),
|
44 |
+
hysteresis: bool = True,
|
45 |
+
eps: float = 1e-6,
|
46 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
47 |
+
r"""Find edges of the input image and filters them using the Canny algorithm.
|
48 |
+
|
49 |
+
.. image:: _static/img/canny.png
|
50 |
+
|
51 |
+
Args:
|
52 |
+
input: input image tensor with shape :math:`(B,C,H,W)`.
|
53 |
+
low_threshold: lower threshold for the hysteresis procedure.
|
54 |
+
high_threshold: upper threshold for the hysteresis procedure.
|
55 |
+
kernel_size: the size of the kernel for the gaussian blur.
|
56 |
+
sigma: the standard deviation of the kernel for the gaussian blur.
|
57 |
+
hysteresis: if True, applies the hysteresis edge tracking.
|
58 |
+
Otherwise, the edges are divided between weak (0.5) and strong (1) edges.
|
59 |
+
eps: regularization number to avoid NaN during backprop.
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
- the canny edge magnitudes map, shape of :math:`(B,1,H,W)`.
|
63 |
+
- the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`.
|
64 |
+
|
65 |
+
.. note::
|
66 |
+
See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
|
67 |
+
canny.html>`__.
|
68 |
+
|
69 |
+
Example:
|
70 |
+
>>> input = torch.rand(5, 3, 4, 4)
|
71 |
+
>>> magnitude, edges = canny(input) # 5x3x4x4
|
72 |
+
>>> magnitude.shape
|
73 |
+
torch.Size([5, 1, 4, 4])
|
74 |
+
>>> edges.shape
|
75 |
+
torch.Size([5, 1, 4, 4])
|
76 |
+
"""
|
77 |
+
if not isinstance(input, torch.Tensor):
|
78 |
+
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
|
79 |
+
|
80 |
+
if not len(input.shape) == 4:
|
81 |
+
raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
|
82 |
+
|
83 |
+
if low_threshold > high_threshold:
|
84 |
+
raise ValueError(
|
85 |
+
"Invalid input thresholds. low_threshold should be smaller than the high_threshold. Got: {}>{}".format(
|
86 |
+
low_threshold, high_threshold
|
87 |
+
)
|
88 |
+
)
|
89 |
+
|
90 |
+
if low_threshold < 0 and low_threshold > 1:
|
91 |
+
raise ValueError(f"Invalid input threshold. low_threshold should be in range (0,1). Got: {low_threshold}")
|
92 |
+
|
93 |
+
if high_threshold < 0 and high_threshold > 1:
|
94 |
+
raise ValueError(f"Invalid input threshold. high_threshold should be in range (0,1). Got: {high_threshold}")
|
95 |
+
|
96 |
+
device: torch.device = input.device
|
97 |
+
dtype: torch.dtype = input.dtype
|
98 |
+
|
99 |
+
# To Grayscale
|
100 |
+
if input.shape[1] == 3:
|
101 |
+
input = rgb_to_grayscale(input)
|
102 |
+
|
103 |
+
# Gaussian filter
|
104 |
+
blurred: torch.Tensor = gaussian_blur2d(input, kernel_size, sigma)
|
105 |
+
|
106 |
+
# Compute the gradients
|
107 |
+
gradients: torch.Tensor = spatial_gradient(blurred, normalized=False)
|
108 |
+
|
109 |
+
# Unpack the edges
|
110 |
+
gx: torch.Tensor = gradients[:, :, 0]
|
111 |
+
gy: torch.Tensor = gradients[:, :, 1]
|
112 |
+
|
113 |
+
# Compute gradient magnitude and angle
|
114 |
+
magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps)
|
115 |
+
angle: torch.Tensor = torch.atan2(gy, gx)
|
116 |
+
|
117 |
+
# Radians to Degrees
|
118 |
+
angle = 180.0 * angle / math.pi
|
119 |
+
|
120 |
+
# Round angle to the nearest 45 degree
|
121 |
+
angle = torch.round(angle / 45) * 45
|
122 |
+
|
123 |
+
# Non-maximal suppression
|
124 |
+
nms_kernels: torch.Tensor = get_canny_nms_kernel(device, dtype)
|
125 |
+
nms_magnitude: torch.Tensor = F.conv2d(magnitude, nms_kernels, padding=nms_kernels.shape[-1] // 2)
|
126 |
+
|
127 |
+
# Get the indices for both directions
|
128 |
+
positive_idx: torch.Tensor = (angle / 45) % 8
|
129 |
+
positive_idx = positive_idx.long()
|
130 |
+
|
131 |
+
negative_idx: torch.Tensor = ((angle / 45) + 4) % 8
|
132 |
+
negative_idx = negative_idx.long()
|
133 |
+
|
134 |
+
# Apply the non-maximum suppression to the different directions
|
135 |
+
channel_select_filtered_positive: torch.Tensor = torch.gather(nms_magnitude, 1, positive_idx)
|
136 |
+
channel_select_filtered_negative: torch.Tensor = torch.gather(nms_magnitude, 1, negative_idx)
|
137 |
+
|
138 |
+
channel_select_filtered: torch.Tensor = torch.stack(
|
139 |
+
[channel_select_filtered_positive, channel_select_filtered_negative], 1
|
140 |
+
)
|
141 |
+
|
142 |
+
is_max: torch.Tensor = channel_select_filtered.min(dim=1)[0] > 0.0
|
143 |
+
|
144 |
+
magnitude = magnitude * is_max
|
145 |
+
|
146 |
+
# Threshold
|
147 |
+
edges: torch.Tensor = F.threshold(magnitude, low_threshold, 0.0)
|
148 |
+
|
149 |
+
low: torch.Tensor = magnitude > low_threshold
|
150 |
+
high: torch.Tensor = magnitude > high_threshold
|
151 |
+
|
152 |
+
edges = low * 0.5 + high * 0.5
|
153 |
+
edges = edges.to(dtype)
|
154 |
+
|
155 |
+
# Hysteresis
|
156 |
+
if hysteresis:
|
157 |
+
edges_old: torch.Tensor = -torch.ones(edges.shape, device=edges.device, dtype=dtype)
|
158 |
+
hysteresis_kernels: torch.Tensor = get_hysteresis_kernel(device, dtype)
|
159 |
+
|
160 |
+
while ((edges_old - edges).abs() != 0).any():
|
161 |
+
weak: torch.Tensor = (edges == 0.5).float()
|
162 |
+
strong: torch.Tensor = (edges == 1).float()
|
163 |
+
|
164 |
+
hysteresis_magnitude: torch.Tensor = F.conv2d(
|
165 |
+
edges, hysteresis_kernels, padding=hysteresis_kernels.shape[-1] // 2
|
166 |
+
)
|
167 |
+
hysteresis_magnitude = (hysteresis_magnitude == 1).any(1, keepdim=True).to(dtype)
|
168 |
+
hysteresis_magnitude = hysteresis_magnitude * weak + strong
|
169 |
+
|
170 |
+
edges_old = edges.clone()
|
171 |
+
edges = hysteresis_magnitude + (hysteresis_magnitude == 0) * weak * 0.5
|
172 |
+
|
173 |
+
edges = hysteresis_magnitude
|
174 |
+
|
175 |
+
return magnitude, edges
|
176 |
+
|
177 |
+
|
178 |
+
class Canny(nn.Module):
|
179 |
+
r"""Module that finds edges of the input image and filters them using the Canny algorithm.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
input: input image tensor with shape :math:`(B,C,H,W)`.
|
183 |
+
low_threshold: lower threshold for the hysteresis procedure.
|
184 |
+
high_threshold: upper threshold for the hysteresis procedure.
|
185 |
+
kernel_size: the size of the kernel for the gaussian blur.
|
186 |
+
sigma: the standard deviation of the kernel for the gaussian blur.
|
187 |
+
hysteresis: if True, applies the hysteresis edge tracking.
|
188 |
+
Otherwise, the edges are divided between weak (0.5) and strong (1) edges.
|
189 |
+
eps: regularization number to avoid NaN during backprop.
|
190 |
+
|
191 |
+
Returns:
|
192 |
+
- the canny edge magnitudes map, shape of :math:`(B,1,H,W)`.
|
193 |
+
- the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`.
|
194 |
+
|
195 |
+
Example:
|
196 |
+
>>> input = torch.rand(5, 3, 4, 4)
|
197 |
+
>>> magnitude, edges = Canny()(input) # 5x3x4x4
|
198 |
+
>>> magnitude.shape
|
199 |
+
torch.Size([5, 1, 4, 4])
|
200 |
+
>>> edges.shape
|
201 |
+
torch.Size([5, 1, 4, 4])
|
202 |
+
"""
|
203 |
+
|
204 |
+
def __init__(
|
205 |
+
self,
|
206 |
+
low_threshold: float = 0.1,
|
207 |
+
high_threshold: float = 0.2,
|
208 |
+
kernel_size: Tuple[int, int] = (5, 5),
|
209 |
+
sigma: Tuple[float, float] = (1, 1),
|
210 |
+
hysteresis: bool = True,
|
211 |
+
eps: float = 1e-6,
|
212 |
+
) -> None:
|
213 |
+
super().__init__()
|
214 |
+
|
215 |
+
if low_threshold > high_threshold:
|
216 |
+
raise ValueError(
|
217 |
+
"Invalid input thresholds. low_threshold should be\
|
218 |
+
smaller than the high_threshold. Got: {}>{}".format(
|
219 |
+
low_threshold, high_threshold
|
220 |
+
)
|
221 |
+
)
|
222 |
+
|
223 |
+
if low_threshold < 0 or low_threshold > 1:
|
224 |
+
raise ValueError(f"Invalid input threshold. low_threshold should be in range (0,1). Got: {low_threshold}")
|
225 |
+
|
226 |
+
if high_threshold < 0 or high_threshold > 1:
|
227 |
+
raise ValueError(f"Invalid input threshold. high_threshold should be in range (0,1). Got: {high_threshold}")
|
228 |
+
|
229 |
+
# Gaussian blur parameters
|
230 |
+
self.kernel_size = kernel_size
|
231 |
+
self.sigma = sigma
|
232 |
+
|
233 |
+
# Double threshold
|
234 |
+
self.low_threshold = low_threshold
|
235 |
+
self.high_threshold = high_threshold
|
236 |
+
|
237 |
+
# Hysteresis
|
238 |
+
self.hysteresis = hysteresis
|
239 |
+
|
240 |
+
self.eps: float = eps
|
241 |
+
|
242 |
+
def __repr__(self) -> str:
|
243 |
+
return ''.join(
|
244 |
+
(
|
245 |
+
f'{type(self).__name__}(',
|
246 |
+
', '.join(
|
247 |
+
f'{name}={getattr(self, name)}' for name in sorted(self.__dict__) if not name.startswith('_')
|
248 |
+
),
|
249 |
+
')',
|
250 |
+
)
|
251 |
+
)
|
252 |
+
|
253 |
+
def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
254 |
+
return canny(
|
255 |
+
input, self.low_threshold, self.high_threshold, self.kernel_size, self.sigma, self.hysteresis, self.eps
|
256 |
+
)
|
propainter/model/canny/filter.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from .kernels import normalize_kernel2d
|
7 |
+
|
8 |
+
|
9 |
+
def _compute_padding(kernel_size: List[int]) -> List[int]:
|
10 |
+
"""Compute padding tuple."""
|
11 |
+
# 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
|
12 |
+
# https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
|
13 |
+
if len(kernel_size) < 2:
|
14 |
+
raise AssertionError(kernel_size)
|
15 |
+
computed = [k - 1 for k in kernel_size]
|
16 |
+
|
17 |
+
# for even kernels we need to do asymmetric padding :(
|
18 |
+
out_padding = 2 * len(kernel_size) * [0]
|
19 |
+
|
20 |
+
for i in range(len(kernel_size)):
|
21 |
+
computed_tmp = computed[-(i + 1)]
|
22 |
+
|
23 |
+
pad_front = computed_tmp // 2
|
24 |
+
pad_rear = computed_tmp - pad_front
|
25 |
+
|
26 |
+
out_padding[2 * i + 0] = pad_front
|
27 |
+
out_padding[2 * i + 1] = pad_rear
|
28 |
+
|
29 |
+
return out_padding
|
30 |
+
|
31 |
+
|
32 |
+
def filter2d(
|
33 |
+
input: torch.Tensor,
|
34 |
+
kernel: torch.Tensor,
|
35 |
+
border_type: str = 'reflect',
|
36 |
+
normalized: bool = False,
|
37 |
+
padding: str = 'same',
|
38 |
+
) -> torch.Tensor:
|
39 |
+
r"""Convolve a tensor with a 2d kernel.
|
40 |
+
|
41 |
+
The function applies a given kernel to a tensor. The kernel is applied
|
42 |
+
independently at each depth channel of the tensor. Before applying the
|
43 |
+
kernel, the function applies padding according to the specified mode so
|
44 |
+
that the output remains in the same shape.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
input: the input tensor with shape of
|
48 |
+
:math:`(B, C, H, W)`.
|
49 |
+
kernel: the kernel to be convolved with the input
|
50 |
+
tensor. The kernel shape must be :math:`(1, kH, kW)` or :math:`(B, kH, kW)`.
|
51 |
+
border_type: the padding mode to be applied before convolving.
|
52 |
+
The expected modes are: ``'constant'``, ``'reflect'``,
|
53 |
+
``'replicate'`` or ``'circular'``.
|
54 |
+
normalized: If True, kernel will be L1 normalized.
|
55 |
+
padding: This defines the type of padding.
|
56 |
+
2 modes available ``'same'`` or ``'valid'``.
|
57 |
+
|
58 |
+
Return:
|
59 |
+
torch.Tensor: the convolved tensor of same size and numbers of channels
|
60 |
+
as the input with shape :math:`(B, C, H, W)`.
|
61 |
+
|
62 |
+
Example:
|
63 |
+
>>> input = torch.tensor([[[
|
64 |
+
... [0., 0., 0., 0., 0.],
|
65 |
+
... [0., 0., 0., 0., 0.],
|
66 |
+
... [0., 0., 5., 0., 0.],
|
67 |
+
... [0., 0., 0., 0., 0.],
|
68 |
+
... [0., 0., 0., 0., 0.],]]])
|
69 |
+
>>> kernel = torch.ones(1, 3, 3)
|
70 |
+
>>> filter2d(input, kernel, padding='same')
|
71 |
+
tensor([[[[0., 0., 0., 0., 0.],
|
72 |
+
[0., 5., 5., 5., 0.],
|
73 |
+
[0., 5., 5., 5., 0.],
|
74 |
+
[0., 5., 5., 5., 0.],
|
75 |
+
[0., 0., 0., 0., 0.]]]])
|
76 |
+
"""
|
77 |
+
if not isinstance(input, torch.Tensor):
|
78 |
+
raise TypeError(f"Input input is not torch.Tensor. Got {type(input)}")
|
79 |
+
|
80 |
+
if not isinstance(kernel, torch.Tensor):
|
81 |
+
raise TypeError(f"Input kernel is not torch.Tensor. Got {type(kernel)}")
|
82 |
+
|
83 |
+
if not isinstance(border_type, str):
|
84 |
+
raise TypeError(f"Input border_type is not string. Got {type(border_type)}")
|
85 |
+
|
86 |
+
if border_type not in ['constant', 'reflect', 'replicate', 'circular']:
|
87 |
+
raise ValueError(
|
88 |
+
f"Invalid border type, we expect 'constant', \
|
89 |
+
'reflect', 'replicate', 'circular'. Got:{border_type}"
|
90 |
+
)
|
91 |
+
|
92 |
+
if not isinstance(padding, str):
|
93 |
+
raise TypeError(f"Input padding is not string. Got {type(padding)}")
|
94 |
+
|
95 |
+
if padding not in ['valid', 'same']:
|
96 |
+
raise ValueError(f"Invalid padding mode, we expect 'valid' or 'same'. Got: {padding}")
|
97 |
+
|
98 |
+
if not len(input.shape) == 4:
|
99 |
+
raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
|
100 |
+
|
101 |
+
if (not len(kernel.shape) == 3) and not ((kernel.shape[0] == 0) or (kernel.shape[0] == input.shape[0])):
|
102 |
+
raise ValueError(f"Invalid kernel shape, we expect 1xHxW or BxHxW. Got: {kernel.shape}")
|
103 |
+
|
104 |
+
# prepare kernel
|
105 |
+
b, c, h, w = input.shape
|
106 |
+
tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input)
|
107 |
+
|
108 |
+
if normalized:
|
109 |
+
tmp_kernel = normalize_kernel2d(tmp_kernel)
|
110 |
+
|
111 |
+
tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
|
112 |
+
|
113 |
+
height, width = tmp_kernel.shape[-2:]
|
114 |
+
|
115 |
+
# pad the input tensor
|
116 |
+
if padding == 'same':
|
117 |
+
padding_shape: List[int] = _compute_padding([height, width])
|
118 |
+
input = F.pad(input, padding_shape, mode=border_type)
|
119 |
+
|
120 |
+
# kernel and input tensor reshape to align element-wise or batch-wise params
|
121 |
+
tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
|
122 |
+
input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
|
123 |
+
|
124 |
+
# convolve the tensor with the kernel.
|
125 |
+
output = F.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
|
126 |
+
|
127 |
+
if padding == 'same':
|
128 |
+
out = output.view(b, c, h, w)
|
129 |
+
else:
|
130 |
+
out = output.view(b, c, h - height + 1, w - width + 1)
|
131 |
+
|
132 |
+
return out
|
133 |
+
|
134 |
+
|
135 |
+
def filter2d_separable(
|
136 |
+
input: torch.Tensor,
|
137 |
+
kernel_x: torch.Tensor,
|
138 |
+
kernel_y: torch.Tensor,
|
139 |
+
border_type: str = 'reflect',
|
140 |
+
normalized: bool = False,
|
141 |
+
padding: str = 'same',
|
142 |
+
) -> torch.Tensor:
|
143 |
+
r"""Convolve a tensor with two 1d kernels, in x and y directions.
|
144 |
+
|
145 |
+
The function applies a given kernel to a tensor. The kernel is applied
|
146 |
+
independently at each depth channel of the tensor. Before applying the
|
147 |
+
kernel, the function applies padding according to the specified mode so
|
148 |
+
that the output remains in the same shape.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
input: the input tensor with shape of
|
152 |
+
:math:`(B, C, H, W)`.
|
153 |
+
kernel_x: the kernel to be convolved with the input
|
154 |
+
tensor. The kernel shape must be :math:`(1, kW)` or :math:`(B, kW)`.
|
155 |
+
kernel_y: the kernel to be convolved with the input
|
156 |
+
tensor. The kernel shape must be :math:`(1, kH)` or :math:`(B, kH)`.
|
157 |
+
border_type: the padding mode to be applied before convolving.
|
158 |
+
The expected modes are: ``'constant'``, ``'reflect'``,
|
159 |
+
``'replicate'`` or ``'circular'``.
|
160 |
+
normalized: If True, kernel will be L1 normalized.
|
161 |
+
padding: This defines the type of padding.
|
162 |
+
2 modes available ``'same'`` or ``'valid'``.
|
163 |
+
|
164 |
+
Return:
|
165 |
+
torch.Tensor: the convolved tensor of same size and numbers of channels
|
166 |
+
as the input with shape :math:`(B, C, H, W)`.
|
167 |
+
|
168 |
+
Example:
|
169 |
+
>>> input = torch.tensor([[[
|
170 |
+
... [0., 0., 0., 0., 0.],
|
171 |
+
... [0., 0., 0., 0., 0.],
|
172 |
+
... [0., 0., 5., 0., 0.],
|
173 |
+
... [0., 0., 0., 0., 0.],
|
174 |
+
... [0., 0., 0., 0., 0.],]]])
|
175 |
+
>>> kernel = torch.ones(1, 3)
|
176 |
+
|
177 |
+
>>> filter2d_separable(input, kernel, kernel, padding='same')
|
178 |
+
tensor([[[[0., 0., 0., 0., 0.],
|
179 |
+
[0., 5., 5., 5., 0.],
|
180 |
+
[0., 5., 5., 5., 0.],
|
181 |
+
[0., 5., 5., 5., 0.],
|
182 |
+
[0., 0., 0., 0., 0.]]]])
|
183 |
+
"""
|
184 |
+
out_x = filter2d(input, kernel_x.unsqueeze(0), border_type, normalized, padding)
|
185 |
+
out = filter2d(out_x, kernel_y.unsqueeze(-1), border_type, normalized, padding)
|
186 |
+
return out
|
187 |
+
|
188 |
+
|
189 |
+
def filter3d(
|
190 |
+
input: torch.Tensor, kernel: torch.Tensor, border_type: str = 'replicate', normalized: bool = False
|
191 |
+
) -> torch.Tensor:
|
192 |
+
r"""Convolve a tensor with a 3d kernel.
|
193 |
+
|
194 |
+
The function applies a given kernel to a tensor. The kernel is applied
|
195 |
+
independently at each depth channel of the tensor. Before applying the
|
196 |
+
kernel, the function applies padding according to the specified mode so
|
197 |
+
that the output remains in the same shape.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
input: the input tensor with shape of
|
201 |
+
:math:`(B, C, D, H, W)`.
|
202 |
+
kernel: the kernel to be convolved with the input
|
203 |
+
tensor. The kernel shape must be :math:`(1, kD, kH, kW)` or :math:`(B, kD, kH, kW)`.
|
204 |
+
border_type: the padding mode to be applied before convolving.
|
205 |
+
The expected modes are: ``'constant'``,
|
206 |
+
``'replicate'`` or ``'circular'``.
|
207 |
+
normalized: If True, kernel will be L1 normalized.
|
208 |
+
|
209 |
+
Return:
|
210 |
+
the convolved tensor of same size and numbers of channels
|
211 |
+
as the input with shape :math:`(B, C, D, H, W)`.
|
212 |
+
|
213 |
+
Example:
|
214 |
+
>>> input = torch.tensor([[[
|
215 |
+
... [[0., 0., 0., 0., 0.],
|
216 |
+
... [0., 0., 0., 0., 0.],
|
217 |
+
... [0., 0., 0., 0., 0.],
|
218 |
+
... [0., 0., 0., 0., 0.],
|
219 |
+
... [0., 0., 0., 0., 0.]],
|
220 |
+
... [[0., 0., 0., 0., 0.],
|
221 |
+
... [0., 0., 0., 0., 0.],
|
222 |
+
... [0., 0., 5., 0., 0.],
|
223 |
+
... [0., 0., 0., 0., 0.],
|
224 |
+
... [0., 0., 0., 0., 0.]],
|
225 |
+
... [[0., 0., 0., 0., 0.],
|
226 |
+
... [0., 0., 0., 0., 0.],
|
227 |
+
... [0., 0., 0., 0., 0.],
|
228 |
+
... [0., 0., 0., 0., 0.],
|
229 |
+
... [0., 0., 0., 0., 0.]]
|
230 |
+
... ]]])
|
231 |
+
>>> kernel = torch.ones(1, 3, 3, 3)
|
232 |
+
>>> filter3d(input, kernel)
|
233 |
+
tensor([[[[[0., 0., 0., 0., 0.],
|
234 |
+
[0., 5., 5., 5., 0.],
|
235 |
+
[0., 5., 5., 5., 0.],
|
236 |
+
[0., 5., 5., 5., 0.],
|
237 |
+
[0., 0., 0., 0., 0.]],
|
238 |
+
<BLANKLINE>
|
239 |
+
[[0., 0., 0., 0., 0.],
|
240 |
+
[0., 5., 5., 5., 0.],
|
241 |
+
[0., 5., 5., 5., 0.],
|
242 |
+
[0., 5., 5., 5., 0.],
|
243 |
+
[0., 0., 0., 0., 0.]],
|
244 |
+
<BLANKLINE>
|
245 |
+
[[0., 0., 0., 0., 0.],
|
246 |
+
[0., 5., 5., 5., 0.],
|
247 |
+
[0., 5., 5., 5., 0.],
|
248 |
+
[0., 5., 5., 5., 0.],
|
249 |
+
[0., 0., 0., 0., 0.]]]]])
|
250 |
+
"""
|
251 |
+
if not isinstance(input, torch.Tensor):
|
252 |
+
raise TypeError(f"Input border_type is not torch.Tensor. Got {type(input)}")
|
253 |
+
|
254 |
+
if not isinstance(kernel, torch.Tensor):
|
255 |
+
raise TypeError(f"Input border_type is not torch.Tensor. Got {type(kernel)}")
|
256 |
+
|
257 |
+
if not isinstance(border_type, str):
|
258 |
+
raise TypeError(f"Input border_type is not string. Got {type(kernel)}")
|
259 |
+
|
260 |
+
if not len(input.shape) == 5:
|
261 |
+
raise ValueError(f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}")
|
262 |
+
|
263 |
+
if not len(kernel.shape) == 4 and kernel.shape[0] != 1:
|
264 |
+
raise ValueError(f"Invalid kernel shape, we expect 1xDxHxW. Got: {kernel.shape}")
|
265 |
+
|
266 |
+
# prepare kernel
|
267 |
+
b, c, d, h, w = input.shape
|
268 |
+
tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input)
|
269 |
+
|
270 |
+
if normalized:
|
271 |
+
bk, dk, hk, wk = kernel.shape
|
272 |
+
tmp_kernel = normalize_kernel2d(tmp_kernel.view(bk, dk, hk * wk)).view_as(tmp_kernel)
|
273 |
+
|
274 |
+
tmp_kernel = tmp_kernel.expand(-1, c, -1, -1, -1)
|
275 |
+
|
276 |
+
# pad the input tensor
|
277 |
+
depth, height, width = tmp_kernel.shape[-3:]
|
278 |
+
padding_shape: List[int] = _compute_padding([depth, height, width])
|
279 |
+
input_pad: torch.Tensor = F.pad(input, padding_shape, mode=border_type)
|
280 |
+
|
281 |
+
# kernel and input tensor reshape to align element-wise or batch-wise params
|
282 |
+
tmp_kernel = tmp_kernel.reshape(-1, 1, depth, height, width)
|
283 |
+
input_pad = input_pad.view(-1, tmp_kernel.size(0), input_pad.size(-3), input_pad.size(-2), input_pad.size(-1))
|
284 |
+
|
285 |
+
# convolve the tensor with the kernel.
|
286 |
+
output = F.conv3d(input_pad, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
|
287 |
+
|
288 |
+
return output.view(b, c, d, h, w)
|
propainter/model/canny/gaussian.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from .filter import filter2d, filter2d_separable
|
7 |
+
from .kernels import get_gaussian_kernel1d, get_gaussian_kernel2d
|
8 |
+
|
9 |
+
|
10 |
+
def gaussian_blur2d(
|
11 |
+
input: torch.Tensor,
|
12 |
+
kernel_size: Tuple[int, int],
|
13 |
+
sigma: Tuple[float, float],
|
14 |
+
border_type: str = 'reflect',
|
15 |
+
separable: bool = True,
|
16 |
+
) -> torch.Tensor:
|
17 |
+
r"""Create an operator that blurs a tensor using a Gaussian filter.
|
18 |
+
|
19 |
+
.. image:: _static/img/gaussian_blur2d.png
|
20 |
+
|
21 |
+
The operator smooths the given tensor with a gaussian kernel by convolving
|
22 |
+
it to each channel. It supports batched operation.
|
23 |
+
|
24 |
+
Arguments:
|
25 |
+
input: the input tensor with shape :math:`(B,C,H,W)`.
|
26 |
+
kernel_size: the size of the kernel.
|
27 |
+
sigma: the standard deviation of the kernel.
|
28 |
+
border_type: the padding mode to be applied before convolving.
|
29 |
+
The expected modes are: ``'constant'``, ``'reflect'``,
|
30 |
+
``'replicate'`` or ``'circular'``. Default: ``'reflect'``.
|
31 |
+
separable: run as composition of two 1d-convolutions.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
the blurred tensor with shape :math:`(B, C, H, W)`.
|
35 |
+
|
36 |
+
.. note::
|
37 |
+
See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
|
38 |
+
gaussian_blur.html>`__.
|
39 |
+
|
40 |
+
Examples:
|
41 |
+
>>> input = torch.rand(2, 4, 5, 5)
|
42 |
+
>>> output = gaussian_blur2d(input, (3, 3), (1.5, 1.5))
|
43 |
+
>>> output.shape
|
44 |
+
torch.Size([2, 4, 5, 5])
|
45 |
+
"""
|
46 |
+
if separable:
|
47 |
+
kernel_x: torch.Tensor = get_gaussian_kernel1d(kernel_size[1], sigma[1])
|
48 |
+
kernel_y: torch.Tensor = get_gaussian_kernel1d(kernel_size[0], sigma[0])
|
49 |
+
out = filter2d_separable(input, kernel_x[None], kernel_y[None], border_type)
|
50 |
+
else:
|
51 |
+
kernel: torch.Tensor = get_gaussian_kernel2d(kernel_size, sigma)
|
52 |
+
out = filter2d(input, kernel[None], border_type)
|
53 |
+
return out
|
54 |
+
|
55 |
+
|
56 |
+
class GaussianBlur2d(nn.Module):
|
57 |
+
r"""Create an operator that blurs a tensor using a Gaussian filter.
|
58 |
+
|
59 |
+
The operator smooths the given tensor with a gaussian kernel by convolving
|
60 |
+
it to each channel. It supports batched operation.
|
61 |
+
|
62 |
+
Arguments:
|
63 |
+
kernel_size: the size of the kernel.
|
64 |
+
sigma: the standard deviation of the kernel.
|
65 |
+
border_type: the padding mode to be applied before convolving.
|
66 |
+
The expected modes are: ``'constant'``, ``'reflect'``,
|
67 |
+
``'replicate'`` or ``'circular'``. Default: ``'reflect'``.
|
68 |
+
separable: run as composition of two 1d-convolutions.
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
the blurred tensor.
|
72 |
+
|
73 |
+
Shape:
|
74 |
+
- Input: :math:`(B, C, H, W)`
|
75 |
+
- Output: :math:`(B, C, H, W)`
|
76 |
+
|
77 |
+
Examples::
|
78 |
+
|
79 |
+
>>> input = torch.rand(2, 4, 5, 5)
|
80 |
+
>>> gauss = GaussianBlur2d((3, 3), (1.5, 1.5))
|
81 |
+
>>> output = gauss(input) # 2x4x5x5
|
82 |
+
>>> output.shape
|
83 |
+
torch.Size([2, 4, 5, 5])
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
kernel_size: Tuple[int, int],
|
89 |
+
sigma: Tuple[float, float],
|
90 |
+
border_type: str = 'reflect',
|
91 |
+
separable: bool = True,
|
92 |
+
) -> None:
|
93 |
+
super().__init__()
|
94 |
+
self.kernel_size: Tuple[int, int] = kernel_size
|
95 |
+
self.sigma: Tuple[float, float] = sigma
|
96 |
+
self.border_type = border_type
|
97 |
+
self.separable = separable
|
98 |
+
|
99 |
+
def __repr__(self) -> str:
|
100 |
+
return (
|
101 |
+
self.__class__.__name__
|
102 |
+
+ '(kernel_size='
|
103 |
+
+ str(self.kernel_size)
|
104 |
+
+ ', '
|
105 |
+
+ 'sigma='
|
106 |
+
+ str(self.sigma)
|
107 |
+
+ ', '
|
108 |
+
+ 'border_type='
|
109 |
+
+ self.border_type
|
110 |
+
+ 'separable='
|
111 |
+
+ str(self.separable)
|
112 |
+
+ ')'
|
113 |
+
)
|
114 |
+
|
115 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
116 |
+
return gaussian_blur2d(input, self.kernel_size, self.sigma, self.border_type, self.separable)
|
propainter/model/canny/kernels.py
ADDED
@@ -0,0 +1,690 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from math import sqrt
|
3 |
+
from typing import List, Optional, Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def normalize_kernel2d(input: torch.Tensor) -> torch.Tensor:
|
9 |
+
r"""Normalize both derivative and smoothing kernel."""
|
10 |
+
if len(input.size()) < 2:
|
11 |
+
raise TypeError(f"input should be at least 2D tensor. Got {input.size()}")
|
12 |
+
norm: torch.Tensor = input.abs().sum(dim=-1).sum(dim=-1)
|
13 |
+
return input / (norm.unsqueeze(-1).unsqueeze(-1))
|
14 |
+
|
15 |
+
|
16 |
+
def gaussian(window_size: int, sigma: float) -> torch.Tensor:
|
17 |
+
device, dtype = None, None
|
18 |
+
if isinstance(sigma, torch.Tensor):
|
19 |
+
device, dtype = sigma.device, sigma.dtype
|
20 |
+
x = torch.arange(window_size, device=device, dtype=dtype) - window_size // 2
|
21 |
+
if window_size % 2 == 0:
|
22 |
+
x = x + 0.5
|
23 |
+
|
24 |
+
gauss = torch.exp((-x.pow(2.0) / (2 * sigma**2)).float())
|
25 |
+
return gauss / gauss.sum()
|
26 |
+
|
27 |
+
|
28 |
+
def gaussian_discrete_erf(window_size: int, sigma) -> torch.Tensor:
|
29 |
+
r"""Discrete Gaussian by interpolating the error function.
|
30 |
+
|
31 |
+
Adapted from:
|
32 |
+
https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py
|
33 |
+
"""
|
34 |
+
device = sigma.device if isinstance(sigma, torch.Tensor) else None
|
35 |
+
sigma = torch.as_tensor(sigma, dtype=torch.float, device=device)
|
36 |
+
x = torch.arange(window_size).float() - window_size // 2
|
37 |
+
t = 0.70710678 / torch.abs(sigma)
|
38 |
+
gauss = 0.5 * ((t * (x + 0.5)).erf() - (t * (x - 0.5)).erf())
|
39 |
+
gauss = gauss.clamp(min=0)
|
40 |
+
return gauss / gauss.sum()
|
41 |
+
|
42 |
+
|
43 |
+
def _modified_bessel_0(x: torch.Tensor) -> torch.Tensor:
|
44 |
+
r"""Adapted from:
|
45 |
+
|
46 |
+
https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py
|
47 |
+
"""
|
48 |
+
if torch.abs(x) < 3.75:
|
49 |
+
y = (x / 3.75) * (x / 3.75)
|
50 |
+
return 1.0 + y * (
|
51 |
+
3.5156229 + y * (3.0899424 + y * (1.2067492 + y * (0.2659732 + y * (0.360768e-1 + y * 0.45813e-2))))
|
52 |
+
)
|
53 |
+
ax = torch.abs(x)
|
54 |
+
y = 3.75 / ax
|
55 |
+
ans = 0.916281e-2 + y * (-0.2057706e-1 + y * (0.2635537e-1 + y * (-0.1647633e-1 + y * 0.392377e-2)))
|
56 |
+
coef = 0.39894228 + y * (0.1328592e-1 + y * (0.225319e-2 + y * (-0.157565e-2 + y * ans)))
|
57 |
+
return (torch.exp(ax) / torch.sqrt(ax)) * coef
|
58 |
+
|
59 |
+
|
60 |
+
def _modified_bessel_1(x: torch.Tensor) -> torch.Tensor:
|
61 |
+
r"""adapted from:
|
62 |
+
|
63 |
+
https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py
|
64 |
+
"""
|
65 |
+
if torch.abs(x) < 3.75:
|
66 |
+
y = (x / 3.75) * (x / 3.75)
|
67 |
+
ans = 0.51498869 + y * (0.15084934 + y * (0.2658733e-1 + y * (0.301532e-2 + y * 0.32411e-3)))
|
68 |
+
return torch.abs(x) * (0.5 + y * (0.87890594 + y * ans))
|
69 |
+
ax = torch.abs(x)
|
70 |
+
y = 3.75 / ax
|
71 |
+
ans = 0.2282967e-1 + y * (-0.2895312e-1 + y * (0.1787654e-1 - y * 0.420059e-2))
|
72 |
+
ans = 0.39894228 + y * (-0.3988024e-1 + y * (-0.362018e-2 + y * (0.163801e-2 + y * (-0.1031555e-1 + y * ans))))
|
73 |
+
ans = ans * torch.exp(ax) / torch.sqrt(ax)
|
74 |
+
return -ans if x < 0.0 else ans
|
75 |
+
|
76 |
+
|
77 |
+
def _modified_bessel_i(n: int, x: torch.Tensor) -> torch.Tensor:
|
78 |
+
r"""adapted from:
|
79 |
+
|
80 |
+
https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py
|
81 |
+
"""
|
82 |
+
if n < 2:
|
83 |
+
raise ValueError("n must be greater than 1.")
|
84 |
+
if x == 0.0:
|
85 |
+
return x
|
86 |
+
device = x.device
|
87 |
+
tox = 2.0 / torch.abs(x)
|
88 |
+
ans = torch.tensor(0.0, device=device)
|
89 |
+
bip = torch.tensor(0.0, device=device)
|
90 |
+
bi = torch.tensor(1.0, device=device)
|
91 |
+
m = int(2 * (n + int(sqrt(40.0 * n))))
|
92 |
+
for j in range(m, 0, -1):
|
93 |
+
bim = bip + float(j) * tox * bi
|
94 |
+
bip = bi
|
95 |
+
bi = bim
|
96 |
+
if abs(bi) > 1.0e10:
|
97 |
+
ans = ans * 1.0e-10
|
98 |
+
bi = bi * 1.0e-10
|
99 |
+
bip = bip * 1.0e-10
|
100 |
+
if j == n:
|
101 |
+
ans = bip
|
102 |
+
ans = ans * _modified_bessel_0(x) / bi
|
103 |
+
return -ans if x < 0.0 and (n % 2) == 1 else ans
|
104 |
+
|
105 |
+
|
106 |
+
def gaussian_discrete(window_size, sigma) -> torch.Tensor:
|
107 |
+
r"""Discrete Gaussian kernel based on the modified Bessel functions.
|
108 |
+
|
109 |
+
Adapted from:
|
110 |
+
https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py
|
111 |
+
"""
|
112 |
+
device = sigma.device if isinstance(sigma, torch.Tensor) else None
|
113 |
+
sigma = torch.as_tensor(sigma, dtype=torch.float, device=device)
|
114 |
+
sigma2 = sigma * sigma
|
115 |
+
tail = int(window_size // 2)
|
116 |
+
out_pos: List[Optional[torch.Tensor]] = [None] * (tail + 1)
|
117 |
+
out_pos[0] = _modified_bessel_0(sigma2)
|
118 |
+
out_pos[1] = _modified_bessel_1(sigma2)
|
119 |
+
for k in range(2, len(out_pos)):
|
120 |
+
out_pos[k] = _modified_bessel_i(k, sigma2)
|
121 |
+
out = out_pos[:0:-1]
|
122 |
+
out.extend(out_pos)
|
123 |
+
out = torch.stack(out) * torch.exp(sigma2) # type: ignore
|
124 |
+
return out / out.sum() # type: ignore
|
125 |
+
|
126 |
+
|
127 |
+
def laplacian_1d(window_size) -> torch.Tensor:
|
128 |
+
r"""One could also use the Laplacian of Gaussian formula to design the filter."""
|
129 |
+
|
130 |
+
filter_1d = torch.ones(window_size)
|
131 |
+
filter_1d[window_size // 2] = 1 - window_size
|
132 |
+
laplacian_1d: torch.Tensor = filter_1d
|
133 |
+
return laplacian_1d
|
134 |
+
|
135 |
+
|
136 |
+
def get_box_kernel2d(kernel_size: Tuple[int, int]) -> torch.Tensor:
|
137 |
+
r"""Utility function that returns a box filter."""
|
138 |
+
kx: float = float(kernel_size[0])
|
139 |
+
ky: float = float(kernel_size[1])
|
140 |
+
scale: torch.Tensor = torch.tensor(1.0) / torch.tensor([kx * ky])
|
141 |
+
tmp_kernel: torch.Tensor = torch.ones(1, kernel_size[0], kernel_size[1])
|
142 |
+
return scale.to(tmp_kernel.dtype) * tmp_kernel
|
143 |
+
|
144 |
+
|
145 |
+
def get_binary_kernel2d(window_size: Tuple[int, int]) -> torch.Tensor:
|
146 |
+
r"""Create a binary kernel to extract the patches.
|
147 |
+
|
148 |
+
If the window size is HxW will create a (H*W)xHxW kernel.
|
149 |
+
"""
|
150 |
+
window_range: int = window_size[0] * window_size[1]
|
151 |
+
kernel: torch.Tensor = torch.zeros(window_range, window_range)
|
152 |
+
for i in range(window_range):
|
153 |
+
kernel[i, i] += 1.0
|
154 |
+
return kernel.view(window_range, 1, window_size[0], window_size[1])
|
155 |
+
|
156 |
+
|
157 |
+
def get_sobel_kernel_3x3() -> torch.Tensor:
|
158 |
+
"""Utility function that returns a sobel kernel of 3x3."""
|
159 |
+
return torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]])
|
160 |
+
|
161 |
+
|
162 |
+
def get_sobel_kernel_5x5_2nd_order() -> torch.Tensor:
|
163 |
+
"""Utility function that returns a 2nd order sobel kernel of 5x5."""
|
164 |
+
return torch.tensor(
|
165 |
+
[
|
166 |
+
[-1.0, 0.0, 2.0, 0.0, -1.0],
|
167 |
+
[-4.0, 0.0, 8.0, 0.0, -4.0],
|
168 |
+
[-6.0, 0.0, 12.0, 0.0, -6.0],
|
169 |
+
[-4.0, 0.0, 8.0, 0.0, -4.0],
|
170 |
+
[-1.0, 0.0, 2.0, 0.0, -1.0],
|
171 |
+
]
|
172 |
+
)
|
173 |
+
|
174 |
+
|
175 |
+
def _get_sobel_kernel_5x5_2nd_order_xy() -> torch.Tensor:
|
176 |
+
"""Utility function that returns a 2nd order sobel kernel of 5x5."""
|
177 |
+
return torch.tensor(
|
178 |
+
[
|
179 |
+
[-1.0, -2.0, 0.0, 2.0, 1.0],
|
180 |
+
[-2.0, -4.0, 0.0, 4.0, 2.0],
|
181 |
+
[0.0, 0.0, 0.0, 0.0, 0.0],
|
182 |
+
[2.0, 4.0, 0.0, -4.0, -2.0],
|
183 |
+
[1.0, 2.0, 0.0, -2.0, -1.0],
|
184 |
+
]
|
185 |
+
)
|
186 |
+
|
187 |
+
|
188 |
+
def get_diff_kernel_3x3() -> torch.Tensor:
|
189 |
+
"""Utility function that returns a first order derivative kernel of 3x3."""
|
190 |
+
return torch.tensor([[-0.0, 0.0, 0.0], [-1.0, 0.0, 1.0], [-0.0, 0.0, 0.0]])
|
191 |
+
|
192 |
+
|
193 |
+
def get_diff_kernel3d(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
|
194 |
+
"""Utility function that returns a first order derivative kernel of 3x3x3."""
|
195 |
+
kernel: torch.Tensor = torch.tensor(
|
196 |
+
[
|
197 |
+
[
|
198 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
199 |
+
[[0.0, 0.0, 0.0], [-0.5, 0.0, 0.5], [0.0, 0.0, 0.0]],
|
200 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
201 |
+
],
|
202 |
+
[
|
203 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
204 |
+
[[0.0, -0.5, 0.0], [0.0, 0.0, 0.0], [0.0, 0.5, 0.0]],
|
205 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
206 |
+
],
|
207 |
+
[
|
208 |
+
[[0.0, 0.0, 0.0], [0.0, -0.5, 0.0], [0.0, 0.0, 0.0]],
|
209 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
210 |
+
[[0.0, 0.0, 0.0], [0.0, 0.5, 0.0], [0.0, 0.0, 0.0]],
|
211 |
+
],
|
212 |
+
],
|
213 |
+
device=device,
|
214 |
+
dtype=dtype,
|
215 |
+
)
|
216 |
+
return kernel.unsqueeze(1)
|
217 |
+
|
218 |
+
|
219 |
+
def get_diff_kernel3d_2nd_order(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
|
220 |
+
"""Utility function that returns a first order derivative kernel of 3x3x3."""
|
221 |
+
kernel: torch.Tensor = torch.tensor(
|
222 |
+
[
|
223 |
+
[
|
224 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
225 |
+
[[0.0, 0.0, 0.0], [1.0, -2.0, 1.0], [0.0, 0.0, 0.0]],
|
226 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
227 |
+
],
|
228 |
+
[
|
229 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
230 |
+
[[0.0, 1.0, 0.0], [0.0, -2.0, 0.0], [0.0, 1.0, 0.0]],
|
231 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
232 |
+
],
|
233 |
+
[
|
234 |
+
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
|
235 |
+
[[0.0, 0.0, 0.0], [0.0, -2.0, 0.0], [0.0, 0.0, 0.0]],
|
236 |
+
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
|
237 |
+
],
|
238 |
+
[
|
239 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
240 |
+
[[1.0, 0.0, -1.0], [0.0, 0.0, 0.0], [-1.0, 0.0, 1.0]],
|
241 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
242 |
+
],
|
243 |
+
[
|
244 |
+
[[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, -1.0, 0.0]],
|
245 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
246 |
+
[[0.0, -1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
|
247 |
+
],
|
248 |
+
[
|
249 |
+
[[0.0, 0.0, 0.0], [1.0, 0.0, -1.0], [0.0, 0.0, 0.0]],
|
250 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
251 |
+
[[0.0, 0.0, 0.0], [-1.0, 0.0, 1.0], [0.0, 0.0, 0.0]],
|
252 |
+
],
|
253 |
+
],
|
254 |
+
device=device,
|
255 |
+
dtype=dtype,
|
256 |
+
)
|
257 |
+
return kernel.unsqueeze(1)
|
258 |
+
|
259 |
+
|
260 |
+
def get_sobel_kernel2d() -> torch.Tensor:
|
261 |
+
kernel_x: torch.Tensor = get_sobel_kernel_3x3()
|
262 |
+
kernel_y: torch.Tensor = kernel_x.transpose(0, 1)
|
263 |
+
return torch.stack([kernel_x, kernel_y])
|
264 |
+
|
265 |
+
|
266 |
+
def get_diff_kernel2d() -> torch.Tensor:
|
267 |
+
kernel_x: torch.Tensor = get_diff_kernel_3x3()
|
268 |
+
kernel_y: torch.Tensor = kernel_x.transpose(0, 1)
|
269 |
+
return torch.stack([kernel_x, kernel_y])
|
270 |
+
|
271 |
+
|
272 |
+
def get_sobel_kernel2d_2nd_order() -> torch.Tensor:
|
273 |
+
gxx: torch.Tensor = get_sobel_kernel_5x5_2nd_order()
|
274 |
+
gyy: torch.Tensor = gxx.transpose(0, 1)
|
275 |
+
gxy: torch.Tensor = _get_sobel_kernel_5x5_2nd_order_xy()
|
276 |
+
return torch.stack([gxx, gxy, gyy])
|
277 |
+
|
278 |
+
|
279 |
+
def get_diff_kernel2d_2nd_order() -> torch.Tensor:
|
280 |
+
gxx: torch.Tensor = torch.tensor([[0.0, 0.0, 0.0], [1.0, -2.0, 1.0], [0.0, 0.0, 0.0]])
|
281 |
+
gyy: torch.Tensor = gxx.transpose(0, 1)
|
282 |
+
gxy: torch.Tensor = torch.tensor([[-1.0, 0.0, 1.0], [0.0, 0.0, 0.0], [1.0, 0.0, -1.0]])
|
283 |
+
return torch.stack([gxx, gxy, gyy])
|
284 |
+
|
285 |
+
|
286 |
+
def get_spatial_gradient_kernel2d(mode: str, order: int) -> torch.Tensor:
|
287 |
+
r"""Function that returns kernel for 1st or 2nd order image gradients, using one of the following operators:
|
288 |
+
|
289 |
+
sobel, diff.
|
290 |
+
"""
|
291 |
+
if mode not in ['sobel', 'diff']:
|
292 |
+
raise TypeError(
|
293 |
+
"mode should be either sobel\
|
294 |
+
or diff. Got {}".format(
|
295 |
+
mode
|
296 |
+
)
|
297 |
+
)
|
298 |
+
if order not in [1, 2]:
|
299 |
+
raise TypeError(
|
300 |
+
"order should be either 1 or 2\
|
301 |
+
Got {}".format(
|
302 |
+
order
|
303 |
+
)
|
304 |
+
)
|
305 |
+
if mode == 'sobel' and order == 1:
|
306 |
+
kernel: torch.Tensor = get_sobel_kernel2d()
|
307 |
+
elif mode == 'sobel' and order == 2:
|
308 |
+
kernel = get_sobel_kernel2d_2nd_order()
|
309 |
+
elif mode == 'diff' and order == 1:
|
310 |
+
kernel = get_diff_kernel2d()
|
311 |
+
elif mode == 'diff' and order == 2:
|
312 |
+
kernel = get_diff_kernel2d_2nd_order()
|
313 |
+
else:
|
314 |
+
raise NotImplementedError("")
|
315 |
+
return kernel
|
316 |
+
|
317 |
+
|
318 |
+
def get_spatial_gradient_kernel3d(mode: str, order: int, device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
|
319 |
+
r"""Function that returns kernel for 1st or 2nd order scale pyramid gradients, using one of the following
|
320 |
+
operators: sobel, diff."""
|
321 |
+
if mode not in ['sobel', 'diff']:
|
322 |
+
raise TypeError(
|
323 |
+
"mode should be either sobel\
|
324 |
+
or diff. Got {}".format(
|
325 |
+
mode
|
326 |
+
)
|
327 |
+
)
|
328 |
+
if order not in [1, 2]:
|
329 |
+
raise TypeError(
|
330 |
+
"order should be either 1 or 2\
|
331 |
+
Got {}".format(
|
332 |
+
order
|
333 |
+
)
|
334 |
+
)
|
335 |
+
if mode == 'sobel':
|
336 |
+
raise NotImplementedError("Sobel kernel for 3d gradient is not implemented yet")
|
337 |
+
if mode == 'diff' and order == 1:
|
338 |
+
kernel = get_diff_kernel3d(device, dtype)
|
339 |
+
elif mode == 'diff' and order == 2:
|
340 |
+
kernel = get_diff_kernel3d_2nd_order(device, dtype)
|
341 |
+
else:
|
342 |
+
raise NotImplementedError("")
|
343 |
+
return kernel
|
344 |
+
|
345 |
+
|
346 |
+
def get_gaussian_kernel1d(kernel_size: int, sigma: float, force_even: bool = False) -> torch.Tensor:
|
347 |
+
r"""Function that returns Gaussian filter coefficients.
|
348 |
+
|
349 |
+
Args:
|
350 |
+
kernel_size: filter size. It should be odd and positive.
|
351 |
+
sigma: gaussian standard deviation.
|
352 |
+
force_even: overrides requirement for odd kernel size.
|
353 |
+
|
354 |
+
Returns:
|
355 |
+
1D tensor with gaussian filter coefficients.
|
356 |
+
|
357 |
+
Shape:
|
358 |
+
- Output: :math:`(\text{kernel_size})`
|
359 |
+
|
360 |
+
Examples:
|
361 |
+
|
362 |
+
>>> get_gaussian_kernel1d(3, 2.5)
|
363 |
+
tensor([0.3243, 0.3513, 0.3243])
|
364 |
+
|
365 |
+
>>> get_gaussian_kernel1d(5, 1.5)
|
366 |
+
tensor([0.1201, 0.2339, 0.2921, 0.2339, 0.1201])
|
367 |
+
"""
|
368 |
+
if not isinstance(kernel_size, int) or ((kernel_size % 2 == 0) and not force_even) or (kernel_size <= 0):
|
369 |
+
raise TypeError("kernel_size must be an odd positive integer. " "Got {}".format(kernel_size))
|
370 |
+
window_1d: torch.Tensor = gaussian(kernel_size, sigma)
|
371 |
+
return window_1d
|
372 |
+
|
373 |
+
|
374 |
+
def get_gaussian_discrete_kernel1d(kernel_size: int, sigma: float, force_even: bool = False) -> torch.Tensor:
|
375 |
+
r"""Function that returns Gaussian filter coefficients based on the modified Bessel functions. Adapted from:
|
376 |
+
https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py.
|
377 |
+
|
378 |
+
Args:
|
379 |
+
kernel_size: filter size. It should be odd and positive.
|
380 |
+
sigma: gaussian standard deviation.
|
381 |
+
force_even: overrides requirement for odd kernel size.
|
382 |
+
|
383 |
+
Returns:
|
384 |
+
1D tensor with gaussian filter coefficients.
|
385 |
+
|
386 |
+
Shape:
|
387 |
+
- Output: :math:`(\text{kernel_size})`
|
388 |
+
|
389 |
+
Examples:
|
390 |
+
|
391 |
+
>>> get_gaussian_discrete_kernel1d(3, 2.5)
|
392 |
+
tensor([0.3235, 0.3531, 0.3235])
|
393 |
+
|
394 |
+
>>> get_gaussian_discrete_kernel1d(5, 1.5)
|
395 |
+
tensor([0.1096, 0.2323, 0.3161, 0.2323, 0.1096])
|
396 |
+
"""
|
397 |
+
if not isinstance(kernel_size, int) or ((kernel_size % 2 == 0) and not force_even) or (kernel_size <= 0):
|
398 |
+
raise TypeError("kernel_size must be an odd positive integer. " "Got {}".format(kernel_size))
|
399 |
+
window_1d = gaussian_discrete(kernel_size, sigma)
|
400 |
+
return window_1d
|
401 |
+
|
402 |
+
|
403 |
+
def get_gaussian_erf_kernel1d(kernel_size: int, sigma: float, force_even: bool = False) -> torch.Tensor:
|
404 |
+
r"""Function that returns Gaussian filter coefficients by interpolating the error function, adapted from:
|
405 |
+
https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py.
|
406 |
+
|
407 |
+
Args:
|
408 |
+
kernel_size: filter size. It should be odd and positive.
|
409 |
+
sigma: gaussian standard deviation.
|
410 |
+
force_even: overrides requirement for odd kernel size.
|
411 |
+
|
412 |
+
Returns:
|
413 |
+
1D tensor with gaussian filter coefficients.
|
414 |
+
|
415 |
+
Shape:
|
416 |
+
- Output: :math:`(\text{kernel_size})`
|
417 |
+
|
418 |
+
Examples:
|
419 |
+
|
420 |
+
>>> get_gaussian_erf_kernel1d(3, 2.5)
|
421 |
+
tensor([0.3245, 0.3511, 0.3245])
|
422 |
+
|
423 |
+
>>> get_gaussian_erf_kernel1d(5, 1.5)
|
424 |
+
tensor([0.1226, 0.2331, 0.2887, 0.2331, 0.1226])
|
425 |
+
"""
|
426 |
+
if not isinstance(kernel_size, int) or ((kernel_size % 2 == 0) and not force_even) or (kernel_size <= 0):
|
427 |
+
raise TypeError("kernel_size must be an odd positive integer. " "Got {}".format(kernel_size))
|
428 |
+
window_1d = gaussian_discrete_erf(kernel_size, sigma)
|
429 |
+
return window_1d
|
430 |
+
|
431 |
+
|
432 |
+
def get_gaussian_kernel2d(
|
433 |
+
kernel_size: Tuple[int, int], sigma: Tuple[float, float], force_even: bool = False
|
434 |
+
) -> torch.Tensor:
|
435 |
+
r"""Function that returns Gaussian filter matrix coefficients.
|
436 |
+
|
437 |
+
Args:
|
438 |
+
kernel_size: filter sizes in the x and y direction.
|
439 |
+
Sizes should be odd and positive.
|
440 |
+
sigma: gaussian standard deviation in the x and y
|
441 |
+
direction.
|
442 |
+
force_even: overrides requirement for odd kernel size.
|
443 |
+
|
444 |
+
Returns:
|
445 |
+
2D tensor with gaussian filter matrix coefficients.
|
446 |
+
|
447 |
+
Shape:
|
448 |
+
- Output: :math:`(\text{kernel_size}_x, \text{kernel_size}_y)`
|
449 |
+
|
450 |
+
Examples:
|
451 |
+
>>> get_gaussian_kernel2d((3, 3), (1.5, 1.5))
|
452 |
+
tensor([[0.0947, 0.1183, 0.0947],
|
453 |
+
[0.1183, 0.1478, 0.1183],
|
454 |
+
[0.0947, 0.1183, 0.0947]])
|
455 |
+
>>> get_gaussian_kernel2d((3, 5), (1.5, 1.5))
|
456 |
+
tensor([[0.0370, 0.0720, 0.0899, 0.0720, 0.0370],
|
457 |
+
[0.0462, 0.0899, 0.1123, 0.0899, 0.0462],
|
458 |
+
[0.0370, 0.0720, 0.0899, 0.0720, 0.0370]])
|
459 |
+
"""
|
460 |
+
if not isinstance(kernel_size, tuple) or len(kernel_size) != 2:
|
461 |
+
raise TypeError(f"kernel_size must be a tuple of length two. Got {kernel_size}")
|
462 |
+
if not isinstance(sigma, tuple) or len(sigma) != 2:
|
463 |
+
raise TypeError(f"sigma must be a tuple of length two. Got {sigma}")
|
464 |
+
ksize_x, ksize_y = kernel_size
|
465 |
+
sigma_x, sigma_y = sigma
|
466 |
+
kernel_x: torch.Tensor = get_gaussian_kernel1d(ksize_x, sigma_x, force_even)
|
467 |
+
kernel_y: torch.Tensor = get_gaussian_kernel1d(ksize_y, sigma_y, force_even)
|
468 |
+
kernel_2d: torch.Tensor = torch.matmul(kernel_x.unsqueeze(-1), kernel_y.unsqueeze(-1).t())
|
469 |
+
return kernel_2d
|
470 |
+
|
471 |
+
|
472 |
+
def get_laplacian_kernel1d(kernel_size: int) -> torch.Tensor:
|
473 |
+
r"""Function that returns the coefficients of a 1D Laplacian filter.
|
474 |
+
|
475 |
+
Args:
|
476 |
+
kernel_size: filter size. It should be odd and positive.
|
477 |
+
|
478 |
+
Returns:
|
479 |
+
1D tensor with laplacian filter coefficients.
|
480 |
+
|
481 |
+
Shape:
|
482 |
+
- Output: math:`(\text{kernel_size})`
|
483 |
+
|
484 |
+
Examples:
|
485 |
+
>>> get_laplacian_kernel1d(3)
|
486 |
+
tensor([ 1., -2., 1.])
|
487 |
+
>>> get_laplacian_kernel1d(5)
|
488 |
+
tensor([ 1., 1., -4., 1., 1.])
|
489 |
+
"""
|
490 |
+
if not isinstance(kernel_size, int) or kernel_size % 2 == 0 or kernel_size <= 0:
|
491 |
+
raise TypeError(f"ksize must be an odd positive integer. Got {kernel_size}")
|
492 |
+
window_1d: torch.Tensor = laplacian_1d(kernel_size)
|
493 |
+
return window_1d
|
494 |
+
|
495 |
+
|
496 |
+
def get_laplacian_kernel2d(kernel_size: int) -> torch.Tensor:
|
497 |
+
r"""Function that returns Gaussian filter matrix coefficients.
|
498 |
+
|
499 |
+
Args:
|
500 |
+
kernel_size: filter size should be odd.
|
501 |
+
|
502 |
+
Returns:
|
503 |
+
2D tensor with laplacian filter matrix coefficients.
|
504 |
+
|
505 |
+
Shape:
|
506 |
+
- Output: :math:`(\text{kernel_size}_x, \text{kernel_size}_y)`
|
507 |
+
|
508 |
+
Examples:
|
509 |
+
>>> get_laplacian_kernel2d(3)
|
510 |
+
tensor([[ 1., 1., 1.],
|
511 |
+
[ 1., -8., 1.],
|
512 |
+
[ 1., 1., 1.]])
|
513 |
+
>>> get_laplacian_kernel2d(5)
|
514 |
+
tensor([[ 1., 1., 1., 1., 1.],
|
515 |
+
[ 1., 1., 1., 1., 1.],
|
516 |
+
[ 1., 1., -24., 1., 1.],
|
517 |
+
[ 1., 1., 1., 1., 1.],
|
518 |
+
[ 1., 1., 1., 1., 1.]])
|
519 |
+
"""
|
520 |
+
if not isinstance(kernel_size, int) or kernel_size % 2 == 0 or kernel_size <= 0:
|
521 |
+
raise TypeError(f"ksize must be an odd positive integer. Got {kernel_size}")
|
522 |
+
|
523 |
+
kernel = torch.ones((kernel_size, kernel_size))
|
524 |
+
mid = kernel_size // 2
|
525 |
+
kernel[mid, mid] = 1 - kernel_size**2
|
526 |
+
kernel_2d: torch.Tensor = kernel
|
527 |
+
return kernel_2d
|
528 |
+
|
529 |
+
|
530 |
+
def get_pascal_kernel_2d(kernel_size: int, norm: bool = True) -> torch.Tensor:
|
531 |
+
"""Generate pascal filter kernel by kernel size.
|
532 |
+
|
533 |
+
Args:
|
534 |
+
kernel_size: height and width of the kernel.
|
535 |
+
norm: if to normalize the kernel or not. Default: True.
|
536 |
+
|
537 |
+
Returns:
|
538 |
+
kernel shaped as :math:`(kernel_size, kernel_size)`
|
539 |
+
|
540 |
+
Examples:
|
541 |
+
>>> get_pascal_kernel_2d(1)
|
542 |
+
tensor([[1.]])
|
543 |
+
>>> get_pascal_kernel_2d(4)
|
544 |
+
tensor([[0.0156, 0.0469, 0.0469, 0.0156],
|
545 |
+
[0.0469, 0.1406, 0.1406, 0.0469],
|
546 |
+
[0.0469, 0.1406, 0.1406, 0.0469],
|
547 |
+
[0.0156, 0.0469, 0.0469, 0.0156]])
|
548 |
+
>>> get_pascal_kernel_2d(4, norm=False)
|
549 |
+
tensor([[1., 3., 3., 1.],
|
550 |
+
[3., 9., 9., 3.],
|
551 |
+
[3., 9., 9., 3.],
|
552 |
+
[1., 3., 3., 1.]])
|
553 |
+
"""
|
554 |
+
a = get_pascal_kernel_1d(kernel_size)
|
555 |
+
|
556 |
+
filt = a[:, None] * a[None, :]
|
557 |
+
if norm:
|
558 |
+
filt = filt / torch.sum(filt)
|
559 |
+
return filt
|
560 |
+
|
561 |
+
|
562 |
+
def get_pascal_kernel_1d(kernel_size: int, norm: bool = False) -> torch.Tensor:
|
563 |
+
"""Generate Yang Hui triangle (Pascal's triangle) by a given number.
|
564 |
+
|
565 |
+
Args:
|
566 |
+
kernel_size: height and width of the kernel.
|
567 |
+
norm: if to normalize the kernel or not. Default: False.
|
568 |
+
|
569 |
+
Returns:
|
570 |
+
kernel shaped as :math:`(kernel_size,)`
|
571 |
+
|
572 |
+
Examples:
|
573 |
+
>>> get_pascal_kernel_1d(1)
|
574 |
+
tensor([1.])
|
575 |
+
>>> get_pascal_kernel_1d(2)
|
576 |
+
tensor([1., 1.])
|
577 |
+
>>> get_pascal_kernel_1d(3)
|
578 |
+
tensor([1., 2., 1.])
|
579 |
+
>>> get_pascal_kernel_1d(4)
|
580 |
+
tensor([1., 3., 3., 1.])
|
581 |
+
>>> get_pascal_kernel_1d(5)
|
582 |
+
tensor([1., 4., 6., 4., 1.])
|
583 |
+
>>> get_pascal_kernel_1d(6)
|
584 |
+
tensor([ 1., 5., 10., 10., 5., 1.])
|
585 |
+
"""
|
586 |
+
pre: List[float] = []
|
587 |
+
cur: List[float] = []
|
588 |
+
for i in range(kernel_size):
|
589 |
+
cur = [1.0] * (i + 1)
|
590 |
+
|
591 |
+
for j in range(1, i // 2 + 1):
|
592 |
+
value = pre[j - 1] + pre[j]
|
593 |
+
cur[j] = value
|
594 |
+
if i != 2 * j:
|
595 |
+
cur[-j - 1] = value
|
596 |
+
pre = cur
|
597 |
+
|
598 |
+
out = torch.as_tensor(cur)
|
599 |
+
if norm:
|
600 |
+
out = out / torch.sum(out)
|
601 |
+
return out
|
602 |
+
|
603 |
+
|
604 |
+
def get_canny_nms_kernel(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
|
605 |
+
"""Utility function that returns 3x3 kernels for the Canny Non-maximal suppression."""
|
606 |
+
kernel: torch.Tensor = torch.tensor(
|
607 |
+
[
|
608 |
+
[[0.0, 0.0, 0.0], [0.0, 1.0, -1.0], [0.0, 0.0, 0.0]],
|
609 |
+
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]],
|
610 |
+
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, -1.0, 0.0]],
|
611 |
+
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]],
|
612 |
+
[[0.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
|
613 |
+
[[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
|
614 |
+
[[0.0, -1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
|
615 |
+
[[0.0, 0.0, -1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
|
616 |
+
],
|
617 |
+
device=device,
|
618 |
+
dtype=dtype,
|
619 |
+
)
|
620 |
+
return kernel.unsqueeze(1)
|
621 |
+
|
622 |
+
|
623 |
+
def get_hysteresis_kernel(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
|
624 |
+
"""Utility function that returns the 3x3 kernels for the Canny hysteresis."""
|
625 |
+
kernel: torch.Tensor = torch.tensor(
|
626 |
+
[
|
627 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 0.0]],
|
628 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 1.0]],
|
629 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
|
630 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0]],
|
631 |
+
[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
632 |
+
[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
633 |
+
[[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
634 |
+
[[0.0, 0.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
635 |
+
],
|
636 |
+
device=device,
|
637 |
+
dtype=dtype,
|
638 |
+
)
|
639 |
+
return kernel.unsqueeze(1)
|
640 |
+
|
641 |
+
|
642 |
+
def get_hanning_kernel1d(kernel_size: int, device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
|
643 |
+
r"""Returns Hanning (also known as Hann) kernel, used in signal processing and KCF tracker.
|
644 |
+
|
645 |
+
.. math:: w(n) = 0.5 - 0.5cos\\left(\\frac{2\\pi{n}}{M-1}\\right)
|
646 |
+
\\qquad 0 \\leq n \\leq M-1
|
647 |
+
|
648 |
+
See further in numpy docs https://numpy.org/doc/stable/reference/generated/numpy.hanning.html
|
649 |
+
|
650 |
+
Args:
|
651 |
+
kernel_size: The size the of the kernel. It should be positive.
|
652 |
+
|
653 |
+
Returns:
|
654 |
+
1D tensor with Hanning filter coefficients.
|
655 |
+
.. math:: w(n) = 0.5 - 0.5cos\\left(\\frac{2\\pi{n}}{M-1}\\right)
|
656 |
+
|
657 |
+
Shape:
|
658 |
+
- Output: math:`(\text{kernel_size})`
|
659 |
+
|
660 |
+
Examples:
|
661 |
+
>>> get_hanning_kernel1d(4)
|
662 |
+
tensor([0.0000, 0.7500, 0.7500, 0.0000])
|
663 |
+
"""
|
664 |
+
if not isinstance(kernel_size, int) or kernel_size <= 2:
|
665 |
+
raise TypeError(f"ksize must be an positive integer > 2. Got {kernel_size}")
|
666 |
+
|
667 |
+
x: torch.Tensor = torch.arange(kernel_size, device=device, dtype=dtype)
|
668 |
+
x = 0.5 - 0.5 * torch.cos(2.0 * math.pi * x / float(kernel_size - 1))
|
669 |
+
return x
|
670 |
+
|
671 |
+
|
672 |
+
def get_hanning_kernel2d(kernel_size: Tuple[int, int], device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
|
673 |
+
r"""Returns 2d Hanning kernel, used in signal processing and KCF tracker.
|
674 |
+
|
675 |
+
Args:
|
676 |
+
kernel_size: The size of the kernel for the filter. It should be positive.
|
677 |
+
|
678 |
+
Returns:
|
679 |
+
2D tensor with Hanning filter coefficients.
|
680 |
+
.. math:: w(n) = 0.5 - 0.5cos\\left(\\frac{2\\pi{n}}{M-1}\\right)
|
681 |
+
|
682 |
+
Shape:
|
683 |
+
- Output: math:`(\text{kernel_size[0], kernel_size[1]})`
|
684 |
+
"""
|
685 |
+
if kernel_size[0] <= 2 or kernel_size[1] <= 2:
|
686 |
+
raise TypeError(f"ksize must be an tuple of positive integers > 2. Got {kernel_size}")
|
687 |
+
ky: torch.Tensor = get_hanning_kernel1d(kernel_size[0], device, dtype)[None].T
|
688 |
+
kx: torch.Tensor = get_hanning_kernel1d(kernel_size[1], device, dtype)[None]
|
689 |
+
kernel2d = ky @ kx
|
690 |
+
return kernel2d
|
propainter/model/canny/sobel.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .kernels import get_spatial_gradient_kernel2d, get_spatial_gradient_kernel3d, normalize_kernel2d
|
6 |
+
|
7 |
+
|
8 |
+
def spatial_gradient(input: torch.Tensor, mode: str = 'sobel', order: int = 1, normalized: bool = True) -> torch.Tensor:
|
9 |
+
r"""Compute the first order image derivative in both x and y using a Sobel operator.
|
10 |
+
|
11 |
+
.. image:: _static/img/spatial_gradient.png
|
12 |
+
|
13 |
+
Args:
|
14 |
+
input: input image tensor with shape :math:`(B, C, H, W)`.
|
15 |
+
mode: derivatives modality, can be: `sobel` or `diff`.
|
16 |
+
order: the order of the derivatives.
|
17 |
+
normalized: whether the output is normalized.
|
18 |
+
|
19 |
+
Return:
|
20 |
+
the derivatives of the input feature map. with shape :math:`(B, C, 2, H, W)`.
|
21 |
+
|
22 |
+
.. note::
|
23 |
+
See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
|
24 |
+
filtering_edges.html>`__.
|
25 |
+
|
26 |
+
Examples:
|
27 |
+
>>> input = torch.rand(1, 3, 4, 4)
|
28 |
+
>>> output = spatial_gradient(input) # 1x3x2x4x4
|
29 |
+
>>> output.shape
|
30 |
+
torch.Size([1, 3, 2, 4, 4])
|
31 |
+
"""
|
32 |
+
if not isinstance(input, torch.Tensor):
|
33 |
+
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
|
34 |
+
|
35 |
+
if not len(input.shape) == 4:
|
36 |
+
raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
|
37 |
+
# allocate kernel
|
38 |
+
kernel: torch.Tensor = get_spatial_gradient_kernel2d(mode, order)
|
39 |
+
if normalized:
|
40 |
+
kernel = normalize_kernel2d(kernel)
|
41 |
+
|
42 |
+
# prepare kernel
|
43 |
+
b, c, h, w = input.shape
|
44 |
+
tmp_kernel: torch.Tensor = kernel.to(input).detach()
|
45 |
+
tmp_kernel = tmp_kernel.unsqueeze(1).unsqueeze(1)
|
46 |
+
|
47 |
+
# convolve input tensor with sobel kernel
|
48 |
+
kernel_flip: torch.Tensor = tmp_kernel.flip(-3)
|
49 |
+
|
50 |
+
# Pad with "replicate for spatial dims, but with zeros for channel
|
51 |
+
spatial_pad = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2]
|
52 |
+
out_channels: int = 3 if order == 2 else 2
|
53 |
+
padded_inp: torch.Tensor = F.pad(input.reshape(b * c, 1, h, w), spatial_pad, 'replicate')[:, :, None]
|
54 |
+
|
55 |
+
return F.conv3d(padded_inp, kernel_flip, padding=0).view(b, c, out_channels, h, w)
|
56 |
+
|
57 |
+
|
58 |
+
def spatial_gradient3d(input: torch.Tensor, mode: str = 'diff', order: int = 1) -> torch.Tensor:
|
59 |
+
r"""Compute the first and second order volume derivative in x, y and d using a diff operator.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
input: input features tensor with shape :math:`(B, C, D, H, W)`.
|
63 |
+
mode: derivatives modality, can be: `sobel` or `diff`.
|
64 |
+
order: the order of the derivatives.
|
65 |
+
|
66 |
+
Return:
|
67 |
+
the spatial gradients of the input feature map with shape math:`(B, C, 3, D, H, W)`
|
68 |
+
or :math:`(B, C, 6, D, H, W)`.
|
69 |
+
|
70 |
+
Examples:
|
71 |
+
>>> input = torch.rand(1, 4, 2, 4, 4)
|
72 |
+
>>> output = spatial_gradient3d(input)
|
73 |
+
>>> output.shape
|
74 |
+
torch.Size([1, 4, 3, 2, 4, 4])
|
75 |
+
"""
|
76 |
+
if not isinstance(input, torch.Tensor):
|
77 |
+
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
|
78 |
+
|
79 |
+
if not len(input.shape) == 5:
|
80 |
+
raise ValueError(f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}")
|
81 |
+
b, c, d, h, w = input.shape
|
82 |
+
dev = input.device
|
83 |
+
dtype = input.dtype
|
84 |
+
if (mode == 'diff') and (order == 1):
|
85 |
+
# we go for the special case implementation due to conv3d bad speed
|
86 |
+
x: torch.Tensor = F.pad(input, 6 * [1], 'replicate')
|
87 |
+
center = slice(1, -1)
|
88 |
+
left = slice(0, -2)
|
89 |
+
right = slice(2, None)
|
90 |
+
out = torch.empty(b, c, 3, d, h, w, device=dev, dtype=dtype)
|
91 |
+
out[..., 0, :, :, :] = x[..., center, center, right] - x[..., center, center, left]
|
92 |
+
out[..., 1, :, :, :] = x[..., center, right, center] - x[..., center, left, center]
|
93 |
+
out[..., 2, :, :, :] = x[..., right, center, center] - x[..., left, center, center]
|
94 |
+
out = 0.5 * out
|
95 |
+
else:
|
96 |
+
# prepare kernel
|
97 |
+
# allocate kernel
|
98 |
+
kernel: torch.Tensor = get_spatial_gradient_kernel3d(mode, order)
|
99 |
+
|
100 |
+
tmp_kernel: torch.Tensor = kernel.to(input).detach()
|
101 |
+
tmp_kernel = tmp_kernel.repeat(c, 1, 1, 1, 1)
|
102 |
+
|
103 |
+
# convolve input tensor with grad kernel
|
104 |
+
kernel_flip: torch.Tensor = tmp_kernel.flip(-3)
|
105 |
+
|
106 |
+
# Pad with "replicate for spatial dims, but with zeros for channel
|
107 |
+
spatial_pad = [
|
108 |
+
kernel.size(2) // 2,
|
109 |
+
kernel.size(2) // 2,
|
110 |
+
kernel.size(3) // 2,
|
111 |
+
kernel.size(3) // 2,
|
112 |
+
kernel.size(4) // 2,
|
113 |
+
kernel.size(4) // 2,
|
114 |
+
]
|
115 |
+
out_ch: int = 6 if order == 2 else 3
|
116 |
+
out = F.conv3d(F.pad(input, spatial_pad, 'replicate'), kernel_flip, padding=0, groups=c).view(
|
117 |
+
b, c, out_ch, d, h, w
|
118 |
+
)
|
119 |
+
return out
|
120 |
+
|
121 |
+
|
122 |
+
def sobel(input: torch.Tensor, normalized: bool = True, eps: float = 1e-6) -> torch.Tensor:
|
123 |
+
r"""Compute the Sobel operator and returns the magnitude per channel.
|
124 |
+
|
125 |
+
.. image:: _static/img/sobel.png
|
126 |
+
|
127 |
+
Args:
|
128 |
+
input: the input image with shape :math:`(B,C,H,W)`.
|
129 |
+
normalized: if True, L1 norm of the kernel is set to 1.
|
130 |
+
eps: regularization number to avoid NaN during backprop.
|
131 |
+
|
132 |
+
Return:
|
133 |
+
the sobel edge gradient magnitudes map with shape :math:`(B,C,H,W)`.
|
134 |
+
|
135 |
+
.. note::
|
136 |
+
See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
|
137 |
+
filtering_edges.html>`__.
|
138 |
+
|
139 |
+
Example:
|
140 |
+
>>> input = torch.rand(1, 3, 4, 4)
|
141 |
+
>>> output = sobel(input) # 1x3x4x4
|
142 |
+
>>> output.shape
|
143 |
+
torch.Size([1, 3, 4, 4])
|
144 |
+
"""
|
145 |
+
if not isinstance(input, torch.Tensor):
|
146 |
+
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
|
147 |
+
|
148 |
+
if not len(input.shape) == 4:
|
149 |
+
raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
|
150 |
+
|
151 |
+
# comput the x/y gradients
|
152 |
+
edges: torch.Tensor = spatial_gradient(input, normalized=normalized)
|
153 |
+
|
154 |
+
# unpack the edges
|
155 |
+
gx: torch.Tensor = edges[:, :, 0]
|
156 |
+
gy: torch.Tensor = edges[:, :, 1]
|
157 |
+
|
158 |
+
# compute gradient maginitude
|
159 |
+
magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps)
|
160 |
+
|
161 |
+
return magnitude
|
162 |
+
|
163 |
+
|
164 |
+
class SpatialGradient(nn.Module):
|
165 |
+
r"""Compute the first order image derivative in both x and y using a Sobel operator.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
mode: derivatives modality, can be: `sobel` or `diff`.
|
169 |
+
order: the order of the derivatives.
|
170 |
+
normalized: whether the output is normalized.
|
171 |
+
|
172 |
+
Return:
|
173 |
+
the sobel edges of the input feature map.
|
174 |
+
|
175 |
+
Shape:
|
176 |
+
- Input: :math:`(B, C, H, W)`
|
177 |
+
- Output: :math:`(B, C, 2, H, W)`
|
178 |
+
|
179 |
+
Examples:
|
180 |
+
>>> input = torch.rand(1, 3, 4, 4)
|
181 |
+
>>> output = SpatialGradient()(input) # 1x3x2x4x4
|
182 |
+
"""
|
183 |
+
|
184 |
+
def __init__(self, mode: str = 'sobel', order: int = 1, normalized: bool = True) -> None:
|
185 |
+
super().__init__()
|
186 |
+
self.normalized: bool = normalized
|
187 |
+
self.order: int = order
|
188 |
+
self.mode: str = mode
|
189 |
+
|
190 |
+
def __repr__(self) -> str:
|
191 |
+
return (
|
192 |
+
self.__class__.__name__ + '('
|
193 |
+
'order=' + str(self.order) + ', ' + 'normalized=' + str(self.normalized) + ', ' + 'mode=' + self.mode + ')'
|
194 |
+
)
|
195 |
+
|
196 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
197 |
+
return spatial_gradient(input, self.mode, self.order, self.normalized)
|
198 |
+
|
199 |
+
|
200 |
+
class SpatialGradient3d(nn.Module):
|
201 |
+
r"""Compute the first and second order volume derivative in x, y and d using a diff operator.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
mode: derivatives modality, can be: `sobel` or `diff`.
|
205 |
+
order: the order of the derivatives.
|
206 |
+
|
207 |
+
Return:
|
208 |
+
the spatial gradients of the input feature map.
|
209 |
+
|
210 |
+
Shape:
|
211 |
+
- Input: :math:`(B, C, D, H, W)`. D, H, W are spatial dimensions, gradient is calculated w.r.t to them.
|
212 |
+
- Output: :math:`(B, C, 3, D, H, W)` or :math:`(B, C, 6, D, H, W)`
|
213 |
+
|
214 |
+
Examples:
|
215 |
+
>>> input = torch.rand(1, 4, 2, 4, 4)
|
216 |
+
>>> output = SpatialGradient3d()(input)
|
217 |
+
>>> output.shape
|
218 |
+
torch.Size([1, 4, 3, 2, 4, 4])
|
219 |
+
"""
|
220 |
+
|
221 |
+
def __init__(self, mode: str = 'diff', order: int = 1) -> None:
|
222 |
+
super().__init__()
|
223 |
+
self.order: int = order
|
224 |
+
self.mode: str = mode
|
225 |
+
self.kernel = get_spatial_gradient_kernel3d(mode, order)
|
226 |
+
return
|
227 |
+
|
228 |
+
def __repr__(self) -> str:
|
229 |
+
return self.__class__.__name__ + '(' 'order=' + str(self.order) + ', ' + 'mode=' + self.mode + ')'
|
230 |
+
|
231 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore
|
232 |
+
return spatial_gradient3d(input, self.mode, self.order)
|
233 |
+
|
234 |
+
|
235 |
+
class Sobel(nn.Module):
|
236 |
+
r"""Compute the Sobel operator and returns the magnitude per channel.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
normalized: if True, L1 norm of the kernel is set to 1.
|
240 |
+
eps: regularization number to avoid NaN during backprop.
|
241 |
+
|
242 |
+
Return:
|
243 |
+
the sobel edge gradient magnitudes map.
|
244 |
+
|
245 |
+
Shape:
|
246 |
+
- Input: :math:`(B, C, H, W)`
|
247 |
+
- Output: :math:`(B, C, H, W)`
|
248 |
+
|
249 |
+
Examples:
|
250 |
+
>>> input = torch.rand(1, 3, 4, 4)
|
251 |
+
>>> output = Sobel()(input) # 1x3x4x4
|
252 |
+
"""
|
253 |
+
|
254 |
+
def __init__(self, normalized: bool = True, eps: float = 1e-6) -> None:
|
255 |
+
super().__init__()
|
256 |
+
self.normalized: bool = normalized
|
257 |
+
self.eps: float = eps
|
258 |
+
|
259 |
+
def __repr__(self) -> str:
|
260 |
+
return self.__class__.__name__ + '(' 'normalized=' + str(self.normalized) + ')'
|
261 |
+
|
262 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
263 |
+
return sobel(input, self.normalized, self.eps)
|
propainter/model/misc.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import random
|
4 |
+
import time
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import logging
|
8 |
+
import numpy as np
|
9 |
+
from os import path as osp
|
10 |
+
|
11 |
+
def constant_init(module, val, bias=0):
|
12 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
13 |
+
nn.init.constant_(module.weight, val)
|
14 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
15 |
+
nn.init.constant_(module.bias, bias)
|
16 |
+
|
17 |
+
initialized_logger = {}
|
18 |
+
def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
|
19 |
+
"""Get the root logger.
|
20 |
+
The logger will be initialized if it has not been initialized. By default a
|
21 |
+
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
22 |
+
also be added.
|
23 |
+
Args:
|
24 |
+
logger_name (str): root logger name. Default: 'basicsr'.
|
25 |
+
log_file (str | None): The log filename. If specified, a FileHandler
|
26 |
+
will be added to the root logger.
|
27 |
+
log_level (int): The root logger level. Note that only the process of
|
28 |
+
rank 0 is affected, while other processes will set the level to
|
29 |
+
"Error" and be silent most of the time.
|
30 |
+
Returns:
|
31 |
+
logging.Logger: The root logger.
|
32 |
+
"""
|
33 |
+
logger = logging.getLogger(logger_name)
|
34 |
+
# if the logger has been initialized, just return it
|
35 |
+
if logger_name in initialized_logger:
|
36 |
+
return logger
|
37 |
+
|
38 |
+
format_str = '%(asctime)s %(levelname)s: %(message)s'
|
39 |
+
stream_handler = logging.StreamHandler()
|
40 |
+
stream_handler.setFormatter(logging.Formatter(format_str))
|
41 |
+
logger.addHandler(stream_handler)
|
42 |
+
logger.propagate = False
|
43 |
+
|
44 |
+
if log_file is not None:
|
45 |
+
logger.setLevel(log_level)
|
46 |
+
# add file handler
|
47 |
+
# file_handler = logging.FileHandler(log_file, 'w')
|
48 |
+
file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log
|
49 |
+
file_handler.setFormatter(logging.Formatter(format_str))
|
50 |
+
file_handler.setLevel(log_level)
|
51 |
+
logger.addHandler(file_handler)
|
52 |
+
initialized_logger[logger_name] = True
|
53 |
+
return logger
|
54 |
+
|
55 |
+
|
56 |
+
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
|
57 |
+
torch.__version__)[0][:3])] >= [1, 12, 0]
|
58 |
+
|
59 |
+
def gpu_is_available():
|
60 |
+
if IS_HIGH_VERSION:
|
61 |
+
if torch.backends.mps.is_available():
|
62 |
+
return True
|
63 |
+
return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
|
64 |
+
|
65 |
+
def get_device(gpu_id=None):
|
66 |
+
if gpu_id is None:
|
67 |
+
gpu_str = ''
|
68 |
+
elif isinstance(gpu_id, int):
|
69 |
+
gpu_str = f':{gpu_id}'
|
70 |
+
else:
|
71 |
+
raise TypeError('Input should be int value.')
|
72 |
+
|
73 |
+
if IS_HIGH_VERSION:
|
74 |
+
if torch.backends.mps.is_available():
|
75 |
+
return torch.device('mps'+gpu_str)
|
76 |
+
return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
|
77 |
+
|
78 |
+
|
79 |
+
def set_random_seed(seed):
|
80 |
+
"""Set random seeds."""
|
81 |
+
random.seed(seed)
|
82 |
+
np.random.seed(seed)
|
83 |
+
torch.manual_seed(seed)
|
84 |
+
torch.cuda.manual_seed(seed)
|
85 |
+
torch.cuda.manual_seed_all(seed)
|
86 |
+
|
87 |
+
|
88 |
+
def get_time_str():
|
89 |
+
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
90 |
+
|
91 |
+
|
92 |
+
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
93 |
+
"""Scan a directory to find the interested files.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
dir_path (str): Path of the directory.
|
97 |
+
suffix (str | tuple(str), optional): File suffix that we are
|
98 |
+
interested in. Default: None.
|
99 |
+
recursive (bool, optional): If set to True, recursively scan the
|
100 |
+
directory. Default: False.
|
101 |
+
full_path (bool, optional): If set to True, include the dir_path.
|
102 |
+
Default: False.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
A generator for all the interested files with relative pathes.
|
106 |
+
"""
|
107 |
+
|
108 |
+
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
109 |
+
raise TypeError('"suffix" must be a string or tuple of strings')
|
110 |
+
|
111 |
+
root = dir_path
|
112 |
+
|
113 |
+
def _scandir(dir_path, suffix, recursive):
|
114 |
+
for entry in os.scandir(dir_path):
|
115 |
+
if not entry.name.startswith('.') and entry.is_file():
|
116 |
+
if full_path:
|
117 |
+
return_path = entry.path
|
118 |
+
else:
|
119 |
+
return_path = osp.relpath(entry.path, root)
|
120 |
+
|
121 |
+
if suffix is None:
|
122 |
+
yield return_path
|
123 |
+
elif return_path.endswith(suffix):
|
124 |
+
yield return_path
|
125 |
+
else:
|
126 |
+
if recursive:
|
127 |
+
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
|
128 |
+
else:
|
129 |
+
continue
|
130 |
+
|
131 |
+
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
propainter/model/modules/base_module.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from functools import reduce
|
6 |
+
|
7 |
+
class BaseNetwork(nn.Module):
|
8 |
+
def __init__(self):
|
9 |
+
super(BaseNetwork, self).__init__()
|
10 |
+
|
11 |
+
def print_network(self):
|
12 |
+
if isinstance(self, list):
|
13 |
+
self = self[0]
|
14 |
+
num_params = 0
|
15 |
+
for param in self.parameters():
|
16 |
+
num_params += param.numel()
|
17 |
+
print(
|
18 |
+
'Network [%s] was created. Total number of parameters: %.1f million. '
|
19 |
+
'To see the architecture, do print(network).' %
|
20 |
+
(type(self).__name__, num_params / 1000000))
|
21 |
+
|
22 |
+
def init_weights(self, init_type='normal', gain=0.02):
|
23 |
+
'''
|
24 |
+
initialize network's weights
|
25 |
+
init_type: normal | xavier | kaiming | orthogonal
|
26 |
+
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
|
27 |
+
'''
|
28 |
+
def init_func(m):
|
29 |
+
classname = m.__class__.__name__
|
30 |
+
if classname.find('InstanceNorm2d') != -1:
|
31 |
+
if hasattr(m, 'weight') and m.weight is not None:
|
32 |
+
nn.init.constant_(m.weight.data, 1.0)
|
33 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
34 |
+
nn.init.constant_(m.bias.data, 0.0)
|
35 |
+
elif hasattr(m, 'weight') and (classname.find('Conv') != -1
|
36 |
+
or classname.find('Linear') != -1):
|
37 |
+
if init_type == 'normal':
|
38 |
+
nn.init.normal_(m.weight.data, 0.0, gain)
|
39 |
+
elif init_type == 'xavier':
|
40 |
+
nn.init.xavier_normal_(m.weight.data, gain=gain)
|
41 |
+
elif init_type == 'xavier_uniform':
|
42 |
+
nn.init.xavier_uniform_(m.weight.data, gain=1.0)
|
43 |
+
elif init_type == 'kaiming':
|
44 |
+
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
45 |
+
elif init_type == 'orthogonal':
|
46 |
+
nn.init.orthogonal_(m.weight.data, gain=gain)
|
47 |
+
elif init_type == 'none': # uses pytorch's default init method
|
48 |
+
m.reset_parameters()
|
49 |
+
else:
|
50 |
+
raise NotImplementedError(
|
51 |
+
'initialization method [%s] is not implemented' %
|
52 |
+
init_type)
|
53 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
54 |
+
nn.init.constant_(m.bias.data, 0.0)
|
55 |
+
|
56 |
+
self.apply(init_func)
|
57 |
+
|
58 |
+
# propagate to children
|
59 |
+
for m in self.children():
|
60 |
+
if hasattr(m, 'init_weights'):
|
61 |
+
m.init_weights(init_type, gain)
|
62 |
+
|
63 |
+
|
64 |
+
class Vec2Feat(nn.Module):
|
65 |
+
def __init__(self, channel, hidden, kernel_size, stride, padding):
|
66 |
+
super(Vec2Feat, self).__init__()
|
67 |
+
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
68 |
+
c_out = reduce((lambda x, y: x * y), kernel_size) * channel
|
69 |
+
self.embedding = nn.Linear(hidden, c_out)
|
70 |
+
self.kernel_size = kernel_size
|
71 |
+
self.stride = stride
|
72 |
+
self.padding = padding
|
73 |
+
self.bias_conv = nn.Conv2d(channel,
|
74 |
+
channel,
|
75 |
+
kernel_size=3,
|
76 |
+
stride=1,
|
77 |
+
padding=1)
|
78 |
+
|
79 |
+
def forward(self, x, t, output_size):
|
80 |
+
b_, _, _, _, c_ = x.shape
|
81 |
+
x = x.view(b_, -1, c_)
|
82 |
+
feat = self.embedding(x)
|
83 |
+
b, _, c = feat.size()
|
84 |
+
feat = feat.view(b * t, -1, c).permute(0, 2, 1)
|
85 |
+
feat = F.fold(feat,
|
86 |
+
output_size=output_size,
|
87 |
+
kernel_size=self.kernel_size,
|
88 |
+
stride=self.stride,
|
89 |
+
padding=self.padding)
|
90 |
+
feat = self.bias_conv(feat)
|
91 |
+
return feat
|
92 |
+
|
93 |
+
|
94 |
+
class FusionFeedForward(nn.Module):
|
95 |
+
def __init__(self, dim, hidden_dim=1960, t2t_params=None):
|
96 |
+
super(FusionFeedForward, self).__init__()
|
97 |
+
# We set hidden_dim as a default to 1960
|
98 |
+
self.fc1 = nn.Sequential(nn.Linear(dim, hidden_dim))
|
99 |
+
self.fc2 = nn.Sequential(nn.GELU(), nn.Linear(hidden_dim, dim))
|
100 |
+
assert t2t_params is not None
|
101 |
+
self.t2t_params = t2t_params
|
102 |
+
self.kernel_shape = reduce((lambda x, y: x * y), t2t_params['kernel_size']) # 49
|
103 |
+
|
104 |
+
def forward(self, x, output_size):
|
105 |
+
n_vecs = 1
|
106 |
+
for i, d in enumerate(self.t2t_params['kernel_size']):
|
107 |
+
n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] -
|
108 |
+
(d - 1) - 1) / self.t2t_params['stride'][i] + 1)
|
109 |
+
|
110 |
+
x = self.fc1(x)
|
111 |
+
b, n, c = x.size()
|
112 |
+
normalizer = x.new_ones(b, n, self.kernel_shape).view(-1, n_vecs, self.kernel_shape).permute(0, 2, 1)
|
113 |
+
normalizer = F.fold(normalizer,
|
114 |
+
output_size=output_size,
|
115 |
+
kernel_size=self.t2t_params['kernel_size'],
|
116 |
+
padding=self.t2t_params['padding'],
|
117 |
+
stride=self.t2t_params['stride'])
|
118 |
+
|
119 |
+
x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1),
|
120 |
+
output_size=output_size,
|
121 |
+
kernel_size=self.t2t_params['kernel_size'],
|
122 |
+
padding=self.t2t_params['padding'],
|
123 |
+
stride=self.t2t_params['stride'])
|
124 |
+
|
125 |
+
x = F.unfold(x / normalizer,
|
126 |
+
kernel_size=self.t2t_params['kernel_size'],
|
127 |
+
padding=self.t2t_params['padding'],
|
128 |
+
stride=self.t2t_params['stride']).permute(
|
129 |
+
0, 2, 1).contiguous().view(b, n, c)
|
130 |
+
x = self.fc2(x)
|
131 |
+
return x
|
propainter/model/modules/deformconv.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import init as init
|
4 |
+
from torch.nn.modules.utils import _pair, _single
|
5 |
+
import math
|
6 |
+
|
7 |
+
class ModulatedDeformConv2d(nn.Module):
|
8 |
+
def __init__(self,
|
9 |
+
in_channels,
|
10 |
+
out_channels,
|
11 |
+
kernel_size,
|
12 |
+
stride=1,
|
13 |
+
padding=0,
|
14 |
+
dilation=1,
|
15 |
+
groups=1,
|
16 |
+
deform_groups=1,
|
17 |
+
bias=True):
|
18 |
+
super(ModulatedDeformConv2d, self).__init__()
|
19 |
+
|
20 |
+
self.in_channels = in_channels
|
21 |
+
self.out_channels = out_channels
|
22 |
+
self.kernel_size = _pair(kernel_size)
|
23 |
+
self.stride = stride
|
24 |
+
self.padding = padding
|
25 |
+
self.dilation = dilation
|
26 |
+
self.groups = groups
|
27 |
+
self.deform_groups = deform_groups
|
28 |
+
self.with_bias = bias
|
29 |
+
# enable compatibility with nn.Conv2d
|
30 |
+
self.transposed = False
|
31 |
+
self.output_padding = _single(0)
|
32 |
+
|
33 |
+
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
|
34 |
+
if bias:
|
35 |
+
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
36 |
+
else:
|
37 |
+
self.register_parameter('bias', None)
|
38 |
+
self.init_weights()
|
39 |
+
|
40 |
+
def init_weights(self):
|
41 |
+
n = self.in_channels
|
42 |
+
for k in self.kernel_size:
|
43 |
+
n *= k
|
44 |
+
stdv = 1. / math.sqrt(n)
|
45 |
+
self.weight.data.uniform_(-stdv, stdv)
|
46 |
+
if self.bias is not None:
|
47 |
+
self.bias.data.zero_()
|
48 |
+
|
49 |
+
if hasattr(self, 'conv_offset'):
|
50 |
+
self.conv_offset.weight.data.zero_()
|
51 |
+
self.conv_offset.bias.data.zero_()
|
52 |
+
|
53 |
+
def forward(self, x, offset, mask):
|
54 |
+
pass
|