Spaces:
Running
Running
Upload 25 files
Browse files- LICENSE +313 -0
- cog.yaml +37 -0
- configs/test.yaml +20 -0
- constants.py +2 -0
- environment.yaml +16 -0
- inference.py +138 -0
- mimicmotion/__init__.py +0 -0
- mimicmotion/dwpose/.gitignore +1 -0
- mimicmotion/dwpose/__init__.py +0 -0
- mimicmotion/dwpose/dwpose_detector.py +72 -0
- mimicmotion/dwpose/onnxdet.py +145 -0
- mimicmotion/dwpose/onnxpose.py +375 -0
- mimicmotion/dwpose/preprocess.py +73 -0
- mimicmotion/dwpose/util.py +133 -0
- mimicmotion/dwpose/wholebody.py +57 -0
- mimicmotion/modules/__init__.py +0 -0
- mimicmotion/modules/attention.py +378 -0
- mimicmotion/modules/pose_net.py +80 -0
- mimicmotion/modules/unet.py +507 -0
- mimicmotion/pipelines/pipeline_mimicmotion.py +628 -0
- mimicmotion/utils/__init__.py +0 -0
- mimicmotion/utils/geglu_patch.py +9 -0
- mimicmotion/utils/loader.py +53 -0
- mimicmotion/utils/utils.py +12 -0
- predict.py +363 -0
LICENSE
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Tencent is pleased to support the open source community by making MimicMotion available.
|
2 |
+
|
3 |
+
Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) THL A29 Limited.
|
4 |
+
|
5 |
+
MimicMotion is licensed under the Apache License Version 2.0 except for the third-party components listed below.
|
6 |
+
|
7 |
+
|
8 |
+
Terms of the Apache License Version 2.0:
|
9 |
+
--------------------------------------------------------------------
|
10 |
+
Apache License
|
11 |
+
|
12 |
+
Version 2.0, January 2004
|
13 |
+
|
14 |
+
http://www.apache.org/licenses/
|
15 |
+
|
16 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
17 |
+
1. Definitions.
|
18 |
+
|
19 |
+
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
20 |
+
|
21 |
+
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
22 |
+
|
23 |
+
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
24 |
+
|
25 |
+
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
28 |
+
|
29 |
+
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
30 |
+
|
31 |
+
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
32 |
+
|
33 |
+
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
34 |
+
|
35 |
+
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
36 |
+
|
37 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
38 |
+
|
39 |
+
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
40 |
+
|
41 |
+
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
42 |
+
|
43 |
+
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
44 |
+
|
45 |
+
You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
46 |
+
|
47 |
+
You must cause any modified files to carry prominent notices stating that You changed the files; and
|
48 |
+
|
49 |
+
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
50 |
+
|
51 |
+
If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
52 |
+
|
53 |
+
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
54 |
+
|
55 |
+
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
56 |
+
|
57 |
+
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
58 |
+
|
59 |
+
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
60 |
+
|
61 |
+
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
62 |
+
|
63 |
+
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
64 |
+
|
65 |
+
END OF TERMS AND CONDITIONS
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
Other dependencies and licenses:
|
70 |
+
|
71 |
+
|
72 |
+
Open Source Software Licensed under the Apache License Version 2.0:
|
73 |
+
The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2023 THL A29 Limited.
|
74 |
+
--------------------------------------------------------------------
|
75 |
+
1. diffusers
|
76 |
+
Copyright (c) diffusers original author and authors
|
77 |
+
|
78 |
+
2. DWPose
|
79 |
+
Copyright 2018-2020 Open-MMLab.
|
80 |
+
Please note this software has been modified by Tencent in this distribution.
|
81 |
+
|
82 |
+
3. transformers
|
83 |
+
Copyright (c) transformers original author and authors
|
84 |
+
|
85 |
+
4. decord
|
86 |
+
Copyright (c) DWPoseoriginal author and authors
|
87 |
+
|
88 |
+
|
89 |
+
A copy of Apache 2.0 has been included in this file.
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
Open Source Software Licensed under the BSD 3-Clause License:
|
94 |
+
--------------------------------------------------------------------
|
95 |
+
1. torch
|
96 |
+
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
97 |
+
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
98 |
+
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
99 |
+
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
100 |
+
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
101 |
+
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
102 |
+
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
103 |
+
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
104 |
+
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
105 |
+
|
106 |
+
2. omegaconf
|
107 |
+
Copyright (c) 2018, Omry Yadan
|
108 |
+
All rights reserved.
|
109 |
+
|
110 |
+
3. torchvision
|
111 |
+
Copyright (c) Soumith Chintala 2016,
|
112 |
+
All rights reserved.
|
113 |
+
|
114 |
+
|
115 |
+
Terms of the BSD 3-Clause:
|
116 |
+
--------------------------------------------------------------------
|
117 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
118 |
+
|
119 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
120 |
+
|
121 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
122 |
+
|
123 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
124 |
+
|
125 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
|
130 |
+
--------------------------------------------------------------------
|
131 |
+
1. numpy
|
132 |
+
Copyright (c) 2005-2023, NumPy Developers.
|
133 |
+
All rights reserved.
|
134 |
+
|
135 |
+
A copy of the BSD 3-Clause is included in this file.
|
136 |
+
|
137 |
+
For the license of other third party components, please refer to the following URL:
|
138 |
+
https://github.com/numpy/numpy/blob/v1.26.3/LICENSES_bundled.txt
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
Open Source Software Licensed under the HPND License:
|
143 |
+
--------------------------------------------------------------------
|
144 |
+
1. Pillow
|
145 |
+
Copyright © 2010-2023 by Jeffrey A. Clark (Alex) and contributors.
|
146 |
+
|
147 |
+
|
148 |
+
Terms of the HPND License:
|
149 |
+
--------------------------------------------------------------------
|
150 |
+
The Python Imaging Library (PIL) is
|
151 |
+
|
152 |
+
Copyright © 1997-2011 by Secret Labs AB
|
153 |
+
Copyright © 1995-2011 by Fredrik Lundh
|
154 |
+
|
155 |
+
Pillow is the friendly PIL fork. It is
|
156 |
+
|
157 |
+
Copyright © 2010-2023 by Jeffrey A. Clark (Alex) and contributors.
|
158 |
+
|
159 |
+
Like PIL, Pillow is licensed under the open source HPND License:
|
160 |
+
|
161 |
+
By obtaining, using, and/or copying this software and/or its associated
|
162 |
+
documentation, you agree that you have read, understood, and will comply
|
163 |
+
with the following terms and conditions:
|
164 |
+
|
165 |
+
Permission to use, copy, modify and distribute this software and its
|
166 |
+
documentation for any purpose and without fee is hereby granted,
|
167 |
+
provided that the above copyright notice appears in all copies, and that
|
168 |
+
both that copyright notice and this permission notice appear in supporting
|
169 |
+
documentation, and that the name of Secret Labs AB or the author not be
|
170 |
+
used in advertising or publicity pertaining to distribution of the software
|
171 |
+
without specific, written prior permission.
|
172 |
+
|
173 |
+
SECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS
|
174 |
+
SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS.
|
175 |
+
IN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR BE LIABLE FOR ANY SPECIAL,
|
176 |
+
INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
|
177 |
+
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
|
178 |
+
OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
|
179 |
+
PERFORMANCE OF THIS SOFTWARE.
|
180 |
+
|
181 |
+
|
182 |
+
Open Source Software Licensed under the Matplotlib License and Other Licenses of the Third-Party Components therein:
|
183 |
+
--------------------------------------------------------------------
|
184 |
+
1. matplotlib
|
185 |
+
Copyright (c)
|
186 |
+
2012- Matplotlib Development Team; All Rights Reserved
|
187 |
+
|
188 |
+
|
189 |
+
Terms of the Matplotlib License:
|
190 |
+
--------------------------------------------------------------------
|
191 |
+
License agreement for matplotlib versions 1.3.0 and later
|
192 |
+
=========================================================
|
193 |
+
|
194 |
+
1. This LICENSE AGREEMENT is between the Matplotlib Development Team
|
195 |
+
("MDT"), and the Individual or Organization ("Licensee") accessing and
|
196 |
+
otherwise using matplotlib software in source or binary form and its
|
197 |
+
associated documentation.
|
198 |
+
|
199 |
+
2. Subject to the terms and conditions of this License Agreement, MDT
|
200 |
+
hereby grants Licensee a nonexclusive, royalty-free, world-wide license
|
201 |
+
to reproduce, analyze, test, perform and/or display publicly, prepare
|
202 |
+
derivative works, distribute, and otherwise use matplotlib
|
203 |
+
alone or in any derivative version, provided, however, that MDT's
|
204 |
+
License Agreement and MDT's notice of copyright, i.e., "Copyright (c)
|
205 |
+
2012- Matplotlib Development Team; All Rights Reserved" are retained in
|
206 |
+
matplotlib alone or in any derivative version prepared by
|
207 |
+
Licensee.
|
208 |
+
|
209 |
+
3. In the event Licensee prepares a derivative work that is based on or
|
210 |
+
incorporates matplotlib or any part thereof, and wants to
|
211 |
+
make the derivative work available to others as provided herein, then
|
212 |
+
Licensee hereby agrees to include in any such work a brief summary of
|
213 |
+
the changes made to matplotlib .
|
214 |
+
|
215 |
+
4. MDT is making matplotlib available to Licensee on an "AS
|
216 |
+
IS" basis. MDT MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
|
217 |
+
IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, MDT MAKES NO AND
|
218 |
+
DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
|
219 |
+
FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF MATPLOTLIB
|
220 |
+
WILL NOT INFRINGE ANY THIRD PARTY RIGHTS.
|
221 |
+
|
222 |
+
5. MDT SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF MATPLOTLIB
|
223 |
+
FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR
|
224 |
+
LOSS AS A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING
|
225 |
+
MATPLOTLIB , OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF
|
226 |
+
THE POSSIBILITY THEREOF.
|
227 |
+
|
228 |
+
6. This License Agreement will automatically terminate upon a material
|
229 |
+
breach of its terms and conditions.
|
230 |
+
|
231 |
+
7. Nothing in this License Agreement shall be deemed to create any
|
232 |
+
relationship of agency, partnership, or joint venture between MDT and
|
233 |
+
Licensee. This License Agreement does not grant permission to use MDT
|
234 |
+
trademarks or trade name in a trademark sense to endorse or promote
|
235 |
+
products or services of Licensee, or any third party.
|
236 |
+
|
237 |
+
8. By copying, installing or otherwise using matplotlib ,
|
238 |
+
Licensee agrees to be bound by the terms and conditions of this License
|
239 |
+
Agreement.
|
240 |
+
|
241 |
+
License agreement for matplotlib versions prior to 1.3.0
|
242 |
+
========================================================
|
243 |
+
|
244 |
+
1. This LICENSE AGREEMENT is between John D. Hunter ("JDH"), and the
|
245 |
+
Individual or Organization ("Licensee") accessing and otherwise using
|
246 |
+
matplotlib software in source or binary form and its associated
|
247 |
+
documentation.
|
248 |
+
|
249 |
+
2. Subject to the terms and conditions of this License Agreement, JDH
|
250 |
+
hereby grants Licensee a nonexclusive, royalty-free, world-wide license
|
251 |
+
to reproduce, analyze, test, perform and/or display publicly, prepare
|
252 |
+
derivative works, distribute, and otherwise use matplotlib
|
253 |
+
alone or in any derivative version, provided, however, that JDH's
|
254 |
+
License Agreement and JDH's notice of copyright, i.e., "Copyright (c)
|
255 |
+
2002-2011 John D. Hunter; All Rights Reserved" are retained in
|
256 |
+
matplotlib alone or in any derivative version prepared by
|
257 |
+
Licensee.
|
258 |
+
|
259 |
+
3. In the event Licensee prepares a derivative work that is based on or
|
260 |
+
incorporates matplotlib or any part thereof, and wants to
|
261 |
+
make the derivative work available to others as provided herein, then
|
262 |
+
Licensee hereby agrees to include in any such work a brief summary of
|
263 |
+
the changes made to matplotlib.
|
264 |
+
|
265 |
+
4. JDH is making matplotlib available to Licensee on an "AS
|
266 |
+
IS" basis. JDH MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
|
267 |
+
IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, JDH MAKES NO AND
|
268 |
+
DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
|
269 |
+
FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF MATPLOTLIB
|
270 |
+
WILL NOT INFRINGE ANY THIRD PARTY RIGHTS.
|
271 |
+
|
272 |
+
5. JDH SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF MATPLOTLIB
|
273 |
+
FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR
|
274 |
+
LOSS AS A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING
|
275 |
+
MATPLOTLIB , OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF
|
276 |
+
THE POSSIBILITY THEREOF.
|
277 |
+
|
278 |
+
6. This License Agreement will automatically terminate upon a material
|
279 |
+
breach of its terms and conditions.
|
280 |
+
|
281 |
+
7. Nothing in this License Agreement shall be deemed to create any
|
282 |
+
relationship of agency, partnership, or joint venture between JDH and
|
283 |
+
Licensee. This License Agreement does not grant permission to use JDH
|
284 |
+
trademarks or trade name in a trademark sense to endorse or promote
|
285 |
+
products or services of Licensee, or any third party.
|
286 |
+
|
287 |
+
8. By copying, installing or otherwise using matplotlib,
|
288 |
+
Licensee agrees to be bound by the terms and conditions of this License
|
289 |
+
Agreement.
|
290 |
+
|
291 |
+
For the license of other third party components, please refer to the following URL:
|
292 |
+
https://github.com/matplotlib/matplotlib/tree/v3.8.0/LICENSE
|
293 |
+
|
294 |
+
|
295 |
+
Open Source Software Licensed under the MIT License:
|
296 |
+
--------------------------------------------------------------------
|
297 |
+
1. einops
|
298 |
+
Copyright (c) 2018 Alex Rogozhnikov
|
299 |
+
|
300 |
+
2. onnxruntime
|
301 |
+
Copyright (c) Microsoft Corporation
|
302 |
+
|
303 |
+
3. OpenCV
|
304 |
+
Copyright (c) Olli-Pekka Heinisuo
|
305 |
+
|
306 |
+
|
307 |
+
Terms of the MIT License:
|
308 |
+
--------------------------------------------------------------------
|
309 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
310 |
+
|
311 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
312 |
+
|
313 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
cog.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration for Cog ⚙️
|
2 |
+
# Reference: https://cog.run/yaml
|
3 |
+
|
4 |
+
build:
|
5 |
+
# set to true if your model requires a GPU
|
6 |
+
gpu: true
|
7 |
+
# cuda: "11.7"
|
8 |
+
|
9 |
+
# a list of ubuntu apt packages to install
|
10 |
+
system_packages:
|
11 |
+
- "libgl1-mesa-glx"
|
12 |
+
- "libglib2.0-0"
|
13 |
+
|
14 |
+
# python version in the form '3.11' or '3.11.4'
|
15 |
+
python_version: "3.11"
|
16 |
+
|
17 |
+
# a list of packages in the format <package-name>==<version>
|
18 |
+
python_packages:
|
19 |
+
- "torch>=2.3" # 2.3.1
|
20 |
+
- "torchvision>=0.18" # 0.18.1
|
21 |
+
- "diffusers>=0.29" # 0.29.2
|
22 |
+
- "transformers>=4.42" # 4.42.3
|
23 |
+
- "decord>=0.6" # 0.6.0
|
24 |
+
- "einops>=0.8" # 0.8.0
|
25 |
+
- "omegaconf>=2.3" # 2.3.0
|
26 |
+
- "opencv-python>=4.10" # 4.10.0.84
|
27 |
+
- "matplotlib>=3.9" # 3.9.1
|
28 |
+
- "onnxruntime>=1.18" # 1.18.1
|
29 |
+
- "accelerate>=0.32" # 0.32.0
|
30 |
+
- "av>=12.2" # 12.2.0, https://github.com/continue-revolution/sd-webui-animatediff/issues/377
|
31 |
+
|
32 |
+
# commands run after the environment is setup
|
33 |
+
run:
|
34 |
+
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
|
35 |
+
|
36 |
+
# predict.py defines how predictions are run on your model
|
37 |
+
predict: "predict.py:Predictor"
|
configs/test.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# base svd model path
|
2 |
+
base_model_path: stabilityai/stable-video-diffusion-img2vid-xt-1-1
|
3 |
+
|
4 |
+
# checkpoint path
|
5 |
+
ckpt_path: models/MimicMotion_1-1.pth
|
6 |
+
|
7 |
+
test_case:
|
8 |
+
- ref_video_path: assets/example_data/videos/pose1.mp4
|
9 |
+
ref_image_path: assets/example_data/images/demo1.jpg
|
10 |
+
num_frames: 72
|
11 |
+
resolution: 576
|
12 |
+
frames_overlap: 6
|
13 |
+
num_inference_steps: 25
|
14 |
+
noise_aug_strength: 0
|
15 |
+
guidance_scale: 2.0
|
16 |
+
sample_stride: 2
|
17 |
+
fps: 15
|
18 |
+
seed: 42
|
19 |
+
|
20 |
+
|
constants.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# w/h apsect ratio
|
2 |
+
ASPECT_RATIO = 9 / 16
|
environment.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: mimicmotion
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
dependencies:
|
6 |
+
- python=3.11
|
7 |
+
- pytorch=2.0.1
|
8 |
+
- torchvision=0.15.2
|
9 |
+
- pytorch-cuda=11.7
|
10 |
+
- pip
|
11 |
+
- pip:
|
12 |
+
- diffusers==0.27.0
|
13 |
+
- transformers==4.32.1
|
14 |
+
- decord==0.6.0
|
15 |
+
- einops
|
16 |
+
- omegaconf
|
inference.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import logging
|
4 |
+
import math
|
5 |
+
from omegaconf import OmegaConf
|
6 |
+
from datetime import datetime
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch.jit
|
11 |
+
from torchvision.datasets.folder import pil_loader
|
12 |
+
from torchvision.transforms.functional import pil_to_tensor, resize, center_crop
|
13 |
+
from torchvision.transforms.functional import to_pil_image
|
14 |
+
|
15 |
+
|
16 |
+
from mimicmotion.utils.geglu_patch import patch_geglu_inplace
|
17 |
+
patch_geglu_inplace()
|
18 |
+
|
19 |
+
from constants import ASPECT_RATIO
|
20 |
+
|
21 |
+
from mimicmotion.pipelines.pipeline_mimicmotion import MimicMotionPipeline
|
22 |
+
from mimicmotion.utils.loader import create_pipeline
|
23 |
+
from mimicmotion.utils.utils import save_to_mp4
|
24 |
+
from mimicmotion.dwpose.preprocess import get_video_pose, get_image_pose
|
25 |
+
|
26 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s: [%(levelname)s] %(message)s")
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
29 |
+
|
30 |
+
|
31 |
+
def preprocess(video_path, image_path, resolution=576, sample_stride=2):
|
32 |
+
"""preprocess ref image pose and video pose
|
33 |
+
|
34 |
+
Args:
|
35 |
+
video_path (str): input video pose path
|
36 |
+
image_path (str): reference image path
|
37 |
+
resolution (int, optional): Defaults to 576.
|
38 |
+
sample_stride (int, optional): Defaults to 2.
|
39 |
+
"""
|
40 |
+
image_pixels = pil_loader(image_path)
|
41 |
+
image_pixels = pil_to_tensor(image_pixels) # (c, h, w)
|
42 |
+
h, w = image_pixels.shape[-2:]
|
43 |
+
############################ compute target h/w according to original aspect ratio ###############################
|
44 |
+
if h>w:
|
45 |
+
w_target, h_target = resolution, int(resolution / ASPECT_RATIO // 64) * 64
|
46 |
+
else:
|
47 |
+
w_target, h_target = int(resolution / ASPECT_RATIO // 64) * 64, resolution
|
48 |
+
h_w_ratio = float(h) / float(w)
|
49 |
+
if h_w_ratio < h_target / w_target:
|
50 |
+
h_resize, w_resize = h_target, math.ceil(h_target / h_w_ratio)
|
51 |
+
else:
|
52 |
+
h_resize, w_resize = math.ceil(w_target * h_w_ratio), w_target
|
53 |
+
image_pixels = resize(image_pixels, [h_resize, w_resize], antialias=None)
|
54 |
+
image_pixels = center_crop(image_pixels, [h_target, w_target])
|
55 |
+
image_pixels = image_pixels.permute((1, 2, 0)).numpy()
|
56 |
+
##################################### get image&video pose value #################################################
|
57 |
+
image_pose = get_image_pose(image_pixels)
|
58 |
+
video_pose = get_video_pose(video_path, image_pixels, sample_stride=sample_stride)
|
59 |
+
pose_pixels = np.concatenate([np.expand_dims(image_pose, 0), video_pose])
|
60 |
+
image_pixels = np.transpose(np.expand_dims(image_pixels, 0), (0, 3, 1, 2))
|
61 |
+
return torch.from_numpy(pose_pixels.copy()) / 127.5 - 1, torch.from_numpy(image_pixels) / 127.5 - 1
|
62 |
+
|
63 |
+
|
64 |
+
def run_pipeline(pipeline: MimicMotionPipeline, image_pixels, pose_pixels, device, task_config):
|
65 |
+
image_pixels = [to_pil_image(img.to(torch.uint8)) for img in (image_pixels + 1.0) * 127.5]
|
66 |
+
generator = torch.Generator(device=device)
|
67 |
+
generator.manual_seed(task_config.seed)
|
68 |
+
frames = pipeline(
|
69 |
+
image_pixels, image_pose=pose_pixels, num_frames=pose_pixels.size(0),
|
70 |
+
tile_size=task_config.num_frames, tile_overlap=task_config.frames_overlap,
|
71 |
+
height=pose_pixels.shape[-2], width=pose_pixels.shape[-1], fps=7,
|
72 |
+
noise_aug_strength=task_config.noise_aug_strength, num_inference_steps=task_config.num_inference_steps,
|
73 |
+
generator=generator, min_guidance_scale=task_config.guidance_scale,
|
74 |
+
max_guidance_scale=task_config.guidance_scale, decode_chunk_size=8, output_type="pt", device=device
|
75 |
+
).frames.cpu()
|
76 |
+
video_frames = (frames * 255.0).to(torch.uint8)
|
77 |
+
|
78 |
+
for vid_idx in range(video_frames.shape[0]):
|
79 |
+
# deprecated first frame because of ref image
|
80 |
+
_video_frames = video_frames[vid_idx, 1:]
|
81 |
+
|
82 |
+
return _video_frames
|
83 |
+
|
84 |
+
|
85 |
+
@torch.no_grad()
|
86 |
+
def main(args):
|
87 |
+
if not args.no_use_float16 :
|
88 |
+
torch.set_default_dtype(torch.float16)
|
89 |
+
|
90 |
+
infer_config = OmegaConf.load(args.inference_config)
|
91 |
+
pipeline = create_pipeline(infer_config, device)
|
92 |
+
|
93 |
+
for task in infer_config.test_case:
|
94 |
+
############################################## Pre-process data ##############################################
|
95 |
+
pose_pixels, image_pixels = preprocess(
|
96 |
+
task.ref_video_path, task.ref_image_path,
|
97 |
+
resolution=task.resolution, sample_stride=task.sample_stride
|
98 |
+
)
|
99 |
+
########################################### Run MimicMotion pipeline ###########################################
|
100 |
+
_video_frames = run_pipeline(
|
101 |
+
pipeline,
|
102 |
+
image_pixels, pose_pixels,
|
103 |
+
device, task
|
104 |
+
)
|
105 |
+
################################### save results to output folder. ###########################################
|
106 |
+
save_to_mp4(
|
107 |
+
_video_frames,
|
108 |
+
f"{args.output_dir}/{os.path.basename(task.ref_video_path).split('.')[0]}" \
|
109 |
+
f"_{datetime.now().strftime('%Y%m%d%H%M%S')}.mp4",
|
110 |
+
fps=task.fps,
|
111 |
+
)
|
112 |
+
|
113 |
+
def set_logger(log_file=None, log_level=logging.INFO):
|
114 |
+
log_handler = logging.FileHandler(log_file, "w")
|
115 |
+
log_handler.setFormatter(
|
116 |
+
logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s]: %(message)s")
|
117 |
+
)
|
118 |
+
log_handler.setLevel(log_level)
|
119 |
+
logger.addHandler(log_handler)
|
120 |
+
|
121 |
+
|
122 |
+
if __name__ == "__main__":
|
123 |
+
parser = argparse.ArgumentParser()
|
124 |
+
parser.add_argument("--log_file", type=str, default=None)
|
125 |
+
parser.add_argument("--inference_config", type=str, default="configs/test.yaml") #ToDo
|
126 |
+
parser.add_argument("--output_dir", type=str, default="outputs/", help="path to output")
|
127 |
+
parser.add_argument("--no_use_float16",
|
128 |
+
action="store_true",
|
129 |
+
help="Whether use float16 to speed up inference",
|
130 |
+
)
|
131 |
+
args = parser.parse_args()
|
132 |
+
|
133 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
134 |
+
set_logger(args.log_file \
|
135 |
+
if args.log_file is not None else f"{args.output_dir}/{datetime.now().strftime('%Y%m%d%H%M%S')}.log")
|
136 |
+
main(args)
|
137 |
+
logger.info(f"--- Finished ---")
|
138 |
+
|
mimicmotion/__init__.py
ADDED
File without changes
|
mimicmotion/dwpose/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.pyc
|
mimicmotion/dwpose/__init__.py
ADDED
File without changes
|
mimicmotion/dwpose/dwpose_detector.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from .wholebody import Wholebody
|
7 |
+
|
8 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
9 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
+
|
11 |
+
class DWposeDetector:
|
12 |
+
"""
|
13 |
+
A pose detect method for image-like data.
|
14 |
+
|
15 |
+
Parameters:
|
16 |
+
model_det: (str) serialized ONNX format model path,
|
17 |
+
such as https://huggingface.co/yzd-v/DWPose/blob/main/yolox_l.onnx
|
18 |
+
model_pose: (str) serialized ONNX format model path,
|
19 |
+
such as https://huggingface.co/yzd-v/DWPose/blob/main/dw-ll_ucoco_384.onnx
|
20 |
+
device: (str) 'cpu' or 'cuda:{device_id}'
|
21 |
+
"""
|
22 |
+
def __init__(self, model_det, model_pose, device='cpu'):
|
23 |
+
self.args = model_det, model_pose, device
|
24 |
+
|
25 |
+
def release_memory(self):
|
26 |
+
if hasattr(self, 'pose_estimation'):
|
27 |
+
del self.pose_estimation
|
28 |
+
import gc; gc.collect()
|
29 |
+
|
30 |
+
def __call__(self, oriImg):
|
31 |
+
if not hasattr(self, 'pose_estimation'):
|
32 |
+
self.pose_estimation = Wholebody(*self.args)
|
33 |
+
|
34 |
+
oriImg = oriImg.copy()
|
35 |
+
H, W, C = oriImg.shape
|
36 |
+
with torch.no_grad():
|
37 |
+
candidate, score = self.pose_estimation(oriImg)
|
38 |
+
nums, _, locs = candidate.shape
|
39 |
+
candidate[..., 0] /= float(W)
|
40 |
+
candidate[..., 1] /= float(H)
|
41 |
+
body = candidate[:, :18].copy()
|
42 |
+
body = body.reshape(nums * 18, locs)
|
43 |
+
subset = score[:, :18].copy()
|
44 |
+
for i in range(len(subset)):
|
45 |
+
for j in range(len(subset[i])):
|
46 |
+
if subset[i][j] > 0.3:
|
47 |
+
subset[i][j] = int(18 * i + j)
|
48 |
+
else:
|
49 |
+
subset[i][j] = -1
|
50 |
+
|
51 |
+
# un_visible = subset < 0.3
|
52 |
+
# candidate[un_visible] = -1
|
53 |
+
|
54 |
+
# foot = candidate[:, 18:24]
|
55 |
+
|
56 |
+
faces = candidate[:, 24:92]
|
57 |
+
|
58 |
+
hands = candidate[:, 92:113]
|
59 |
+
hands = np.vstack([hands, candidate[:, 113:]])
|
60 |
+
|
61 |
+
faces_score = score[:, 24:92]
|
62 |
+
hands_score = np.vstack([score[:, 92:113], score[:, 113:]])
|
63 |
+
|
64 |
+
bodies = dict(candidate=body, subset=subset, score=score[:, :18])
|
65 |
+
pose = dict(bodies=bodies, hands=hands, hands_score=hands_score, faces=faces, faces_score=faces_score)
|
66 |
+
|
67 |
+
return pose
|
68 |
+
|
69 |
+
dwpose_detector = DWposeDetector(
|
70 |
+
model_det="models/DWPose/yolox_l.onnx",
|
71 |
+
model_pose="models/DWPose/dw-ll_ucoco_384.onnx",
|
72 |
+
device=device)
|
mimicmotion/dwpose/onnxdet.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def nms(boxes, scores, nms_thr):
|
6 |
+
"""Single class NMS implemented in Numpy.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
boxes (np.ndarray): shape=(N,4); N is number of boxes
|
10 |
+
scores (np.ndarray): the score of bboxes
|
11 |
+
nms_thr (float): the threshold in NMS
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
List[int]: output bbox ids
|
15 |
+
"""
|
16 |
+
x1 = boxes[:, 0]
|
17 |
+
y1 = boxes[:, 1]
|
18 |
+
x2 = boxes[:, 2]
|
19 |
+
y2 = boxes[:, 3]
|
20 |
+
|
21 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
22 |
+
order = scores.argsort()[::-1]
|
23 |
+
|
24 |
+
keep = []
|
25 |
+
while order.size > 0:
|
26 |
+
i = order[0]
|
27 |
+
keep.append(i)
|
28 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
29 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
30 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
31 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
32 |
+
|
33 |
+
w = np.maximum(0.0, xx2 - xx1 + 1)
|
34 |
+
h = np.maximum(0.0, yy2 - yy1 + 1)
|
35 |
+
inter = w * h
|
36 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
37 |
+
|
38 |
+
inds = np.where(ovr <= nms_thr)[0]
|
39 |
+
order = order[inds + 1]
|
40 |
+
|
41 |
+
return keep
|
42 |
+
|
43 |
+
def multiclass_nms(boxes, scores, nms_thr, score_thr):
|
44 |
+
"""Multiclass NMS implemented in Numpy. Class-aware version.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
boxes (np.ndarray): shape=(N,4); N is number of boxes
|
48 |
+
scores (np.ndarray): the score of bboxes
|
49 |
+
nms_thr (float): the threshold in NMS
|
50 |
+
score_thr (float): the threshold of cls score
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
np.ndarray: outputs bboxes coordinate
|
54 |
+
"""
|
55 |
+
final_dets = []
|
56 |
+
num_classes = scores.shape[1]
|
57 |
+
for cls_ind in range(num_classes):
|
58 |
+
cls_scores = scores[:, cls_ind]
|
59 |
+
valid_score_mask = cls_scores > score_thr
|
60 |
+
if valid_score_mask.sum() == 0:
|
61 |
+
continue
|
62 |
+
else:
|
63 |
+
valid_scores = cls_scores[valid_score_mask]
|
64 |
+
valid_boxes = boxes[valid_score_mask]
|
65 |
+
keep = nms(valid_boxes, valid_scores, nms_thr)
|
66 |
+
if len(keep) > 0:
|
67 |
+
cls_inds = np.ones((len(keep), 1)) * cls_ind
|
68 |
+
dets = np.concatenate(
|
69 |
+
[valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
|
70 |
+
)
|
71 |
+
final_dets.append(dets)
|
72 |
+
if len(final_dets) == 0:
|
73 |
+
return None
|
74 |
+
return np.concatenate(final_dets, 0)
|
75 |
+
|
76 |
+
def demo_postprocess(outputs, img_size, p6=False):
|
77 |
+
grids = []
|
78 |
+
expanded_strides = []
|
79 |
+
strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
|
80 |
+
|
81 |
+
hsizes = [img_size[0] // stride for stride in strides]
|
82 |
+
wsizes = [img_size[1] // stride for stride in strides]
|
83 |
+
|
84 |
+
for hsize, wsize, stride in zip(hsizes, wsizes, strides):
|
85 |
+
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
|
86 |
+
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
|
87 |
+
grids.append(grid)
|
88 |
+
shape = grid.shape[:2]
|
89 |
+
expanded_strides.append(np.full((*shape, 1), stride))
|
90 |
+
|
91 |
+
grids = np.concatenate(grids, 1)
|
92 |
+
expanded_strides = np.concatenate(expanded_strides, 1)
|
93 |
+
outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
|
94 |
+
outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
|
95 |
+
|
96 |
+
return outputs
|
97 |
+
|
98 |
+
def preprocess(img, input_size, swap=(2, 0, 1)):
|
99 |
+
if len(img.shape) == 3:
|
100 |
+
padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
|
101 |
+
else:
|
102 |
+
padded_img = np.ones(input_size, dtype=np.uint8) * 114
|
103 |
+
|
104 |
+
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
|
105 |
+
resized_img = cv2.resize(
|
106 |
+
img,
|
107 |
+
(int(img.shape[1] * r), int(img.shape[0] * r)),
|
108 |
+
interpolation=cv2.INTER_LINEAR,
|
109 |
+
).astype(np.uint8)
|
110 |
+
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
|
111 |
+
|
112 |
+
padded_img = padded_img.transpose(swap)
|
113 |
+
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
|
114 |
+
return padded_img, r
|
115 |
+
|
116 |
+
def inference_detector(session, oriImg):
|
117 |
+
"""run human detect
|
118 |
+
"""
|
119 |
+
input_shape = (640,640)
|
120 |
+
img, ratio = preprocess(oriImg, input_shape)
|
121 |
+
|
122 |
+
ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
|
123 |
+
output = session.run(None, ort_inputs)
|
124 |
+
predictions = demo_postprocess(output[0], input_shape)[0]
|
125 |
+
|
126 |
+
boxes = predictions[:, :4]
|
127 |
+
scores = predictions[:, 4:5] * predictions[:, 5:]
|
128 |
+
|
129 |
+
boxes_xyxy = np.ones_like(boxes)
|
130 |
+
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
|
131 |
+
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
|
132 |
+
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
|
133 |
+
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
|
134 |
+
boxes_xyxy /= ratio
|
135 |
+
dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
|
136 |
+
if dets is not None:
|
137 |
+
final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
|
138 |
+
isscore = final_scores>0.3
|
139 |
+
iscat = final_cls_inds == 0
|
140 |
+
isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
|
141 |
+
final_boxes = final_boxes[isbbox]
|
142 |
+
else:
|
143 |
+
final_boxes = np.array([])
|
144 |
+
|
145 |
+
return final_boxes
|
mimicmotion/dwpose/onnxpose.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import onnxruntime as ort
|
6 |
+
|
7 |
+
def preprocess(
|
8 |
+
img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
|
9 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
10 |
+
"""Do preprocessing for RTMPose model inference.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
img (np.ndarray): Input image in shape.
|
14 |
+
input_size (tuple): Input image size in shape (w, h).
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
tuple:
|
18 |
+
- resized_img (np.ndarray): Preprocessed image.
|
19 |
+
- center (np.ndarray): Center of image.
|
20 |
+
- scale (np.ndarray): Scale of image.
|
21 |
+
"""
|
22 |
+
# get shape of image
|
23 |
+
img_shape = img.shape[:2]
|
24 |
+
out_img, out_center, out_scale = [], [], []
|
25 |
+
if len(out_bbox) == 0:
|
26 |
+
out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
|
27 |
+
for i in range(len(out_bbox)):
|
28 |
+
x0 = out_bbox[i][0]
|
29 |
+
y0 = out_bbox[i][1]
|
30 |
+
x1 = out_bbox[i][2]
|
31 |
+
y1 = out_bbox[i][3]
|
32 |
+
bbox = np.array([x0, y0, x1, y1])
|
33 |
+
|
34 |
+
# get center and scale
|
35 |
+
center, scale = bbox_xyxy2cs(bbox, padding=1.25)
|
36 |
+
|
37 |
+
# do affine transformation
|
38 |
+
resized_img, scale = top_down_affine(input_size, scale, center, img)
|
39 |
+
|
40 |
+
# normalize image
|
41 |
+
mean = np.array([123.675, 116.28, 103.53])
|
42 |
+
std = np.array([58.395, 57.12, 57.375])
|
43 |
+
resized_img = (resized_img - mean) / std
|
44 |
+
|
45 |
+
out_img.append(resized_img)
|
46 |
+
out_center.append(center)
|
47 |
+
out_scale.append(scale)
|
48 |
+
|
49 |
+
return out_img, out_center, out_scale
|
50 |
+
|
51 |
+
|
52 |
+
def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
|
53 |
+
"""Inference RTMPose model.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
sess (ort.InferenceSession): ONNXRuntime session.
|
57 |
+
img (np.ndarray): Input image in shape.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
outputs (np.ndarray): Output of RTMPose model.
|
61 |
+
"""
|
62 |
+
all_out = []
|
63 |
+
# build input
|
64 |
+
for i in range(len(img)):
|
65 |
+
input = [img[i].transpose(2, 0, 1)]
|
66 |
+
|
67 |
+
# build output
|
68 |
+
sess_input = {sess.get_inputs()[0].name: input}
|
69 |
+
sess_output = []
|
70 |
+
for out in sess.get_outputs():
|
71 |
+
sess_output.append(out.name)
|
72 |
+
|
73 |
+
# run model
|
74 |
+
outputs = sess.run(sess_output, sess_input)
|
75 |
+
all_out.append(outputs)
|
76 |
+
|
77 |
+
return all_out
|
78 |
+
|
79 |
+
|
80 |
+
def postprocess(outputs: List[np.ndarray],
|
81 |
+
model_input_size: Tuple[int, int],
|
82 |
+
center: Tuple[int, int],
|
83 |
+
scale: Tuple[int, int],
|
84 |
+
simcc_split_ratio: float = 2.0
|
85 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
86 |
+
"""Postprocess for RTMPose model output.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
outputs (np.ndarray): Output of RTMPose model.
|
90 |
+
model_input_size (tuple): RTMPose model Input image size.
|
91 |
+
center (tuple): Center of bbox in shape (x, y).
|
92 |
+
scale (tuple): Scale of bbox in shape (w, h).
|
93 |
+
simcc_split_ratio (float): Split ratio of simcc.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
tuple:
|
97 |
+
- keypoints (np.ndarray): Rescaled keypoints.
|
98 |
+
- scores (np.ndarray): Model predict scores.
|
99 |
+
"""
|
100 |
+
all_key = []
|
101 |
+
all_score = []
|
102 |
+
for i in range(len(outputs)):
|
103 |
+
# use simcc to decode
|
104 |
+
simcc_x, simcc_y = outputs[i]
|
105 |
+
keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
|
106 |
+
|
107 |
+
# rescale keypoints
|
108 |
+
keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
|
109 |
+
all_key.append(keypoints[0])
|
110 |
+
all_score.append(scores[0])
|
111 |
+
|
112 |
+
return np.array(all_key), np.array(all_score)
|
113 |
+
|
114 |
+
|
115 |
+
def bbox_xyxy2cs(bbox: np.ndarray,
|
116 |
+
padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
|
117 |
+
"""Transform the bbox format from (x,y,w,h) into (center, scale)
|
118 |
+
|
119 |
+
Args:
|
120 |
+
bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
|
121 |
+
as (left, top, right, bottom)
|
122 |
+
padding (float): BBox padding factor that will be multilied to scale.
|
123 |
+
Default: 1.0
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
tuple: A tuple containing center and scale.
|
127 |
+
- np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
|
128 |
+
(n, 2)
|
129 |
+
- np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
|
130 |
+
(n, 2)
|
131 |
+
"""
|
132 |
+
# convert single bbox from (4, ) to (1, 4)
|
133 |
+
dim = bbox.ndim
|
134 |
+
if dim == 1:
|
135 |
+
bbox = bbox[None, :]
|
136 |
+
|
137 |
+
# get bbox center and scale
|
138 |
+
x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
|
139 |
+
center = np.hstack([x1 + x2, y1 + y2]) * 0.5
|
140 |
+
scale = np.hstack([x2 - x1, y2 - y1]) * padding
|
141 |
+
|
142 |
+
if dim == 1:
|
143 |
+
center = center[0]
|
144 |
+
scale = scale[0]
|
145 |
+
|
146 |
+
return center, scale
|
147 |
+
|
148 |
+
|
149 |
+
def _fix_aspect_ratio(bbox_scale: np.ndarray,
|
150 |
+
aspect_ratio: float) -> np.ndarray:
|
151 |
+
"""Extend the scale to match the given aspect ratio.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
scale (np.ndarray): The image scale (w, h) in shape (2, )
|
155 |
+
aspect_ratio (float): The ratio of ``w/h``
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
np.ndarray: The reshaped image scale in (2, )
|
159 |
+
"""
|
160 |
+
w, h = np.hsplit(bbox_scale, [1])
|
161 |
+
bbox_scale = np.where(w > h * aspect_ratio,
|
162 |
+
np.hstack([w, w / aspect_ratio]),
|
163 |
+
np.hstack([h * aspect_ratio, h]))
|
164 |
+
return bbox_scale
|
165 |
+
|
166 |
+
|
167 |
+
def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
|
168 |
+
"""Rotate a point by an angle.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
|
172 |
+
angle_rad (float): rotation angle in radian
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
np.ndarray: Rotated point in shape (2, )
|
176 |
+
"""
|
177 |
+
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
|
178 |
+
rot_mat = np.array([[cs, -sn], [sn, cs]])
|
179 |
+
return rot_mat @ pt
|
180 |
+
|
181 |
+
|
182 |
+
def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
183 |
+
"""To calculate the affine matrix, three pairs of points are required. This
|
184 |
+
function is used to get the 3rd point, given 2D points a & b.
|
185 |
+
|
186 |
+
The 3rd point is defined by rotating vector `a - b` by 90 degrees
|
187 |
+
anticlockwise, using b as the rotation center.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
a (np.ndarray): The 1st point (x,y) in shape (2, )
|
191 |
+
b (np.ndarray): The 2nd point (x,y) in shape (2, )
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
np.ndarray: The 3rd point.
|
195 |
+
"""
|
196 |
+
direction = a - b
|
197 |
+
c = b + np.r_[-direction[1], direction[0]]
|
198 |
+
return c
|
199 |
+
|
200 |
+
|
201 |
+
def get_warp_matrix(center: np.ndarray,
|
202 |
+
scale: np.ndarray,
|
203 |
+
rot: float,
|
204 |
+
output_size: Tuple[int, int],
|
205 |
+
shift: Tuple[float, float] = (0., 0.),
|
206 |
+
inv: bool = False) -> np.ndarray:
|
207 |
+
"""Calculate the affine transformation matrix that can warp the bbox area
|
208 |
+
in the input image to the output size.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
center (np.ndarray[2, ]): Center of the bounding box (x, y).
|
212 |
+
scale (np.ndarray[2, ]): Scale of the bounding box
|
213 |
+
wrt [width, height].
|
214 |
+
rot (float): Rotation angle (degree).
|
215 |
+
output_size (np.ndarray[2, ] | list(2,)): Size of the
|
216 |
+
destination heatmaps.
|
217 |
+
shift (0-100%): Shift translation ratio wrt the width/height.
|
218 |
+
Default (0., 0.).
|
219 |
+
inv (bool): Option to inverse the affine transform direction.
|
220 |
+
(inv=False: src->dst or inv=True: dst->src)
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
np.ndarray: A 2x3 transformation matrix
|
224 |
+
"""
|
225 |
+
shift = np.array(shift)
|
226 |
+
src_w = scale[0]
|
227 |
+
dst_w = output_size[0]
|
228 |
+
dst_h = output_size[1]
|
229 |
+
|
230 |
+
# compute transformation matrix
|
231 |
+
rot_rad = np.deg2rad(rot)
|
232 |
+
src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
|
233 |
+
dst_dir = np.array([0., dst_w * -0.5])
|
234 |
+
|
235 |
+
# get four corners of the src rectangle in the original image
|
236 |
+
src = np.zeros((3, 2), dtype=np.float32)
|
237 |
+
src[0, :] = center + scale * shift
|
238 |
+
src[1, :] = center + src_dir + scale * shift
|
239 |
+
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
|
240 |
+
|
241 |
+
# get four corners of the dst rectangle in the input image
|
242 |
+
dst = np.zeros((3, 2), dtype=np.float32)
|
243 |
+
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
244 |
+
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
245 |
+
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
|
246 |
+
|
247 |
+
if inv:
|
248 |
+
warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
249 |
+
else:
|
250 |
+
warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
251 |
+
|
252 |
+
return warp_mat
|
253 |
+
|
254 |
+
|
255 |
+
def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
|
256 |
+
img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
257 |
+
"""Get the bbox image as the model input by affine transform.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
input_size (dict): The input size of the model.
|
261 |
+
bbox_scale (dict): The bbox scale of the img.
|
262 |
+
bbox_center (dict): The bbox center of the img.
|
263 |
+
img (np.ndarray): The original image.
|
264 |
+
|
265 |
+
Returns:
|
266 |
+
tuple: A tuple containing center and scale.
|
267 |
+
- np.ndarray[float32]: img after affine transform.
|
268 |
+
- np.ndarray[float32]: bbox scale after affine transform.
|
269 |
+
"""
|
270 |
+
w, h = input_size
|
271 |
+
warp_size = (int(w), int(h))
|
272 |
+
|
273 |
+
# reshape bbox to fixed aspect ratio
|
274 |
+
bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
|
275 |
+
|
276 |
+
# get the affine matrix
|
277 |
+
center = bbox_center
|
278 |
+
scale = bbox_scale
|
279 |
+
rot = 0
|
280 |
+
warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
|
281 |
+
|
282 |
+
# do affine transform
|
283 |
+
img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
|
284 |
+
|
285 |
+
return img, bbox_scale
|
286 |
+
|
287 |
+
|
288 |
+
def get_simcc_maximum(simcc_x: np.ndarray,
|
289 |
+
simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
290 |
+
"""Get maximum response location and value from simcc representations.
|
291 |
+
|
292 |
+
Note:
|
293 |
+
instance number: N
|
294 |
+
num_keypoints: K
|
295 |
+
heatmap height: H
|
296 |
+
heatmap width: W
|
297 |
+
|
298 |
+
Args:
|
299 |
+
simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
|
300 |
+
simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
|
301 |
+
|
302 |
+
Returns:
|
303 |
+
tuple:
|
304 |
+
- locs (np.ndarray): locations of maximum heatmap responses in shape
|
305 |
+
(K, 2) or (N, K, 2)
|
306 |
+
- vals (np.ndarray): values of maximum heatmap responses in shape
|
307 |
+
(K,) or (N, K)
|
308 |
+
"""
|
309 |
+
N, K, Wx = simcc_x.shape
|
310 |
+
simcc_x = simcc_x.reshape(N * K, -1)
|
311 |
+
simcc_y = simcc_y.reshape(N * K, -1)
|
312 |
+
|
313 |
+
# get maximum value locations
|
314 |
+
x_locs = np.argmax(simcc_x, axis=1)
|
315 |
+
y_locs = np.argmax(simcc_y, axis=1)
|
316 |
+
locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
|
317 |
+
max_val_x = np.amax(simcc_x, axis=1)
|
318 |
+
max_val_y = np.amax(simcc_y, axis=1)
|
319 |
+
|
320 |
+
# get maximum value across x and y axis
|
321 |
+
mask = max_val_x > max_val_y
|
322 |
+
max_val_x[mask] = max_val_y[mask]
|
323 |
+
vals = max_val_x
|
324 |
+
locs[vals <= 0.] = -1
|
325 |
+
|
326 |
+
# reshape
|
327 |
+
locs = locs.reshape(N, K, 2)
|
328 |
+
vals = vals.reshape(N, K)
|
329 |
+
|
330 |
+
return locs, vals
|
331 |
+
|
332 |
+
|
333 |
+
def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
|
334 |
+
simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
|
335 |
+
"""Modulate simcc distribution with Gaussian.
|
336 |
+
|
337 |
+
Args:
|
338 |
+
simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
|
339 |
+
simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
|
340 |
+
simcc_split_ratio (int): The split ratio of simcc.
|
341 |
+
|
342 |
+
Returns:
|
343 |
+
tuple: A tuple containing center and scale.
|
344 |
+
- np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
|
345 |
+
- np.ndarray[float32]: scores in shape (K,) or (n, K)
|
346 |
+
"""
|
347 |
+
keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
|
348 |
+
keypoints /= simcc_split_ratio
|
349 |
+
|
350 |
+
return keypoints, scores
|
351 |
+
|
352 |
+
|
353 |
+
def inference_pose(session, out_bbox, oriImg):
|
354 |
+
"""run pose detect
|
355 |
+
|
356 |
+
Args:
|
357 |
+
session (ort.InferenceSession): ONNXRuntime session.
|
358 |
+
out_bbox (np.ndarray): bbox list
|
359 |
+
oriImg (np.ndarray): Input image in shape.
|
360 |
+
|
361 |
+
Returns:
|
362 |
+
tuple:
|
363 |
+
- keypoints (np.ndarray): Rescaled keypoints.
|
364 |
+
- scores (np.ndarray): Model predict scores.
|
365 |
+
"""
|
366 |
+
h, w = session.get_inputs()[0].shape[2:]
|
367 |
+
model_input_size = (w, h)
|
368 |
+
# preprocess for rtm-pose model inference.
|
369 |
+
resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
|
370 |
+
# run pose estimation for processed img
|
371 |
+
outputs = inference(session, resized_img)
|
372 |
+
# postprocess for rtm-pose model output.
|
373 |
+
keypoints, scores = postprocess(outputs, model_input_size, center, scale)
|
374 |
+
|
375 |
+
return keypoints, scores
|
mimicmotion/dwpose/preprocess.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
import decord
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from .util import draw_pose
|
6 |
+
from .dwpose_detector import dwpose_detector as dwprocessor
|
7 |
+
|
8 |
+
|
9 |
+
def get_video_pose(
|
10 |
+
video_path: str,
|
11 |
+
ref_image: np.ndarray,
|
12 |
+
sample_stride: int=1):
|
13 |
+
"""preprocess ref image pose and video pose
|
14 |
+
|
15 |
+
Args:
|
16 |
+
video_path (str): video pose path
|
17 |
+
ref_image (np.ndarray): reference image
|
18 |
+
sample_stride (int, optional): Defaults to 1.
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
np.ndarray: sequence of video pose
|
22 |
+
"""
|
23 |
+
# select ref-keypoint from reference pose for pose rescale
|
24 |
+
ref_pose = dwprocessor(ref_image)
|
25 |
+
ref_keypoint_id = [0, 1, 2, 5, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
|
26 |
+
ref_keypoint_id = [i for i in ref_keypoint_id \
|
27 |
+
if len(ref_pose['bodies']['subset']) > 0 and ref_pose['bodies']['subset'][0][i] >= .0]
|
28 |
+
ref_body = ref_pose['bodies']['candidate'][ref_keypoint_id]
|
29 |
+
|
30 |
+
height, width, _ = ref_image.shape
|
31 |
+
|
32 |
+
# read input video
|
33 |
+
vr = decord.VideoReader(video_path, ctx=decord.cpu(0))
|
34 |
+
sample_stride *= max(1, int(vr.get_avg_fps() / 24))
|
35 |
+
|
36 |
+
frames = vr.get_batch(list(range(0, len(vr), sample_stride))).asnumpy()
|
37 |
+
detected_poses = [dwprocessor(frm) for frm in tqdm(frames, desc="DWPose")]
|
38 |
+
dwprocessor.release_memory()
|
39 |
+
|
40 |
+
detected_bodies = np.stack(
|
41 |
+
[p['bodies']['candidate'] for p in detected_poses if p['bodies']['candidate'].shape[0] == 18])[:,
|
42 |
+
ref_keypoint_id]
|
43 |
+
# compute linear-rescale params
|
44 |
+
ay, by = np.polyfit(detected_bodies[:, :, 1].flatten(), np.tile(ref_body[:, 1], len(detected_bodies)), 1)
|
45 |
+
fh, fw, _ = vr[0].shape
|
46 |
+
ax = ay / (fh / fw / height * width)
|
47 |
+
bx = np.mean(np.tile(ref_body[:, 0], len(detected_bodies)) - detected_bodies[:, :, 0].flatten() * ax)
|
48 |
+
a = np.array([ax, ay])
|
49 |
+
b = np.array([bx, by])
|
50 |
+
output_pose = []
|
51 |
+
# pose rescale
|
52 |
+
for detected_pose in detected_poses:
|
53 |
+
detected_pose['bodies']['candidate'] = detected_pose['bodies']['candidate'] * a + b
|
54 |
+
detected_pose['faces'] = detected_pose['faces'] * a + b
|
55 |
+
detected_pose['hands'] = detected_pose['hands'] * a + b
|
56 |
+
im = draw_pose(detected_pose, height, width)
|
57 |
+
output_pose.append(np.array(im))
|
58 |
+
return np.stack(output_pose)
|
59 |
+
|
60 |
+
|
61 |
+
def get_image_pose(ref_image):
|
62 |
+
"""process image pose
|
63 |
+
|
64 |
+
Args:
|
65 |
+
ref_image (np.ndarray): reference image pixel value
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
np.ndarray: pose visual image in RGB-mode
|
69 |
+
"""
|
70 |
+
height, width, _ = ref_image.shape
|
71 |
+
ref_pose = dwprocessor(ref_image)
|
72 |
+
pose_img = draw_pose(ref_pose, height, width)
|
73 |
+
return np.array(pose_img)
|
mimicmotion/dwpose/util.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib
|
4 |
+
import cv2
|
5 |
+
|
6 |
+
|
7 |
+
eps = 0.01
|
8 |
+
|
9 |
+
def alpha_blend_color(color, alpha):
|
10 |
+
"""blend color according to point conf
|
11 |
+
"""
|
12 |
+
return [int(c * alpha) for c in color]
|
13 |
+
|
14 |
+
def draw_bodypose(canvas, candidate, subset, score):
|
15 |
+
H, W, C = canvas.shape
|
16 |
+
candidate = np.array(candidate)
|
17 |
+
subset = np.array(subset)
|
18 |
+
|
19 |
+
stickwidth = 4
|
20 |
+
|
21 |
+
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
|
22 |
+
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
|
23 |
+
[1, 16], [16, 18], [3, 17], [6, 18]]
|
24 |
+
|
25 |
+
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
|
26 |
+
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
|
27 |
+
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
|
28 |
+
|
29 |
+
for i in range(17):
|
30 |
+
for n in range(len(subset)):
|
31 |
+
index = subset[n][np.array(limbSeq[i]) - 1]
|
32 |
+
conf = score[n][np.array(limbSeq[i]) - 1]
|
33 |
+
if conf[0] < 0.3 or conf[1] < 0.3:
|
34 |
+
continue
|
35 |
+
Y = candidate[index.astype(int), 0] * float(W)
|
36 |
+
X = candidate[index.astype(int), 1] * float(H)
|
37 |
+
mX = np.mean(X)
|
38 |
+
mY = np.mean(Y)
|
39 |
+
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
40 |
+
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
41 |
+
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
42 |
+
cv2.fillConvexPoly(canvas, polygon, alpha_blend_color(colors[i], conf[0] * conf[1]))
|
43 |
+
|
44 |
+
canvas = (canvas * 0.6).astype(np.uint8)
|
45 |
+
|
46 |
+
for i in range(18):
|
47 |
+
for n in range(len(subset)):
|
48 |
+
index = int(subset[n][i])
|
49 |
+
if index == -1:
|
50 |
+
continue
|
51 |
+
x, y = candidate[index][0:2]
|
52 |
+
conf = score[n][i]
|
53 |
+
x = int(x * W)
|
54 |
+
y = int(y * H)
|
55 |
+
cv2.circle(canvas, (int(x), int(y)), 4, alpha_blend_color(colors[i], conf), thickness=-1)
|
56 |
+
|
57 |
+
return canvas
|
58 |
+
|
59 |
+
def draw_handpose(canvas, all_hand_peaks, all_hand_scores):
|
60 |
+
H, W, C = canvas.shape
|
61 |
+
|
62 |
+
edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
|
63 |
+
[10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
|
64 |
+
|
65 |
+
for peaks, scores in zip(all_hand_peaks, all_hand_scores):
|
66 |
+
|
67 |
+
for ie, e in enumerate(edges):
|
68 |
+
x1, y1 = peaks[e[0]]
|
69 |
+
x2, y2 = peaks[e[1]]
|
70 |
+
x1 = int(x1 * W)
|
71 |
+
y1 = int(y1 * H)
|
72 |
+
x2 = int(x2 * W)
|
73 |
+
y2 = int(y2 * H)
|
74 |
+
score = int(scores[e[0]] * scores[e[1]] * 255)
|
75 |
+
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
76 |
+
cv2.line(canvas, (x1, y1), (x2, y2),
|
77 |
+
matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * score, thickness=2)
|
78 |
+
|
79 |
+
for i, keyponit in enumerate(peaks):
|
80 |
+
x, y = keyponit
|
81 |
+
x = int(x * W)
|
82 |
+
y = int(y * H)
|
83 |
+
score = int(scores[i] * 255)
|
84 |
+
if x > eps and y > eps:
|
85 |
+
cv2.circle(canvas, (x, y), 4, (0, 0, score), thickness=-1)
|
86 |
+
return canvas
|
87 |
+
|
88 |
+
def draw_facepose(canvas, all_lmks, all_scores):
|
89 |
+
H, W, C = canvas.shape
|
90 |
+
for lmks, scores in zip(all_lmks, all_scores):
|
91 |
+
for lmk, score in zip(lmks, scores):
|
92 |
+
x, y = lmk
|
93 |
+
x = int(x * W)
|
94 |
+
y = int(y * H)
|
95 |
+
conf = int(score * 255)
|
96 |
+
if x > eps and y > eps:
|
97 |
+
cv2.circle(canvas, (x, y), 3, (conf, conf, conf), thickness=-1)
|
98 |
+
return canvas
|
99 |
+
|
100 |
+
def draw_pose(pose, H, W, ref_w=2160):
|
101 |
+
"""vis dwpose outputs
|
102 |
+
|
103 |
+
Args:
|
104 |
+
pose (List): DWposeDetector outputs in dwpose_detector.py
|
105 |
+
H (int): height
|
106 |
+
W (int): width
|
107 |
+
ref_w (int, optional) Defaults to 2160.
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
np.ndarray: image pixel value in RGB mode
|
111 |
+
"""
|
112 |
+
bodies = pose['bodies']
|
113 |
+
faces = pose['faces']
|
114 |
+
hands = pose['hands']
|
115 |
+
candidate = bodies['candidate']
|
116 |
+
subset = bodies['subset']
|
117 |
+
|
118 |
+
sz = min(H, W)
|
119 |
+
sr = (ref_w / sz) if sz != ref_w else 1
|
120 |
+
|
121 |
+
########################################## create zero canvas ##################################################
|
122 |
+
canvas = np.zeros(shape=(int(H*sr), int(W*sr), 3), dtype=np.uint8)
|
123 |
+
|
124 |
+
########################################### draw body pose #####################################################
|
125 |
+
canvas = draw_bodypose(canvas, candidate, subset, score=bodies['score'])
|
126 |
+
|
127 |
+
########################################### draw hand pose #####################################################
|
128 |
+
canvas = draw_handpose(canvas, hands, pose['hands_score'])
|
129 |
+
|
130 |
+
########################################### draw face pose #####################################################
|
131 |
+
canvas = draw_facepose(canvas, faces, pose['faces_score'])
|
132 |
+
|
133 |
+
return cv2.cvtColor(cv2.resize(canvas, (W, H)), cv2.COLOR_BGR2RGB).transpose(2, 0, 1)
|
mimicmotion/dwpose/wholebody.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import onnxruntime as ort
|
3 |
+
|
4 |
+
from .onnxdet import inference_detector
|
5 |
+
from .onnxpose import inference_pose
|
6 |
+
|
7 |
+
|
8 |
+
class Wholebody:
|
9 |
+
"""detect human pose by dwpose
|
10 |
+
"""
|
11 |
+
def __init__(self, model_det, model_pose, device="cpu"):
|
12 |
+
providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider']
|
13 |
+
provider_options = None if device == 'cpu' else [{'device_id': 0}]
|
14 |
+
|
15 |
+
self.session_det = ort.InferenceSession(
|
16 |
+
path_or_bytes=model_det, providers=providers, provider_options=provider_options
|
17 |
+
)
|
18 |
+
self.session_pose = ort.InferenceSession(
|
19 |
+
path_or_bytes=model_pose, providers=providers, provider_options=provider_options
|
20 |
+
)
|
21 |
+
|
22 |
+
def __call__(self, oriImg):
|
23 |
+
"""call to process dwpose-detect
|
24 |
+
|
25 |
+
Args:
|
26 |
+
oriImg (np.ndarray): detected image
|
27 |
+
|
28 |
+
"""
|
29 |
+
det_result = inference_detector(self.session_det, oriImg)
|
30 |
+
keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
|
31 |
+
|
32 |
+
keypoints_info = np.concatenate(
|
33 |
+
(keypoints, scores[..., None]), axis=-1)
|
34 |
+
# compute neck joint
|
35 |
+
neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
|
36 |
+
# neck score when visualizing pred
|
37 |
+
neck[:, 2:4] = np.logical_and(
|
38 |
+
keypoints_info[:, 5, 2:4] > 0.3,
|
39 |
+
keypoints_info[:, 6, 2:4] > 0.3).astype(int)
|
40 |
+
new_keypoints_info = np.insert(
|
41 |
+
keypoints_info, 17, neck, axis=1)
|
42 |
+
mmpose_idx = [
|
43 |
+
17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
|
44 |
+
]
|
45 |
+
openpose_idx = [
|
46 |
+
1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
|
47 |
+
]
|
48 |
+
new_keypoints_info[:, openpose_idx] = \
|
49 |
+
new_keypoints_info[:, mmpose_idx]
|
50 |
+
keypoints_info = new_keypoints_info
|
51 |
+
|
52 |
+
keypoints, scores = keypoints_info[
|
53 |
+
..., :2], keypoints_info[..., 2]
|
54 |
+
|
55 |
+
return keypoints, scores
|
56 |
+
|
57 |
+
|
mimicmotion/modules/__init__.py
ADDED
File without changes
|
mimicmotion/modules/attention.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Any, Dict, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
6 |
+
from diffusers.models.attention import BasicTransformerBlock, TemporalBasicTransformerBlock
|
7 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
8 |
+
from diffusers.models.modeling_utils import ModelMixin
|
9 |
+
from diffusers.models.resnet import AlphaBlender
|
10 |
+
from diffusers.utils import BaseOutput
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class TransformerTemporalModelOutput(BaseOutput):
|
16 |
+
"""
|
17 |
+
The output of [`TransformerTemporalModel`].
|
18 |
+
|
19 |
+
Args:
|
20 |
+
sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
|
21 |
+
The hidden states output conditioned on `encoder_hidden_states` input.
|
22 |
+
"""
|
23 |
+
|
24 |
+
sample: torch.FloatTensor
|
25 |
+
|
26 |
+
|
27 |
+
class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
28 |
+
"""
|
29 |
+
A Transformer model for video-like data.
|
30 |
+
|
31 |
+
Parameters:
|
32 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
33 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
34 |
+
in_channels (`int`, *optional*):
|
35 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
36 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
37 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
38 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
39 |
+
attention_bias (`bool`, *optional*):
|
40 |
+
Configure if the `TransformerBlock` attention should contain a bias parameter.
|
41 |
+
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
42 |
+
This is fixed during training since it is used to learn a number of position embeddings.
|
43 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
44 |
+
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
|
45 |
+
activation functions.
|
46 |
+
norm_elementwise_affine (`bool`, *optional*):
|
47 |
+
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
|
48 |
+
double_self_attention (`bool`, *optional*):
|
49 |
+
Configure if each `TransformerBlock` should contain two self-attention layers.
|
50 |
+
positional_embeddings: (`str`, *optional*):
|
51 |
+
The type of positional embeddings to apply to the sequence input before passing use.
|
52 |
+
num_positional_embeddings: (`int`, *optional*):
|
53 |
+
The maximum length of the sequence over which to apply positional embeddings.
|
54 |
+
"""
|
55 |
+
|
56 |
+
@register_to_config
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
num_attention_heads: int = 16,
|
60 |
+
attention_head_dim: int = 88,
|
61 |
+
in_channels: Optional[int] = None,
|
62 |
+
out_channels: Optional[int] = None,
|
63 |
+
num_layers: int = 1,
|
64 |
+
dropout: float = 0.0,
|
65 |
+
norm_num_groups: int = 32,
|
66 |
+
cross_attention_dim: Optional[int] = None,
|
67 |
+
attention_bias: bool = False,
|
68 |
+
sample_size: Optional[int] = None,
|
69 |
+
activation_fn: str = "geglu",
|
70 |
+
norm_elementwise_affine: bool = True,
|
71 |
+
double_self_attention: bool = True,
|
72 |
+
positional_embeddings: Optional[str] = None,
|
73 |
+
num_positional_embeddings: Optional[int] = None,
|
74 |
+
):
|
75 |
+
super().__init__()
|
76 |
+
self.num_attention_heads = num_attention_heads
|
77 |
+
self.attention_head_dim = attention_head_dim
|
78 |
+
inner_dim = num_attention_heads * attention_head_dim
|
79 |
+
|
80 |
+
self.in_channels = in_channels
|
81 |
+
|
82 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
83 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
84 |
+
|
85 |
+
# 3. Define transformers blocks
|
86 |
+
self.transformer_blocks = nn.ModuleList(
|
87 |
+
[
|
88 |
+
BasicTransformerBlock(
|
89 |
+
inner_dim,
|
90 |
+
num_attention_heads,
|
91 |
+
attention_head_dim,
|
92 |
+
dropout=dropout,
|
93 |
+
cross_attention_dim=cross_attention_dim,
|
94 |
+
activation_fn=activation_fn,
|
95 |
+
attention_bias=attention_bias,
|
96 |
+
double_self_attention=double_self_attention,
|
97 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
98 |
+
positional_embeddings=positional_embeddings,
|
99 |
+
num_positional_embeddings=num_positional_embeddings,
|
100 |
+
)
|
101 |
+
for d in range(num_layers)
|
102 |
+
]
|
103 |
+
)
|
104 |
+
|
105 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
106 |
+
|
107 |
+
def forward(
|
108 |
+
self,
|
109 |
+
hidden_states: torch.FloatTensor,
|
110 |
+
encoder_hidden_states: Optional[torch.LongTensor] = None,
|
111 |
+
timestep: Optional[torch.LongTensor] = None,
|
112 |
+
class_labels: torch.LongTensor = None,
|
113 |
+
num_frames: int = 1,
|
114 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
115 |
+
return_dict: bool = True,
|
116 |
+
) -> TransformerTemporalModelOutput:
|
117 |
+
"""
|
118 |
+
The [`TransformerTemporal`] forward method.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete,
|
122 |
+
`torch.FloatTensor` of shape `(batch size, channel, height, width)`if continuous): Input hidden_states.
|
123 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
124 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
125 |
+
self-attention.
|
126 |
+
timestep ( `torch.LongTensor`, *optional*):
|
127 |
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
128 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
129 |
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
130 |
+
`AdaLayerZeroNorm`.
|
131 |
+
num_frames (`int`, *optional*, defaults to 1):
|
132 |
+
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
133 |
+
cross_attention_kwargs (`dict`, *optional*):
|
134 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
135 |
+
`self.processor` in [diffusers.models.attention_processor](
|
136 |
+
https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
137 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
138 |
+
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
139 |
+
tuple.
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
143 |
+
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
144 |
+
returned, otherwise a `tuple` where the first element is the sample tensor.
|
145 |
+
"""
|
146 |
+
# 1. Input
|
147 |
+
batch_frames, channel, height, width = hidden_states.shape
|
148 |
+
batch_size = batch_frames // num_frames
|
149 |
+
|
150 |
+
residual = hidden_states
|
151 |
+
|
152 |
+
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
|
153 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
154 |
+
|
155 |
+
hidden_states = self.norm(hidden_states)
|
156 |
+
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
|
157 |
+
|
158 |
+
hidden_states = self.proj_in(hidden_states)
|
159 |
+
|
160 |
+
# 2. Blocks
|
161 |
+
for block in self.transformer_blocks:
|
162 |
+
hidden_states = block(
|
163 |
+
hidden_states,
|
164 |
+
encoder_hidden_states=encoder_hidden_states,
|
165 |
+
timestep=timestep,
|
166 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
167 |
+
class_labels=class_labels,
|
168 |
+
)
|
169 |
+
|
170 |
+
# 3. Output
|
171 |
+
hidden_states = self.proj_out(hidden_states)
|
172 |
+
hidden_states = (
|
173 |
+
hidden_states[None, None, :]
|
174 |
+
.reshape(batch_size, height, width, num_frames, channel)
|
175 |
+
.permute(0, 3, 4, 1, 2)
|
176 |
+
.contiguous()
|
177 |
+
)
|
178 |
+
hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
|
179 |
+
|
180 |
+
output = hidden_states + residual
|
181 |
+
|
182 |
+
if not return_dict:
|
183 |
+
return (output,)
|
184 |
+
|
185 |
+
return TransformerTemporalModelOutput(sample=output)
|
186 |
+
|
187 |
+
|
188 |
+
class TransformerSpatioTemporalModel(nn.Module):
|
189 |
+
"""
|
190 |
+
A Transformer model for video-like data.
|
191 |
+
|
192 |
+
Parameters:
|
193 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
194 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
195 |
+
in_channels (`int`, *optional*):
|
196 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
197 |
+
out_channels (`int`, *optional*):
|
198 |
+
The number of channels in the output (specify if the input is **continuous**).
|
199 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
200 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
201 |
+
"""
|
202 |
+
|
203 |
+
def __init__(
|
204 |
+
self,
|
205 |
+
num_attention_heads: int = 16,
|
206 |
+
attention_head_dim: int = 88,
|
207 |
+
in_channels: int = 320,
|
208 |
+
out_channels: Optional[int] = None,
|
209 |
+
num_layers: int = 1,
|
210 |
+
cross_attention_dim: Optional[int] = None,
|
211 |
+
):
|
212 |
+
super().__init__()
|
213 |
+
self.num_attention_heads = num_attention_heads
|
214 |
+
self.attention_head_dim = attention_head_dim
|
215 |
+
|
216 |
+
inner_dim = num_attention_heads * attention_head_dim
|
217 |
+
self.inner_dim = inner_dim
|
218 |
+
|
219 |
+
# 2. Define input layers
|
220 |
+
self.in_channels = in_channels
|
221 |
+
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
|
222 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
223 |
+
|
224 |
+
# 3. Define transformers blocks
|
225 |
+
self.transformer_blocks = nn.ModuleList(
|
226 |
+
[
|
227 |
+
BasicTransformerBlock(
|
228 |
+
inner_dim,
|
229 |
+
num_attention_heads,
|
230 |
+
attention_head_dim,
|
231 |
+
cross_attention_dim=cross_attention_dim,
|
232 |
+
)
|
233 |
+
for d in range(num_layers)
|
234 |
+
]
|
235 |
+
)
|
236 |
+
|
237 |
+
time_mix_inner_dim = inner_dim
|
238 |
+
self.temporal_transformer_blocks = nn.ModuleList(
|
239 |
+
[
|
240 |
+
TemporalBasicTransformerBlock(
|
241 |
+
inner_dim,
|
242 |
+
time_mix_inner_dim,
|
243 |
+
num_attention_heads,
|
244 |
+
attention_head_dim,
|
245 |
+
cross_attention_dim=cross_attention_dim,
|
246 |
+
)
|
247 |
+
for _ in range(num_layers)
|
248 |
+
]
|
249 |
+
)
|
250 |
+
|
251 |
+
time_embed_dim = in_channels * 4
|
252 |
+
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
|
253 |
+
self.time_proj = Timesteps(in_channels, True, 0)
|
254 |
+
self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
|
255 |
+
|
256 |
+
# 4. Define output layers
|
257 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
258 |
+
# TODO: should use out_channels for continuous projections
|
259 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
260 |
+
|
261 |
+
self.gradient_checkpointing = False
|
262 |
+
|
263 |
+
def forward(
|
264 |
+
self,
|
265 |
+
hidden_states: torch.Tensor,
|
266 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
267 |
+
image_only_indicator: Optional[torch.Tensor] = None,
|
268 |
+
return_dict: bool = True,
|
269 |
+
):
|
270 |
+
"""
|
271 |
+
Args:
|
272 |
+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
273 |
+
Input hidden_states.
|
274 |
+
num_frames (`int`):
|
275 |
+
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
276 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
277 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
278 |
+
self-attention.
|
279 |
+
image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
|
280 |
+
A tensor indicating whether the input contains only images. 1 indicates that the input contains only
|
281 |
+
images, 0 indicates that the input contains video frames.
|
282 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
283 |
+
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`]
|
284 |
+
instead of a plain tuple.
|
285 |
+
|
286 |
+
Returns:
|
287 |
+
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
288 |
+
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
289 |
+
returned, otherwise a `tuple` where the first element is the sample tensor.
|
290 |
+
"""
|
291 |
+
# 1. Input
|
292 |
+
batch_frames, _, height, width = hidden_states.shape
|
293 |
+
num_frames = image_only_indicator.shape[-1]
|
294 |
+
batch_size = batch_frames // num_frames
|
295 |
+
|
296 |
+
time_context = encoder_hidden_states
|
297 |
+
time_context_first_timestep = time_context[None, :].reshape(
|
298 |
+
batch_size, num_frames, -1, time_context.shape[-1]
|
299 |
+
)[:, 0]
|
300 |
+
time_context = time_context_first_timestep[None, :].broadcast_to(
|
301 |
+
height * width, batch_size, 1, time_context.shape[-1]
|
302 |
+
)
|
303 |
+
time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
|
304 |
+
|
305 |
+
residual = hidden_states
|
306 |
+
|
307 |
+
hidden_states = self.norm(hidden_states)
|
308 |
+
inner_dim = hidden_states.shape[1]
|
309 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
|
310 |
+
hidden_states = torch.utils.checkpoint.checkpoint(self.proj_in, hidden_states)
|
311 |
+
|
312 |
+
num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
|
313 |
+
num_frames_emb = num_frames_emb.repeat(batch_size, 1)
|
314 |
+
num_frames_emb = num_frames_emb.reshape(-1)
|
315 |
+
t_emb = self.time_proj(num_frames_emb)
|
316 |
+
|
317 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
318 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
319 |
+
# there might be better ways to encapsulate this.
|
320 |
+
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
321 |
+
|
322 |
+
emb = self.time_pos_embed(t_emb)
|
323 |
+
emb = emb[:, None, :]
|
324 |
+
|
325 |
+
# 2. Blocks
|
326 |
+
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
|
327 |
+
if self.gradient_checkpointing:
|
328 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
329 |
+
block,
|
330 |
+
hidden_states,
|
331 |
+
None,
|
332 |
+
encoder_hidden_states,
|
333 |
+
None,
|
334 |
+
use_reentrant=False,
|
335 |
+
)
|
336 |
+
else:
|
337 |
+
hidden_states = block(
|
338 |
+
hidden_states,
|
339 |
+
encoder_hidden_states=encoder_hidden_states,
|
340 |
+
)
|
341 |
+
|
342 |
+
hidden_states_mix = hidden_states
|
343 |
+
hidden_states_mix = hidden_states_mix + emb
|
344 |
+
|
345 |
+
if self.gradient_checkpointing:
|
346 |
+
hidden_states_mix = torch.utils.checkpoint.checkpoint(
|
347 |
+
temporal_block,
|
348 |
+
hidden_states_mix,
|
349 |
+
num_frames,
|
350 |
+
time_context,
|
351 |
+
)
|
352 |
+
hidden_states = self.time_mixer(
|
353 |
+
x_spatial=hidden_states,
|
354 |
+
x_temporal=hidden_states_mix,
|
355 |
+
image_only_indicator=image_only_indicator,
|
356 |
+
)
|
357 |
+
else:
|
358 |
+
hidden_states_mix = temporal_block(
|
359 |
+
hidden_states_mix,
|
360 |
+
num_frames=num_frames,
|
361 |
+
encoder_hidden_states=time_context,
|
362 |
+
)
|
363 |
+
hidden_states = self.time_mixer(
|
364 |
+
x_spatial=hidden_states,
|
365 |
+
x_temporal=hidden_states_mix,
|
366 |
+
image_only_indicator=image_only_indicator,
|
367 |
+
)
|
368 |
+
|
369 |
+
# 3. Output
|
370 |
+
hidden_states = torch.utils.checkpoint.checkpoint(self.proj_out, hidden_states)
|
371 |
+
hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
372 |
+
|
373 |
+
output = hidden_states + residual
|
374 |
+
|
375 |
+
if not return_dict:
|
376 |
+
return (output,)
|
377 |
+
|
378 |
+
return TransformerTemporalModelOutput(sample=output)
|
mimicmotion/modules/pose_net.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import einops
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.init as init
|
8 |
+
|
9 |
+
|
10 |
+
class PoseNet(nn.Module):
|
11 |
+
"""a tiny conv network for introducing pose sequence as the condition
|
12 |
+
"""
|
13 |
+
def __init__(self, noise_latent_channels=320, *args, **kwargs):
|
14 |
+
super().__init__(*args, **kwargs)
|
15 |
+
# multiple convolution layers
|
16 |
+
self.conv_layers = nn.Sequential(
|
17 |
+
nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1),
|
18 |
+
nn.SiLU(),
|
19 |
+
nn.Conv2d(in_channels=3, out_channels=16, kernel_size=4, stride=2, padding=1),
|
20 |
+
nn.SiLU(),
|
21 |
+
|
22 |
+
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1),
|
23 |
+
nn.SiLU(),
|
24 |
+
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1),
|
25 |
+
nn.SiLU(),
|
26 |
+
|
27 |
+
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),
|
28 |
+
nn.SiLU(),
|
29 |
+
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1),
|
30 |
+
nn.SiLU(),
|
31 |
+
|
32 |
+
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
|
33 |
+
nn.SiLU(),
|
34 |
+
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
|
35 |
+
nn.SiLU()
|
36 |
+
)
|
37 |
+
|
38 |
+
# Final projection layer
|
39 |
+
self.final_proj = nn.Conv2d(in_channels=128, out_channels=noise_latent_channels, kernel_size=1)
|
40 |
+
|
41 |
+
# Initialize layers
|
42 |
+
self._initialize_weights()
|
43 |
+
|
44 |
+
self.scale = nn.Parameter(torch.ones(1) * 2)
|
45 |
+
|
46 |
+
def _initialize_weights(self):
|
47 |
+
"""Initialize weights with He. initialization and zero out the biases
|
48 |
+
"""
|
49 |
+
for m in self.conv_layers:
|
50 |
+
if isinstance(m, nn.Conv2d):
|
51 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
|
52 |
+
init.normal_(m.weight, mean=0.0, std=np.sqrt(2. / n))
|
53 |
+
if m.bias is not None:
|
54 |
+
init.zeros_(m.bias)
|
55 |
+
init.zeros_(self.final_proj.weight)
|
56 |
+
if self.final_proj.bias is not None:
|
57 |
+
init.zeros_(self.final_proj.bias)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
if x.ndim == 5:
|
61 |
+
x = einops.rearrange(x, "b f c h w -> (b f) c h w")
|
62 |
+
x = self.conv_layers(x)
|
63 |
+
x = self.final_proj(x)
|
64 |
+
|
65 |
+
return x * self.scale
|
66 |
+
|
67 |
+
@classmethod
|
68 |
+
def from_pretrained(cls, pretrained_model_path):
|
69 |
+
"""load pretrained pose-net weights
|
70 |
+
"""
|
71 |
+
if not Path(pretrained_model_path).exists():
|
72 |
+
print(f"There is no model file in {pretrained_model_path}")
|
73 |
+
print(f"loaded PoseNet's pretrained weights from {pretrained_model_path}.")
|
74 |
+
|
75 |
+
state_dict = torch.load(pretrained_model_path, map_location="cpu")
|
76 |
+
model = PoseNet(noise_latent_channels=320)
|
77 |
+
|
78 |
+
model.load_state_dict(state_dict, strict=True)
|
79 |
+
|
80 |
+
return model
|
mimicmotion/modules/unet.py
ADDED
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Dict, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
7 |
+
from diffusers.loaders import UNet2DConditionLoadersMixin
|
8 |
+
from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
|
9 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
10 |
+
from diffusers.models.modeling_utils import ModelMixin
|
11 |
+
from diffusers.utils import BaseOutput, logging
|
12 |
+
|
13 |
+
from diffusers.models.unets.unet_3d_blocks import get_down_block, get_up_block, UNetMidBlockSpatioTemporal
|
14 |
+
|
15 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
16 |
+
|
17 |
+
|
18 |
+
@dataclass
|
19 |
+
class UNetSpatioTemporalConditionOutput(BaseOutput):
|
20 |
+
"""
|
21 |
+
The output of [`UNetSpatioTemporalConditionModel`].
|
22 |
+
|
23 |
+
Args:
|
24 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
|
25 |
+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
26 |
+
"""
|
27 |
+
|
28 |
+
sample: torch.FloatTensor = None
|
29 |
+
|
30 |
+
|
31 |
+
class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
32 |
+
r"""
|
33 |
+
A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state,
|
34 |
+
and a timestep and returns a sample shaped output.
|
35 |
+
|
36 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
37 |
+
for all models (such as downloading or saving).
|
38 |
+
|
39 |
+
Parameters:
|
40 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
41 |
+
Height and width of input/output sample.
|
42 |
+
in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
|
43 |
+
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
44 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal",
|
45 |
+
"CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
|
46 |
+
The tuple of downsample blocks to use.
|
47 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal",
|
48 |
+
"CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
|
49 |
+
The tuple of upsample blocks to use.
|
50 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
51 |
+
The tuple of output channels for each block.
|
52 |
+
addition_time_embed_dim: (`int`, defaults to 256):
|
53 |
+
Dimension to to encode the additional time ids.
|
54 |
+
projection_class_embeddings_input_dim (`int`, defaults to 768):
|
55 |
+
The dimension of the projection of encoded `added_time_ids`.
|
56 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
57 |
+
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
58 |
+
The dimension of the cross attention features.
|
59 |
+
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
|
60 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
61 |
+
[`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`],
|
62 |
+
[`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
|
63 |
+
[`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
|
64 |
+
num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
|
65 |
+
The number of attention heads.
|
66 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
67 |
+
"""
|
68 |
+
|
69 |
+
_supports_gradient_checkpointing = True
|
70 |
+
|
71 |
+
@register_to_config
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
sample_size: Optional[int] = None,
|
75 |
+
in_channels: int = 8,
|
76 |
+
out_channels: int = 4,
|
77 |
+
down_block_types: Tuple[str] = (
|
78 |
+
"CrossAttnDownBlockSpatioTemporal",
|
79 |
+
"CrossAttnDownBlockSpatioTemporal",
|
80 |
+
"CrossAttnDownBlockSpatioTemporal",
|
81 |
+
"DownBlockSpatioTemporal",
|
82 |
+
),
|
83 |
+
up_block_types: Tuple[str] = (
|
84 |
+
"UpBlockSpatioTemporal",
|
85 |
+
"CrossAttnUpBlockSpatioTemporal",
|
86 |
+
"CrossAttnUpBlockSpatioTemporal",
|
87 |
+
"CrossAttnUpBlockSpatioTemporal",
|
88 |
+
),
|
89 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
90 |
+
addition_time_embed_dim: int = 256,
|
91 |
+
projection_class_embeddings_input_dim: int = 768,
|
92 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
93 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1024,
|
94 |
+
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
95 |
+
num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
|
96 |
+
num_frames: int = 25,
|
97 |
+
):
|
98 |
+
super().__init__()
|
99 |
+
|
100 |
+
self.sample_size = sample_size
|
101 |
+
|
102 |
+
# Check inputs
|
103 |
+
if len(down_block_types) != len(up_block_types):
|
104 |
+
raise ValueError(
|
105 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. " \
|
106 |
+
f"`down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
107 |
+
)
|
108 |
+
|
109 |
+
if len(block_out_channels) != len(down_block_types):
|
110 |
+
raise ValueError(
|
111 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. " \
|
112 |
+
f"`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
113 |
+
)
|
114 |
+
|
115 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
116 |
+
raise ValueError(
|
117 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. " \
|
118 |
+
f"`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
119 |
+
)
|
120 |
+
|
121 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
122 |
+
raise ValueError(
|
123 |
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. " \
|
124 |
+
f"`cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
125 |
+
)
|
126 |
+
|
127 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
128 |
+
raise ValueError(
|
129 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. " \
|
130 |
+
f"`layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
131 |
+
)
|
132 |
+
|
133 |
+
# input
|
134 |
+
self.conv_in = nn.Conv2d(
|
135 |
+
in_channels,
|
136 |
+
block_out_channels[0],
|
137 |
+
kernel_size=3,
|
138 |
+
padding=1,
|
139 |
+
)
|
140 |
+
|
141 |
+
# time
|
142 |
+
time_embed_dim = block_out_channels[0] * 4
|
143 |
+
|
144 |
+
self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
|
145 |
+
timestep_input_dim = block_out_channels[0]
|
146 |
+
|
147 |
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
148 |
+
|
149 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
|
150 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
151 |
+
|
152 |
+
self.down_blocks = nn.ModuleList([])
|
153 |
+
self.up_blocks = nn.ModuleList([])
|
154 |
+
|
155 |
+
if isinstance(num_attention_heads, int):
|
156 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
157 |
+
|
158 |
+
if isinstance(cross_attention_dim, int):
|
159 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
160 |
+
|
161 |
+
if isinstance(layers_per_block, int):
|
162 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
163 |
+
|
164 |
+
if isinstance(transformer_layers_per_block, int):
|
165 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
166 |
+
|
167 |
+
blocks_time_embed_dim = time_embed_dim
|
168 |
+
|
169 |
+
# down
|
170 |
+
output_channel = block_out_channels[0]
|
171 |
+
for i, down_block_type in enumerate(down_block_types):
|
172 |
+
input_channel = output_channel
|
173 |
+
output_channel = block_out_channels[i]
|
174 |
+
is_final_block = i == len(block_out_channels) - 1
|
175 |
+
|
176 |
+
down_block = get_down_block(
|
177 |
+
down_block_type,
|
178 |
+
num_layers=layers_per_block[i],
|
179 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
180 |
+
in_channels=input_channel,
|
181 |
+
out_channels=output_channel,
|
182 |
+
temb_channels=blocks_time_embed_dim,
|
183 |
+
add_downsample=not is_final_block,
|
184 |
+
resnet_eps=1e-5,
|
185 |
+
cross_attention_dim=cross_attention_dim[i],
|
186 |
+
num_attention_heads=num_attention_heads[i],
|
187 |
+
resnet_act_fn="silu",
|
188 |
+
)
|
189 |
+
self.down_blocks.append(down_block)
|
190 |
+
|
191 |
+
# mid
|
192 |
+
self.mid_block = UNetMidBlockSpatioTemporal(
|
193 |
+
block_out_channels[-1],
|
194 |
+
temb_channels=blocks_time_embed_dim,
|
195 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
196 |
+
cross_attention_dim=cross_attention_dim[-1],
|
197 |
+
num_attention_heads=num_attention_heads[-1],
|
198 |
+
)
|
199 |
+
|
200 |
+
# count how many layers upsample the images
|
201 |
+
self.num_upsamplers = 0
|
202 |
+
|
203 |
+
# up
|
204 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
205 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
206 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
207 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
208 |
+
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
209 |
+
|
210 |
+
output_channel = reversed_block_out_channels[0]
|
211 |
+
for i, up_block_type in enumerate(up_block_types):
|
212 |
+
is_final_block = i == len(block_out_channels) - 1
|
213 |
+
|
214 |
+
prev_output_channel = output_channel
|
215 |
+
output_channel = reversed_block_out_channels[i]
|
216 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
217 |
+
|
218 |
+
# add upsample block for all BUT final layer
|
219 |
+
if not is_final_block:
|
220 |
+
add_upsample = True
|
221 |
+
self.num_upsamplers += 1
|
222 |
+
else:
|
223 |
+
add_upsample = False
|
224 |
+
|
225 |
+
up_block = get_up_block(
|
226 |
+
up_block_type,
|
227 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
228 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
229 |
+
in_channels=input_channel,
|
230 |
+
out_channels=output_channel,
|
231 |
+
prev_output_channel=prev_output_channel,
|
232 |
+
temb_channels=blocks_time_embed_dim,
|
233 |
+
add_upsample=add_upsample,
|
234 |
+
resnet_eps=1e-5,
|
235 |
+
resolution_idx=i,
|
236 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
237 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
238 |
+
resnet_act_fn="silu",
|
239 |
+
)
|
240 |
+
self.up_blocks.append(up_block)
|
241 |
+
prev_output_channel = output_channel
|
242 |
+
|
243 |
+
# out
|
244 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
|
245 |
+
self.conv_act = nn.SiLU()
|
246 |
+
|
247 |
+
self.conv_out = nn.Conv2d(
|
248 |
+
block_out_channels[0],
|
249 |
+
out_channels,
|
250 |
+
kernel_size=3,
|
251 |
+
padding=1,
|
252 |
+
)
|
253 |
+
|
254 |
+
@property
|
255 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
256 |
+
r"""
|
257 |
+
Returns:
|
258 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
259 |
+
indexed by its weight name.
|
260 |
+
"""
|
261 |
+
# set recursively
|
262 |
+
processors = {}
|
263 |
+
|
264 |
+
def fn_recursive_add_processors(
|
265 |
+
name: str,
|
266 |
+
module: torch.nn.Module,
|
267 |
+
processors: Dict[str, AttentionProcessor],
|
268 |
+
):
|
269 |
+
if hasattr(module, "get_processor"):
|
270 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
271 |
+
|
272 |
+
for sub_name, child in module.named_children():
|
273 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
274 |
+
|
275 |
+
return processors
|
276 |
+
|
277 |
+
for name, module in self.named_children():
|
278 |
+
fn_recursive_add_processors(name, module, processors)
|
279 |
+
|
280 |
+
return processors
|
281 |
+
|
282 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
283 |
+
r"""
|
284 |
+
Sets the attention processor to use to compute attention.
|
285 |
+
|
286 |
+
Parameters:
|
287 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
288 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
289 |
+
for **all** `Attention` layers.
|
290 |
+
|
291 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
292 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
293 |
+
|
294 |
+
"""
|
295 |
+
count = len(self.attn_processors.keys())
|
296 |
+
|
297 |
+
if isinstance(processor, dict) and len(processor) != count:
|
298 |
+
raise ValueError(
|
299 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
300 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
301 |
+
)
|
302 |
+
|
303 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
304 |
+
if hasattr(module, "set_processor"):
|
305 |
+
if not isinstance(processor, dict):
|
306 |
+
module.set_processor(processor)
|
307 |
+
else:
|
308 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
309 |
+
|
310 |
+
for sub_name, child in module.named_children():
|
311 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
312 |
+
|
313 |
+
for name, module in self.named_children():
|
314 |
+
fn_recursive_attn_processor(name, module, processor)
|
315 |
+
|
316 |
+
def set_default_attn_processor(self):
|
317 |
+
"""
|
318 |
+
Disables custom attention processors and sets the default attention implementation.
|
319 |
+
"""
|
320 |
+
if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
321 |
+
processor = AttnProcessor()
|
322 |
+
else:
|
323 |
+
raise ValueError(
|
324 |
+
f"Cannot call `set_default_attn_processor` " \
|
325 |
+
f"when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
326 |
+
)
|
327 |
+
|
328 |
+
self.set_attn_processor(processor)
|
329 |
+
|
330 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
331 |
+
if hasattr(module, "gradient_checkpointing"):
|
332 |
+
module.gradient_checkpointing = value
|
333 |
+
|
334 |
+
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
335 |
+
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
336 |
+
"""
|
337 |
+
Sets the attention processor to use [feed forward
|
338 |
+
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
339 |
+
|
340 |
+
Parameters:
|
341 |
+
chunk_size (`int`, *optional*):
|
342 |
+
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
343 |
+
over each tensor of dim=`dim`.
|
344 |
+
dim (`int`, *optional*, defaults to `0`):
|
345 |
+
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
346 |
+
or dim=1 (sequence length).
|
347 |
+
"""
|
348 |
+
if dim not in [0, 1]:
|
349 |
+
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
350 |
+
|
351 |
+
# By default chunk size is 1
|
352 |
+
chunk_size = chunk_size or 1
|
353 |
+
|
354 |
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
355 |
+
if hasattr(module, "set_chunk_feed_forward"):
|
356 |
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
357 |
+
|
358 |
+
for child in module.children():
|
359 |
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
360 |
+
|
361 |
+
for module in self.children():
|
362 |
+
fn_recursive_feed_forward(module, chunk_size, dim)
|
363 |
+
|
364 |
+
def forward(
|
365 |
+
self,
|
366 |
+
sample: torch.FloatTensor,
|
367 |
+
timestep: Union[torch.Tensor, float, int],
|
368 |
+
encoder_hidden_states: torch.Tensor,
|
369 |
+
added_time_ids: torch.Tensor,
|
370 |
+
pose_latents: torch.Tensor = None,
|
371 |
+
image_only_indicator: bool = False,
|
372 |
+
return_dict: bool = True,
|
373 |
+
) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
|
374 |
+
r"""
|
375 |
+
The [`UNetSpatioTemporalConditionModel`] forward method.
|
376 |
+
|
377 |
+
Args:
|
378 |
+
sample (`torch.FloatTensor`):
|
379 |
+
The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
|
380 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
381 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
382 |
+
The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
|
383 |
+
added_time_ids: (`torch.FloatTensor`):
|
384 |
+
The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
|
385 |
+
embeddings and added to the time embeddings.
|
386 |
+
pose_latents: (`torch.FloatTensor`):
|
387 |
+
The additional latents for pose sequences.
|
388 |
+
image_only_indicator (`bool`, *optional*, defaults to `False`):
|
389 |
+
Whether or not training with all images.
|
390 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
391 |
+
Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`]
|
392 |
+
instead of a plain tuple.
|
393 |
+
Returns:
|
394 |
+
[`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
|
395 |
+
If `return_dict` is True,
|
396 |
+
an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned,
|
397 |
+
otherwise a `tuple` is returned where the first element is the sample tensor.
|
398 |
+
"""
|
399 |
+
# 1. time
|
400 |
+
timesteps = timestep
|
401 |
+
if not torch.is_tensor(timesteps):
|
402 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
403 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
404 |
+
is_mps = sample.device.type == "mps"
|
405 |
+
if isinstance(timestep, float):
|
406 |
+
dtype = torch.float32 if is_mps else torch.float64
|
407 |
+
else:
|
408 |
+
dtype = torch.int32 if is_mps else torch.int64
|
409 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
410 |
+
elif len(timesteps.shape) == 0:
|
411 |
+
timesteps = timesteps[None].to(sample.device)
|
412 |
+
|
413 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
414 |
+
batch_size, num_frames = sample.shape[:2]
|
415 |
+
timesteps = timesteps.expand(batch_size)
|
416 |
+
|
417 |
+
t_emb = self.time_proj(timesteps)
|
418 |
+
|
419 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
420 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
421 |
+
# there might be better ways to encapsulate this.
|
422 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
423 |
+
|
424 |
+
emb = self.time_embedding(t_emb)
|
425 |
+
|
426 |
+
time_embeds = self.add_time_proj(added_time_ids.flatten())
|
427 |
+
time_embeds = time_embeds.reshape((batch_size, -1))
|
428 |
+
time_embeds = time_embeds.to(emb.dtype)
|
429 |
+
aug_emb = self.add_embedding(time_embeds)
|
430 |
+
emb = emb + aug_emb
|
431 |
+
|
432 |
+
# Flatten the batch and frames dimensions
|
433 |
+
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
|
434 |
+
sample = sample.flatten(0, 1)
|
435 |
+
# Repeat the embeddings num_video_frames times
|
436 |
+
# emb: [batch, channels] -> [batch * frames, channels]
|
437 |
+
emb = emb.repeat_interleave(num_frames, dim=0)
|
438 |
+
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
|
439 |
+
encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
|
440 |
+
|
441 |
+
# 2. pre-process
|
442 |
+
sample = self.conv_in(sample)
|
443 |
+
if pose_latents is not None:
|
444 |
+
sample = sample + pose_latents
|
445 |
+
|
446 |
+
image_only_indicator = torch.ones(batch_size, num_frames, dtype=sample.dtype, device=sample.device) \
|
447 |
+
if image_only_indicator else torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
|
448 |
+
|
449 |
+
down_block_res_samples = (sample,)
|
450 |
+
for downsample_block in self.down_blocks:
|
451 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
452 |
+
sample, res_samples = downsample_block(
|
453 |
+
hidden_states=sample,
|
454 |
+
temb=emb,
|
455 |
+
encoder_hidden_states=encoder_hidden_states,
|
456 |
+
image_only_indicator=image_only_indicator,
|
457 |
+
)
|
458 |
+
else:
|
459 |
+
sample, res_samples = downsample_block(
|
460 |
+
hidden_states=sample,
|
461 |
+
temb=emb,
|
462 |
+
image_only_indicator=image_only_indicator,
|
463 |
+
)
|
464 |
+
|
465 |
+
down_block_res_samples += res_samples
|
466 |
+
|
467 |
+
# 4. mid
|
468 |
+
sample = self.mid_block(
|
469 |
+
hidden_states=sample,
|
470 |
+
temb=emb,
|
471 |
+
encoder_hidden_states=encoder_hidden_states,
|
472 |
+
image_only_indicator=image_only_indicator,
|
473 |
+
)
|
474 |
+
|
475 |
+
# 5. up
|
476 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
477 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
478 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
479 |
+
|
480 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
481 |
+
sample = upsample_block(
|
482 |
+
hidden_states=sample,
|
483 |
+
temb=emb,
|
484 |
+
res_hidden_states_tuple=res_samples,
|
485 |
+
encoder_hidden_states=encoder_hidden_states,
|
486 |
+
image_only_indicator=image_only_indicator,
|
487 |
+
)
|
488 |
+
else:
|
489 |
+
sample = upsample_block(
|
490 |
+
hidden_states=sample,
|
491 |
+
temb=emb,
|
492 |
+
res_hidden_states_tuple=res_samples,
|
493 |
+
image_only_indicator=image_only_indicator,
|
494 |
+
)
|
495 |
+
|
496 |
+
# 6. post-process
|
497 |
+
sample = self.conv_norm_out(sample)
|
498 |
+
sample = self.conv_act(sample)
|
499 |
+
sample = self.conv_out(sample)
|
500 |
+
|
501 |
+
# 7. Reshape back to original shape
|
502 |
+
sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
|
503 |
+
|
504 |
+
if not return_dict:
|
505 |
+
return (sample,)
|
506 |
+
|
507 |
+
return UNetSpatioTemporalConditionOutput(sample=sample)
|
mimicmotion/pipelines/pipeline_mimicmotion.py
ADDED
@@ -0,0 +1,628 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Callable, Dict, List, Optional, Union
|
4 |
+
|
5 |
+
import PIL.Image
|
6 |
+
import einops
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from diffusers.image_processor import VaeImageProcessor, PipelineImageInput
|
10 |
+
from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
|
11 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
12 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
|
13 |
+
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion \
|
14 |
+
import _resize_with_antialiasing, _append_dims
|
15 |
+
from diffusers.schedulers import EulerDiscreteScheduler
|
16 |
+
from diffusers.utils import BaseOutput, logging
|
17 |
+
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
|
18 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
19 |
+
|
20 |
+
from ..modules.pose_net import PoseNet
|
21 |
+
|
22 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
23 |
+
|
24 |
+
|
25 |
+
def _append_dims(x, target_dims):
|
26 |
+
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
27 |
+
dims_to_append = target_dims - x.ndim
|
28 |
+
if dims_to_append < 0:
|
29 |
+
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
|
30 |
+
return x[(...,) + (None,) * dims_to_append]
|
31 |
+
|
32 |
+
|
33 |
+
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
34 |
+
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
35 |
+
batch_size, channels, num_frames, height, width = video.shape
|
36 |
+
outputs = []
|
37 |
+
for batch_idx in range(batch_size):
|
38 |
+
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
39 |
+
batch_output = processor.postprocess(batch_vid, output_type)
|
40 |
+
|
41 |
+
outputs.append(batch_output)
|
42 |
+
|
43 |
+
if output_type == "np":
|
44 |
+
outputs = np.stack(outputs)
|
45 |
+
|
46 |
+
elif output_type == "pt":
|
47 |
+
outputs = torch.stack(outputs)
|
48 |
+
|
49 |
+
elif not output_type == "pil":
|
50 |
+
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
51 |
+
|
52 |
+
return outputs
|
53 |
+
|
54 |
+
|
55 |
+
@dataclass
|
56 |
+
class MimicMotionPipelineOutput(BaseOutput):
|
57 |
+
r"""
|
58 |
+
Output class for mimicmotion pipeline.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.Tensor`]):
|
62 |
+
List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
|
63 |
+
num_frames, height, width, num_channels)`.
|
64 |
+
"""
|
65 |
+
|
66 |
+
frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.Tensor]
|
67 |
+
|
68 |
+
|
69 |
+
class MimicMotionPipeline(DiffusionPipeline):
|
70 |
+
r"""
|
71 |
+
Pipeline to generate video from an input image using Stable Video Diffusion.
|
72 |
+
|
73 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
74 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
75 |
+
|
76 |
+
Args:
|
77 |
+
vae ([`AutoencoderKLTemporalDecoder`]):
|
78 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
79 |
+
image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
|
80 |
+
Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K]
|
81 |
+
(https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
|
82 |
+
unet ([`UNetSpatioTemporalConditionModel`]):
|
83 |
+
A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.
|
84 |
+
scheduler ([`EulerDiscreteScheduler`]):
|
85 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
86 |
+
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
87 |
+
A `CLIPImageProcessor` to extract features from generated images.
|
88 |
+
pose_net ([`PoseNet`]):
|
89 |
+
A `` to inject pose signals into unet.
|
90 |
+
"""
|
91 |
+
|
92 |
+
model_cpu_offload_seq = "image_encoder->unet->vae"
|
93 |
+
_callback_tensor_inputs = ["latents"]
|
94 |
+
|
95 |
+
def __init__(
|
96 |
+
self,
|
97 |
+
vae: AutoencoderKLTemporalDecoder,
|
98 |
+
image_encoder: CLIPVisionModelWithProjection,
|
99 |
+
unet: UNetSpatioTemporalConditionModel,
|
100 |
+
scheduler: EulerDiscreteScheduler,
|
101 |
+
feature_extractor: CLIPImageProcessor,
|
102 |
+
pose_net: PoseNet,
|
103 |
+
):
|
104 |
+
super().__init__()
|
105 |
+
|
106 |
+
self.register_modules(
|
107 |
+
vae=vae,
|
108 |
+
image_encoder=image_encoder,
|
109 |
+
unet=unet,
|
110 |
+
scheduler=scheduler,
|
111 |
+
feature_extractor=feature_extractor,
|
112 |
+
pose_net=pose_net,
|
113 |
+
)
|
114 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
115 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
116 |
+
|
117 |
+
def _encode_image(
|
118 |
+
self,
|
119 |
+
image: PipelineImageInput,
|
120 |
+
device: Union[str, torch.device],
|
121 |
+
num_videos_per_prompt: int,
|
122 |
+
do_classifier_free_guidance: bool):
|
123 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
124 |
+
|
125 |
+
if not isinstance(image, torch.Tensor):
|
126 |
+
image = self.image_processor.pil_to_numpy(image)
|
127 |
+
image = self.image_processor.numpy_to_pt(image)
|
128 |
+
|
129 |
+
# We normalize the image before resizing to match with the original implementation.
|
130 |
+
# Then we unnormalize it after resizing.
|
131 |
+
image = image * 2.0 - 1.0
|
132 |
+
image = _resize_with_antialiasing(image, (224, 224))
|
133 |
+
image = (image + 1.0) / 2.0
|
134 |
+
|
135 |
+
# Normalize the image with for CLIP input
|
136 |
+
image = self.feature_extractor(
|
137 |
+
images=image,
|
138 |
+
do_normalize=True,
|
139 |
+
do_center_crop=False,
|
140 |
+
do_resize=False,
|
141 |
+
do_rescale=False,
|
142 |
+
return_tensors="pt",
|
143 |
+
).pixel_values
|
144 |
+
|
145 |
+
image = image.to(device=device, dtype=dtype)
|
146 |
+
image_embeddings = self.image_encoder(image).image_embeds
|
147 |
+
image_embeddings = image_embeddings.unsqueeze(1)
|
148 |
+
|
149 |
+
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
150 |
+
bs_embed, seq_len, _ = image_embeddings.shape
|
151 |
+
image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
|
152 |
+
image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
153 |
+
|
154 |
+
if do_classifier_free_guidance:
|
155 |
+
negative_image_embeddings = torch.zeros_like(image_embeddings)
|
156 |
+
|
157 |
+
# For classifier free guidance, we need to do two forward passes.
|
158 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
159 |
+
# to avoid doing two forward passes
|
160 |
+
image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
|
161 |
+
|
162 |
+
return image_embeddings
|
163 |
+
|
164 |
+
def _encode_vae_image(
|
165 |
+
self,
|
166 |
+
image: torch.Tensor,
|
167 |
+
device: Union[str, torch.device],
|
168 |
+
num_videos_per_prompt: int,
|
169 |
+
do_classifier_free_guidance: bool,
|
170 |
+
):
|
171 |
+
image = image.to(device=device, dtype=self.vae.dtype)
|
172 |
+
image_latents = self.vae.encode(image).latent_dist.mode()
|
173 |
+
|
174 |
+
if do_classifier_free_guidance:
|
175 |
+
negative_image_latents = torch.zeros_like(image_latents)
|
176 |
+
|
177 |
+
# For classifier free guidance, we need to do two forward passes.
|
178 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
179 |
+
# to avoid doing two forward passes
|
180 |
+
image_latents = torch.cat([negative_image_latents, image_latents])
|
181 |
+
|
182 |
+
# duplicate image_latents for each generation per prompt, using mps friendly method
|
183 |
+
image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
|
184 |
+
|
185 |
+
return image_latents
|
186 |
+
|
187 |
+
def _get_add_time_ids(
|
188 |
+
self,
|
189 |
+
fps: int,
|
190 |
+
motion_bucket_id: int,
|
191 |
+
noise_aug_strength: float,
|
192 |
+
dtype: torch.dtype,
|
193 |
+
batch_size: int,
|
194 |
+
num_videos_per_prompt: int,
|
195 |
+
do_classifier_free_guidance: bool,
|
196 |
+
):
|
197 |
+
add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
|
198 |
+
|
199 |
+
passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
|
200 |
+
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
201 |
+
|
202 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
203 |
+
raise ValueError(
|
204 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, " \
|
205 |
+
f"but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. " \
|
206 |
+
f"Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
207 |
+
)
|
208 |
+
|
209 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
210 |
+
add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
|
211 |
+
|
212 |
+
if do_classifier_free_guidance:
|
213 |
+
add_time_ids = torch.cat([add_time_ids, add_time_ids])
|
214 |
+
|
215 |
+
return add_time_ids
|
216 |
+
|
217 |
+
def decode_latents(
|
218 |
+
self,
|
219 |
+
latents: torch.Tensor,
|
220 |
+
num_frames: int,
|
221 |
+
decode_chunk_size: int = 8):
|
222 |
+
# [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
|
223 |
+
latents = latents.flatten(0, 1)
|
224 |
+
|
225 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
226 |
+
|
227 |
+
forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
|
228 |
+
accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
|
229 |
+
|
230 |
+
# decode decode_chunk_size frames at a time to avoid OOM
|
231 |
+
frames = []
|
232 |
+
for i in range(0, latents.shape[0], decode_chunk_size):
|
233 |
+
num_frames_in = latents[i: i + decode_chunk_size].shape[0]
|
234 |
+
decode_kwargs = {}
|
235 |
+
if accepts_num_frames:
|
236 |
+
# we only pass num_frames_in if it's expected
|
237 |
+
decode_kwargs["num_frames"] = num_frames_in
|
238 |
+
|
239 |
+
frame = self.vae.decode(latents[i: i + decode_chunk_size], **decode_kwargs).sample
|
240 |
+
frames.append(frame.cpu())
|
241 |
+
frames = torch.cat(frames, dim=0)
|
242 |
+
|
243 |
+
# [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
|
244 |
+
frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
|
245 |
+
|
246 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
247 |
+
frames = frames.float()
|
248 |
+
return frames
|
249 |
+
|
250 |
+
def check_inputs(self, image, height, width):
|
251 |
+
if (
|
252 |
+
not isinstance(image, torch.Tensor)
|
253 |
+
and not isinstance(image, PIL.Image.Image)
|
254 |
+
and not isinstance(image, list)
|
255 |
+
):
|
256 |
+
raise ValueError(
|
257 |
+
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
258 |
+
f" {type(image)}"
|
259 |
+
)
|
260 |
+
|
261 |
+
if height % 8 != 0 or width % 8 != 0:
|
262 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
263 |
+
|
264 |
+
def prepare_latents(
|
265 |
+
self,
|
266 |
+
batch_size: int,
|
267 |
+
num_frames: int,
|
268 |
+
num_channels_latents: int,
|
269 |
+
height: int,
|
270 |
+
width: int,
|
271 |
+
dtype: torch.dtype,
|
272 |
+
device: Union[str, torch.device],
|
273 |
+
generator: torch.Generator,
|
274 |
+
latents: Optional[torch.Tensor] = None,
|
275 |
+
):
|
276 |
+
shape = (
|
277 |
+
batch_size,
|
278 |
+
num_frames,
|
279 |
+
num_channels_latents // 2,
|
280 |
+
height // self.vae_scale_factor,
|
281 |
+
width // self.vae_scale_factor,
|
282 |
+
)
|
283 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
284 |
+
raise ValueError(
|
285 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
286 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
287 |
+
)
|
288 |
+
|
289 |
+
if latents is None:
|
290 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
291 |
+
else:
|
292 |
+
latents = latents.to(device)
|
293 |
+
|
294 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
295 |
+
latents = latents * self.scheduler.init_noise_sigma
|
296 |
+
return latents
|
297 |
+
|
298 |
+
@property
|
299 |
+
def guidance_scale(self):
|
300 |
+
return self._guidance_scale
|
301 |
+
|
302 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
303 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
304 |
+
# corresponds to doing no classifier free guidance.
|
305 |
+
@property
|
306 |
+
def do_classifier_free_guidance(self):
|
307 |
+
if isinstance(self.guidance_scale, (int, float)):
|
308 |
+
return self.guidance_scale > 1
|
309 |
+
return self.guidance_scale.max() > 1
|
310 |
+
|
311 |
+
@property
|
312 |
+
def num_timesteps(self):
|
313 |
+
return self._num_timesteps
|
314 |
+
|
315 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
316 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
317 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
318 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
319 |
+
# and should be between [0, 1]
|
320 |
+
|
321 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
322 |
+
extra_step_kwargs = {}
|
323 |
+
if accepts_eta:
|
324 |
+
extra_step_kwargs["eta"] = eta
|
325 |
+
|
326 |
+
# check if the scheduler accepts generator
|
327 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
328 |
+
if accepts_generator:
|
329 |
+
extra_step_kwargs["generator"] = generator
|
330 |
+
return extra_step_kwargs
|
331 |
+
|
332 |
+
@torch.no_grad()
|
333 |
+
def __call__(
|
334 |
+
self,
|
335 |
+
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
|
336 |
+
image_pose: Union[torch.FloatTensor],
|
337 |
+
height: int = 576,
|
338 |
+
width: int = 1024,
|
339 |
+
num_frames: Optional[int] = None,
|
340 |
+
tile_size: Optional[int] = 16,
|
341 |
+
tile_overlap: Optional[int] = 4,
|
342 |
+
num_inference_steps: int = 25,
|
343 |
+
min_guidance_scale: float = 1.0,
|
344 |
+
max_guidance_scale: float = 3.0,
|
345 |
+
fps: int = 7,
|
346 |
+
motion_bucket_id: int = 127,
|
347 |
+
noise_aug_strength: float = 0.02,
|
348 |
+
image_only_indicator: bool = False,
|
349 |
+
decode_chunk_size: Optional[int] = None,
|
350 |
+
num_videos_per_prompt: Optional[int] = 1,
|
351 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
352 |
+
latents: Optional[torch.FloatTensor] = None,
|
353 |
+
output_type: Optional[str] = "pil",
|
354 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
355 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
356 |
+
return_dict: bool = True,
|
357 |
+
device: Union[str, torch.device] =None,
|
358 |
+
):
|
359 |
+
r"""
|
360 |
+
The call function to the pipeline for generation.
|
361 |
+
|
362 |
+
Args:
|
363 |
+
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
|
364 |
+
Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
|
365 |
+
[`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/
|
366 |
+
feature_extractor/preprocessor_config.json).
|
367 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
368 |
+
The height in pixels of the generated image.
|
369 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
370 |
+
The width in pixels of the generated image.
|
371 |
+
num_frames (`int`, *optional*):
|
372 |
+
The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid`
|
373 |
+
and to 25 for `stable-video-diffusion-img2vid-xt`
|
374 |
+
num_inference_steps (`int`, *optional*, defaults to 25):
|
375 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
376 |
+
expense of slower inference. This parameter is modulated by `strength`.
|
377 |
+
min_guidance_scale (`float`, *optional*, defaults to 1.0):
|
378 |
+
The minimum guidance scale. Used for the classifier free guidance with first frame.
|
379 |
+
max_guidance_scale (`float`, *optional*, defaults to 3.0):
|
380 |
+
The maximum guidance scale. Used for the classifier free guidance with last frame.
|
381 |
+
fps (`int`, *optional*, defaults to 7):
|
382 |
+
Frames per second.The rate at which the generated images shall be exported to a video after generation.
|
383 |
+
Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
|
384 |
+
motion_bucket_id (`int`, *optional*, defaults to 127):
|
385 |
+
The motion bucket ID. Used as conditioning for the generation.
|
386 |
+
The higher the number the more motion will be in the video.
|
387 |
+
noise_aug_strength (`float`, *optional*, defaults to 0.02):
|
388 |
+
The amount of noise added to the init image,
|
389 |
+
the higher it is the less the video will look like the init image. Increase it for more motion.
|
390 |
+
image_only_indicator (`bool`, *optional*, defaults to False):
|
391 |
+
Whether to treat the inputs as batch of images instead of videos.
|
392 |
+
decode_chunk_size (`int`, *optional*):
|
393 |
+
The number of frames to decode at a time.The higher the chunk size, the higher the temporal consistency
|
394 |
+
between frames, but also the higher the memory consumption.
|
395 |
+
By default, the decoder will decode all frames at once for maximal quality.
|
396 |
+
Reduce `decode_chunk_size` to reduce memory usage.
|
397 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
398 |
+
The number of images to generate per prompt.
|
399 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
400 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
401 |
+
generation deterministic.
|
402 |
+
latents (`torch.FloatTensor`, *optional*):
|
403 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
404 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
405 |
+
tensor is generated by sampling using the supplied random `generator`.
|
406 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
407 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
408 |
+
callback_on_step_end (`Callable`, *optional*):
|
409 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
410 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
411 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
412 |
+
`callback_on_step_end_tensor_inputs`.
|
413 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
414 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
415 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
416 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
417 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
418 |
+
Whether to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
419 |
+
plain tuple.
|
420 |
+
device:
|
421 |
+
On which device the pipeline runs on.
|
422 |
+
|
423 |
+
Returns:
|
424 |
+
[`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
|
425 |
+
If `return_dict` is `True`,
|
426 |
+
[`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
|
427 |
+
otherwise a `tuple` is returned where the first element is a list of list with the generated frames.
|
428 |
+
|
429 |
+
Examples:
|
430 |
+
|
431 |
+
```py
|
432 |
+
from diffusers import StableVideoDiffusionPipeline
|
433 |
+
from diffusers.utils import load_image, export_to_video
|
434 |
+
|
435 |
+
pipe = StableVideoDiffusionPipeline.from_pretrained(
|
436 |
+
"stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
|
437 |
+
pipe.to("cuda")
|
438 |
+
|
439 |
+
image = load_image(
|
440 |
+
"https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
|
441 |
+
image = image.resize((1024, 576))
|
442 |
+
|
443 |
+
frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
|
444 |
+
export_to_video(frames, "generated.mp4", fps=7)
|
445 |
+
```
|
446 |
+
"""
|
447 |
+
# 0. Default height and width to unet
|
448 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
449 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
450 |
+
|
451 |
+
num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
|
452 |
+
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
|
453 |
+
|
454 |
+
# 1. Check inputs. Raise error if not correct
|
455 |
+
self.check_inputs(image, height, width)
|
456 |
+
|
457 |
+
# 2. Define call parameters
|
458 |
+
if isinstance(image, PIL.Image.Image):
|
459 |
+
batch_size = 1
|
460 |
+
elif isinstance(image, list):
|
461 |
+
batch_size = len(image)
|
462 |
+
else:
|
463 |
+
batch_size = image.shape[0]
|
464 |
+
device = device if device is not None else self._execution_device
|
465 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
466 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
467 |
+
# corresponds to doing no classifier free guidance.
|
468 |
+
self._guidance_scale = max_guidance_scale
|
469 |
+
|
470 |
+
# 3. Encode input image
|
471 |
+
self.image_encoder.to(device)
|
472 |
+
image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
|
473 |
+
self.image_encoder.cpu()
|
474 |
+
|
475 |
+
# NOTE: Stable Diffusion Video was conditioned on fps - 1, which
|
476 |
+
# is why it is reduced here.
|
477 |
+
fps = fps - 1
|
478 |
+
|
479 |
+
# 4. Encode input image using VAE
|
480 |
+
image = self.image_processor.preprocess(image, height=height, width=width).to(device)
|
481 |
+
noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
|
482 |
+
image = image + noise_aug_strength * noise
|
483 |
+
|
484 |
+
self.vae.to(device)
|
485 |
+
image_latents = self._encode_vae_image(
|
486 |
+
image,
|
487 |
+
device=device,
|
488 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
489 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
490 |
+
)
|
491 |
+
image_latents = image_latents.to(image_embeddings.dtype)
|
492 |
+
self.vae.cpu()
|
493 |
+
|
494 |
+
# Repeat the image latents for each frame so we can concatenate them with the noise
|
495 |
+
# image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
|
496 |
+
image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
|
497 |
+
|
498 |
+
# 5. Get Added Time IDs
|
499 |
+
added_time_ids = self._get_add_time_ids(
|
500 |
+
fps,
|
501 |
+
motion_bucket_id,
|
502 |
+
noise_aug_strength,
|
503 |
+
image_embeddings.dtype,
|
504 |
+
batch_size,
|
505 |
+
num_videos_per_prompt,
|
506 |
+
self.do_classifier_free_guidance,
|
507 |
+
)
|
508 |
+
added_time_ids = added_time_ids.to(device)
|
509 |
+
|
510 |
+
# 4. Prepare timesteps
|
511 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None)
|
512 |
+
|
513 |
+
# 5. Prepare latent variables
|
514 |
+
num_channels_latents = self.unet.config.in_channels
|
515 |
+
latents = self.prepare_latents(
|
516 |
+
batch_size * num_videos_per_prompt,
|
517 |
+
tile_size,
|
518 |
+
num_channels_latents,
|
519 |
+
height,
|
520 |
+
width,
|
521 |
+
image_embeddings.dtype,
|
522 |
+
device,
|
523 |
+
generator,
|
524 |
+
latents,
|
525 |
+
)
|
526 |
+
latents = latents.repeat(1, num_frames // tile_size + 1, 1, 1, 1)[:, :num_frames]
|
527 |
+
|
528 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
529 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, 0.0)
|
530 |
+
|
531 |
+
# 7. Prepare guidance scale
|
532 |
+
guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
|
533 |
+
guidance_scale = guidance_scale.to(device, latents.dtype)
|
534 |
+
guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
|
535 |
+
guidance_scale = _append_dims(guidance_scale, latents.ndim)
|
536 |
+
|
537 |
+
self._guidance_scale = guidance_scale
|
538 |
+
|
539 |
+
# 8. Denoising loop
|
540 |
+
self._num_timesteps = len(timesteps)
|
541 |
+
indices = [[0, *range(i + 1, min(i + tile_size, num_frames))] for i in
|
542 |
+
range(0, num_frames - tile_size + 1, tile_size - tile_overlap)]
|
543 |
+
if indices[-1][-1] < num_frames - 1:
|
544 |
+
indices.append([0, *range(num_frames - tile_size + 1, num_frames)])
|
545 |
+
|
546 |
+
self.pose_net.to(device)
|
547 |
+
self.unet.to(device)
|
548 |
+
|
549 |
+
with torch.cuda.device(device):
|
550 |
+
torch.cuda.empty_cache()
|
551 |
+
|
552 |
+
with self.progress_bar(total=len(timesteps) * len(indices)) as progress_bar:
|
553 |
+
for i, t in enumerate(timesteps):
|
554 |
+
# expand the latents if we are doing classifier free guidance
|
555 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
556 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
557 |
+
|
558 |
+
# Concatenate image_latents over channels dimension
|
559 |
+
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
|
560 |
+
|
561 |
+
# predict the noise residual
|
562 |
+
noise_pred = torch.zeros_like(image_latents)
|
563 |
+
noise_pred_cnt = image_latents.new_zeros((num_frames,))
|
564 |
+
weight = (torch.arange(tile_size, device=device) + 0.5) * 2. / tile_size
|
565 |
+
weight = torch.minimum(weight, 2 - weight)
|
566 |
+
for idx in indices:
|
567 |
+
|
568 |
+
# classification-free inference
|
569 |
+
pose_latents = self.pose_net(image_pose[idx].to(device))
|
570 |
+
_noise_pred = self.unet(
|
571 |
+
latent_model_input[:1, idx],
|
572 |
+
t,
|
573 |
+
encoder_hidden_states=image_embeddings[:1],
|
574 |
+
added_time_ids=added_time_ids[:1],
|
575 |
+
pose_latents=None,
|
576 |
+
image_only_indicator=image_only_indicator,
|
577 |
+
return_dict=False,
|
578 |
+
)[0]
|
579 |
+
noise_pred[:1, idx] += _noise_pred * weight[:, None, None, None]
|
580 |
+
|
581 |
+
# normal inference
|
582 |
+
_noise_pred = self.unet(
|
583 |
+
latent_model_input[1:, idx],
|
584 |
+
t,
|
585 |
+
encoder_hidden_states=image_embeddings[1:],
|
586 |
+
added_time_ids=added_time_ids[1:],
|
587 |
+
pose_latents=pose_latents,
|
588 |
+
image_only_indicator=image_only_indicator,
|
589 |
+
return_dict=False,
|
590 |
+
)[0]
|
591 |
+
noise_pred[1:, idx] += _noise_pred * weight[:, None, None, None]
|
592 |
+
|
593 |
+
noise_pred_cnt[idx] += weight
|
594 |
+
progress_bar.update()
|
595 |
+
noise_pred.div_(noise_pred_cnt[:, None, None, None])
|
596 |
+
|
597 |
+
# perform guidance
|
598 |
+
if self.do_classifier_free_guidance:
|
599 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
600 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
601 |
+
|
602 |
+
# compute the previous noisy sample x_t -> x_t-1
|
603 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
604 |
+
|
605 |
+
if callback_on_step_end is not None:
|
606 |
+
callback_kwargs = {}
|
607 |
+
for k in callback_on_step_end_tensor_inputs:
|
608 |
+
callback_kwargs[k] = locals()[k]
|
609 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
610 |
+
|
611 |
+
latents = callback_outputs.pop("latents", latents)
|
612 |
+
|
613 |
+
self.pose_net.cpu()
|
614 |
+
self.unet.cpu()
|
615 |
+
|
616 |
+
if not output_type == "latent":
|
617 |
+
self.vae.decoder.to(device)
|
618 |
+
frames = self.decode_latents(latents, num_frames, decode_chunk_size)
|
619 |
+
frames = tensor2vid(frames, self.image_processor, output_type=output_type)
|
620 |
+
else:
|
621 |
+
frames = latents
|
622 |
+
|
623 |
+
self.maybe_free_model_hooks()
|
624 |
+
|
625 |
+
if not return_dict:
|
626 |
+
return frames
|
627 |
+
|
628 |
+
return MimicMotionPipelineOutput(frames=frames)
|
mimicmotion/utils/__init__.py
ADDED
File without changes
|
mimicmotion/utils/geglu_patch.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import diffusers.models.activations
|
2 |
+
|
3 |
+
|
4 |
+
def patch_geglu_inplace():
|
5 |
+
"""Patch GEGLU with inplace multiplication to save GPU memory."""
|
6 |
+
def forward(self, hidden_states):
|
7 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
8 |
+
return hidden_states.mul_(self.gelu(gate))
|
9 |
+
diffusers.models.activations.GEGLU.forward = forward
|
mimicmotion/utils/loader.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.utils.checkpoint
|
5 |
+
from diffusers.models import AutoencoderKLTemporalDecoder
|
6 |
+
from diffusers.schedulers import EulerDiscreteScheduler
|
7 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
8 |
+
|
9 |
+
from ..modules.unet import UNetSpatioTemporalConditionModel
|
10 |
+
from ..modules.pose_net import PoseNet
|
11 |
+
from ..pipelines.pipeline_mimicmotion import MimicMotionPipeline
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
class MimicMotionModel(torch.nn.Module):
|
16 |
+
def __init__(self, base_model_path):
|
17 |
+
"""construnct base model components and load pretrained svd model except pose-net
|
18 |
+
Args:
|
19 |
+
base_model_path (str): pretrained svd model path
|
20 |
+
"""
|
21 |
+
super().__init__()
|
22 |
+
self.unet = UNetSpatioTemporalConditionModel.from_config(
|
23 |
+
UNetSpatioTemporalConditionModel.load_config(base_model_path, subfolder="unet"))
|
24 |
+
self.vae = AutoencoderKLTemporalDecoder.from_pretrained(
|
25 |
+
base_model_path, subfolder="vae", torch_dtype=torch.float16, variant="fp16")
|
26 |
+
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
27 |
+
base_model_path, subfolder="image_encoder", torch_dtype=torch.float16, variant="fp16")
|
28 |
+
self.noise_scheduler = EulerDiscreteScheduler.from_pretrained(
|
29 |
+
base_model_path, subfolder="scheduler")
|
30 |
+
self.feature_extractor = CLIPImageProcessor.from_pretrained(
|
31 |
+
base_model_path, subfolder="feature_extractor")
|
32 |
+
# pose_net
|
33 |
+
self.pose_net = PoseNet(noise_latent_channels=self.unet.config.block_out_channels[0])
|
34 |
+
|
35 |
+
def create_pipeline(infer_config, device):
|
36 |
+
"""create mimicmotion pipeline and load pretrained weight
|
37 |
+
|
38 |
+
Args:
|
39 |
+
infer_config (str):
|
40 |
+
device (str or torch.device): "cpu" or "cuda:{device_id}"
|
41 |
+
"""
|
42 |
+
mimicmotion_models = MimicMotionModel(infer_config.base_model_path)
|
43 |
+
mimicmotion_models.load_state_dict(torch.load(infer_config.ckpt_path, map_location="cpu"), strict=False)
|
44 |
+
pipeline = MimicMotionPipeline(
|
45 |
+
vae=mimicmotion_models.vae,
|
46 |
+
image_encoder=mimicmotion_models.image_encoder,
|
47 |
+
unet=mimicmotion_models.unet,
|
48 |
+
scheduler=mimicmotion_models.noise_scheduler,
|
49 |
+
feature_extractor=mimicmotion_models.feature_extractor,
|
50 |
+
pose_net=mimicmotion_models.pose_net
|
51 |
+
)
|
52 |
+
return pipeline
|
53 |
+
|
mimicmotion/utils/utils.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
from torchvision.io import write_video
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
def save_to_mp4(frames, save_path, fps=7):
|
9 |
+
frames = frames.permute((0, 2, 3, 1)) # (f, c, h, w) to (f, h, w, c)
|
10 |
+
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
|
11 |
+
write_video(save_path, frames, fps=fps)
|
12 |
+
|
predict.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# predict.py
|
2 |
+
import subprocess
|
3 |
+
import time
|
4 |
+
from cog import BasePredictor, Input, Path
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
from omegaconf import OmegaConf
|
10 |
+
from datetime import datetime
|
11 |
+
|
12 |
+
from torchvision.transforms.functional import pil_to_tensor, resize, center_crop
|
13 |
+
from constants import ASPECT_RATIO
|
14 |
+
|
15 |
+
MODEL_CACHE = "models"
|
16 |
+
os.environ["HF_DATASETS_OFFLINE"] = "1"
|
17 |
+
os.environ["TRANSFORMERS_OFFLINE"] = "1"
|
18 |
+
os.environ["HF_HOME"] = MODEL_CACHE
|
19 |
+
os.environ["TORCH_HOME"] = MODEL_CACHE
|
20 |
+
os.environ["HF_DATASETS_CACHE"] = MODEL_CACHE
|
21 |
+
os.environ["TRANSFORMERS_CACHE"] = MODEL_CACHE
|
22 |
+
os.environ["HUGGINGFACE_HUB_CACHE"] = MODEL_CACHE
|
23 |
+
|
24 |
+
BASE_URL = f"https://weights.replicate.delivery/default/MimicMotion/{MODEL_CACHE}/"
|
25 |
+
|
26 |
+
|
27 |
+
def download_weights(url: str, dest: str) -> None:
|
28 |
+
# NOTE WHEN YOU EXTRACT SPECIFY THE PARENT FOLDER
|
29 |
+
start = time.time()
|
30 |
+
print("[!] Initiating download from URL: ", url)
|
31 |
+
print("[~] Destination path: ", dest)
|
32 |
+
if ".tar" in dest:
|
33 |
+
dest = os.path.dirname(dest)
|
34 |
+
command = ["pget", "-vf" + ("x" if ".tar" in url else ""), url, dest]
|
35 |
+
try:
|
36 |
+
print(f"[~] Running command: {' '.join(command)}")
|
37 |
+
subprocess.check_call(command, close_fds=False)
|
38 |
+
except subprocess.CalledProcessError as e:
|
39 |
+
print(
|
40 |
+
f"[ERROR] Failed to download weights. Command '{' '.join(e.cmd)}' returned non-zero exit status {e.returncode}."
|
41 |
+
)
|
42 |
+
raise
|
43 |
+
print("[+] Download completed in: ", time.time() - start, "seconds")
|
44 |
+
|
45 |
+
|
46 |
+
class Predictor(BasePredictor):
|
47 |
+
def setup(self):
|
48 |
+
"""Load the model into memory to make running multiple predictions efficient"""
|
49 |
+
|
50 |
+
if not os.path.exists(MODEL_CACHE):
|
51 |
+
os.makedirs(MODEL_CACHE)
|
52 |
+
model_files = [
|
53 |
+
"DWPose.tar",
|
54 |
+
"MimicMotion.pth",
|
55 |
+
"MimicMotion_1-1.pth",
|
56 |
+
"SVD.tar",
|
57 |
+
]
|
58 |
+
for model_file in model_files:
|
59 |
+
url = BASE_URL + model_file
|
60 |
+
filename = url.split("/")[-1]
|
61 |
+
dest_path = os.path.join(MODEL_CACHE, filename)
|
62 |
+
if not os.path.exists(dest_path.replace(".tar", "")):
|
63 |
+
download_weights(url, dest_path)
|
64 |
+
|
65 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
66 |
+
print(f"Using device: {self.device}")
|
67 |
+
|
68 |
+
# Move imports here and make them global
|
69 |
+
# This ensures model files are downloaded before importing mimicmotion modules
|
70 |
+
global MimicMotionPipeline, create_pipeline, save_to_mp4, get_video_pose, get_image_pose
|
71 |
+
from mimicmotion.pipelines.pipeline_mimicmotion import MimicMotionPipeline
|
72 |
+
from mimicmotion.utils.loader import create_pipeline
|
73 |
+
from mimicmotion.utils.utils import save_to_mp4
|
74 |
+
from mimicmotion.dwpose.preprocess import get_video_pose, get_image_pose
|
75 |
+
|
76 |
+
# Load config with new checkpoint as default
|
77 |
+
self.config = OmegaConf.create(
|
78 |
+
{
|
79 |
+
"base_model_path": "models/SVD/stable-video-diffusion-img2vid-xt-1-1",
|
80 |
+
"ckpt_path": "models/MimicMotion_1-1.pth",
|
81 |
+
}
|
82 |
+
)
|
83 |
+
|
84 |
+
# Create the pipeline with the new checkpoint
|
85 |
+
self.pipeline = create_pipeline(self.config, self.device)
|
86 |
+
self.current_checkpoint = "v1-1"
|
87 |
+
self.current_dtype = torch.get_default_dtype()
|
88 |
+
|
89 |
+
def predict(
|
90 |
+
self,
|
91 |
+
motion_video: Path = Input(
|
92 |
+
description="Reference video file containing the motion to be mimicked"
|
93 |
+
),
|
94 |
+
appearance_image: Path = Input(
|
95 |
+
description="Reference image file for the appearance of the generated video"
|
96 |
+
),
|
97 |
+
resolution: int = Input(
|
98 |
+
description="Height of the output video in pixels. Width is automatically calculated.",
|
99 |
+
default=576,
|
100 |
+
ge=64,
|
101 |
+
le=1024,
|
102 |
+
),
|
103 |
+
chunk_size: int = Input(
|
104 |
+
description="Number of frames to generate in each processing chunk",
|
105 |
+
default=16,
|
106 |
+
ge=2,
|
107 |
+
),
|
108 |
+
frames_overlap: int = Input(
|
109 |
+
description="Number of overlapping frames between chunks for smoother transitions",
|
110 |
+
default=6,
|
111 |
+
ge=0,
|
112 |
+
),
|
113 |
+
denoising_steps: int = Input(
|
114 |
+
description="Number of denoising steps in the diffusion process. More steps can improve quality but increase processing time.",
|
115 |
+
default=25,
|
116 |
+
ge=1,
|
117 |
+
le=100,
|
118 |
+
),
|
119 |
+
noise_strength: float = Input(
|
120 |
+
description="Strength of noise augmentation. Higher values add more variation but may reduce coherence with the reference.",
|
121 |
+
default=0.0,
|
122 |
+
ge=0.0,
|
123 |
+
le=1.0,
|
124 |
+
),
|
125 |
+
guidance_scale: float = Input(
|
126 |
+
description="Strength of guidance towards the reference. Higher values adhere more closely to the reference but may reduce creativity.",
|
127 |
+
default=2.0,
|
128 |
+
ge=0.1,
|
129 |
+
le=10.0,
|
130 |
+
),
|
131 |
+
sample_stride: int = Input(
|
132 |
+
description="Interval for sampling frames from the reference video. Higher values skip more frames.",
|
133 |
+
default=2,
|
134 |
+
ge=1,
|
135 |
+
),
|
136 |
+
output_frames_per_second: int = Input(
|
137 |
+
description="Frames per second of the output video. Affects playback speed.",
|
138 |
+
default=15,
|
139 |
+
ge=1,
|
140 |
+
le=60,
|
141 |
+
),
|
142 |
+
seed: int = Input(
|
143 |
+
description="Random seed. Leave blank to randomize the seed",
|
144 |
+
default=None,
|
145 |
+
),
|
146 |
+
checkpoint_version: str = Input(
|
147 |
+
description="Choose the checkpoint version to use",
|
148 |
+
choices=["v1", "v1-1"],
|
149 |
+
default="v1-1",
|
150 |
+
),
|
151 |
+
) -> Path:
|
152 |
+
"""Run a single prediction on the model"""
|
153 |
+
|
154 |
+
ref_video = motion_video
|
155 |
+
ref_image = appearance_image
|
156 |
+
num_frames = chunk_size
|
157 |
+
num_inference_steps = denoising_steps
|
158 |
+
noise_aug_strength = noise_strength
|
159 |
+
fps = output_frames_per_second
|
160 |
+
use_fp16 = True
|
161 |
+
|
162 |
+
if seed is None:
|
163 |
+
seed = int.from_bytes(os.urandom(2), "big")
|
164 |
+
print(f"Using seed: {seed}")
|
165 |
+
|
166 |
+
need_pipeline_update = False
|
167 |
+
|
168 |
+
# Check if we need to switch checkpoints
|
169 |
+
if checkpoint_version != self.current_checkpoint:
|
170 |
+
if checkpoint_version == "v1":
|
171 |
+
self.config.ckpt_path = "models/MimicMotion.pth"
|
172 |
+
else: # v1-1
|
173 |
+
self.config.ckpt_path = "models/MimicMotion_1-1.pth"
|
174 |
+
need_pipeline_update = True
|
175 |
+
self.current_checkpoint = checkpoint_version
|
176 |
+
|
177 |
+
# Check if we need to switch dtype
|
178 |
+
target_dtype = torch.float16 if use_fp16 else torch.float32
|
179 |
+
if target_dtype != self.current_dtype:
|
180 |
+
torch.set_default_dtype(target_dtype)
|
181 |
+
need_pipeline_update = True
|
182 |
+
self.current_dtype = target_dtype
|
183 |
+
|
184 |
+
# Update pipeline if needed
|
185 |
+
if need_pipeline_update:
|
186 |
+
print(
|
187 |
+
f"Updating pipeline with checkpoint: {self.config.ckpt_path} and dtype: {torch.get_default_dtype()}"
|
188 |
+
)
|
189 |
+
self.pipeline = create_pipeline(self.config, self.device)
|
190 |
+
|
191 |
+
print(f"Using checkpoint: {self.config.ckpt_path}")
|
192 |
+
print(f"Using dtype: {torch.get_default_dtype()}")
|
193 |
+
|
194 |
+
print(
|
195 |
+
f"[!] ({type(ref_video)}) ref_video={ref_video}, "
|
196 |
+
f"[!] ({type(ref_image)}) ref_image={ref_image}, "
|
197 |
+
f"[!] ({type(resolution)}) resolution={resolution}, "
|
198 |
+
f"[!] ({type(num_frames)}) num_frames={num_frames}, "
|
199 |
+
f"[!] ({type(frames_overlap)}) frames_overlap={frames_overlap}, "
|
200 |
+
f"[!] ({type(num_inference_steps)}) num_inference_steps={num_inference_steps}, "
|
201 |
+
f"[!] ({type(noise_aug_strength)}) noise_aug_strength={noise_aug_strength}, "
|
202 |
+
f"[!] ({type(guidance_scale)}) guidance_scale={guidance_scale}, "
|
203 |
+
f"[!] ({type(sample_stride)}) sample_stride={sample_stride}, "
|
204 |
+
f"[!] ({type(fps)}) fps={fps}, "
|
205 |
+
f"[!] ({type(seed)}) seed={seed}, "
|
206 |
+
f"[!] ({type(use_fp16)}) use_fp16={use_fp16}"
|
207 |
+
)
|
208 |
+
|
209 |
+
# Input validation
|
210 |
+
if not ref_video.exists():
|
211 |
+
raise ValueError(f"Reference video file does not exist: {ref_video}")
|
212 |
+
if not ref_image.exists():
|
213 |
+
raise ValueError(f"Reference image file does not exist: {ref_image}")
|
214 |
+
|
215 |
+
if resolution % 8 != 0:
|
216 |
+
raise ValueError(f"Resolution must be a multiple of 8, got {resolution}")
|
217 |
+
|
218 |
+
if resolution < 64 or resolution > 1024:
|
219 |
+
raise ValueError(
|
220 |
+
f"Resolution must be between 64 and 1024, got {resolution}"
|
221 |
+
)
|
222 |
+
|
223 |
+
if num_frames <= frames_overlap:
|
224 |
+
raise ValueError(
|
225 |
+
f"Number of frames ({num_frames}) must be greater than frames overlap ({frames_overlap})"
|
226 |
+
)
|
227 |
+
|
228 |
+
if num_frames < 2:
|
229 |
+
raise ValueError(f"Number of frames must be at least 2, got {num_frames}")
|
230 |
+
|
231 |
+
if frames_overlap < 0:
|
232 |
+
raise ValueError(
|
233 |
+
f"Frames overlap must be non-negative, got {frames_overlap}"
|
234 |
+
)
|
235 |
+
|
236 |
+
if num_inference_steps < 1 or num_inference_steps > 100:
|
237 |
+
raise ValueError(
|
238 |
+
f"Number of inference steps must be between 1 and 100, got {num_inference_steps}"
|
239 |
+
)
|
240 |
+
|
241 |
+
if noise_aug_strength < 0.0 or noise_aug_strength > 1.0:
|
242 |
+
raise ValueError(
|
243 |
+
f"Noise augmentation strength must be between 0.0 and 1.0, got {noise_aug_strength}"
|
244 |
+
)
|
245 |
+
|
246 |
+
if guidance_scale < 0.1 or guidance_scale > 10.0:
|
247 |
+
raise ValueError(
|
248 |
+
f"Guidance scale must be between 0.1 and 10.0, got {guidance_scale}"
|
249 |
+
)
|
250 |
+
|
251 |
+
if sample_stride < 1:
|
252 |
+
raise ValueError(f"Sample stride must be at least 1, got {sample_stride}")
|
253 |
+
|
254 |
+
if fps < 1 or fps > 60:
|
255 |
+
raise ValueError(f"FPS must be between 1 and 60, got {fps}")
|
256 |
+
|
257 |
+
try:
|
258 |
+
# Preprocess
|
259 |
+
pose_pixels, image_pixels = self.preprocess(
|
260 |
+
str(ref_video),
|
261 |
+
str(ref_image),
|
262 |
+
resolution=resolution,
|
263 |
+
sample_stride=sample_stride,
|
264 |
+
)
|
265 |
+
|
266 |
+
# Run pipeline
|
267 |
+
video_frames = self.run_pipeline(
|
268 |
+
image_pixels,
|
269 |
+
pose_pixels,
|
270 |
+
num_frames=num_frames,
|
271 |
+
frames_overlap=frames_overlap,
|
272 |
+
num_inference_steps=num_inference_steps,
|
273 |
+
noise_aug_strength=noise_aug_strength,
|
274 |
+
guidance_scale=guidance_scale,
|
275 |
+
seed=seed,
|
276 |
+
)
|
277 |
+
|
278 |
+
# Save output
|
279 |
+
output_path = f"/tmp/output_{datetime.now().strftime('%Y%m%d%H%M%S')}.mp4"
|
280 |
+
save_to_mp4(video_frames, output_path, fps=fps)
|
281 |
+
|
282 |
+
return Path(output_path)
|
283 |
+
|
284 |
+
except Exception as e:
|
285 |
+
print(f"An error occurred during prediction: {str(e)}")
|
286 |
+
raise
|
287 |
+
|
288 |
+
def preprocess(self, video_path, image_path, resolution=576, sample_stride=2):
|
289 |
+
image_pixels = Image.open(image_path).convert("RGB")
|
290 |
+
image_pixels = pil_to_tensor(image_pixels) # (c, h, w)
|
291 |
+
h, w = image_pixels.shape[-2:]
|
292 |
+
|
293 |
+
if h > w:
|
294 |
+
w_target, h_target = resolution, int(resolution / ASPECT_RATIO // 64) * 64
|
295 |
+
else:
|
296 |
+
w_target, h_target = int(resolution / ASPECT_RATIO // 64) * 64, resolution
|
297 |
+
|
298 |
+
h_w_ratio = float(h) / float(w)
|
299 |
+
if h_w_ratio < h_target / w_target:
|
300 |
+
h_resize, w_resize = h_target, int(h_target / h_w_ratio)
|
301 |
+
else:
|
302 |
+
h_resize, w_resize = int(w_target * h_w_ratio), w_target
|
303 |
+
|
304 |
+
image_pixels = resize(image_pixels, [h_resize, w_resize], antialias=None)
|
305 |
+
image_pixels = center_crop(image_pixels, [h_target, w_target])
|
306 |
+
image_pixels = image_pixels.permute((1, 2, 0)).numpy()
|
307 |
+
|
308 |
+
image_pose = get_image_pose(image_pixels)
|
309 |
+
video_pose = get_video_pose(
|
310 |
+
video_path, image_pixels, sample_stride=sample_stride
|
311 |
+
)
|
312 |
+
|
313 |
+
pose_pixels = np.concatenate([np.expand_dims(image_pose, 0), video_pose])
|
314 |
+
image_pixels = np.transpose(np.expand_dims(image_pixels, 0), (0, 3, 1, 2))
|
315 |
+
|
316 |
+
return (
|
317 |
+
torch.from_numpy(pose_pixels.copy()) / 127.5 - 1,
|
318 |
+
torch.from_numpy(image_pixels) / 127.5 - 1,
|
319 |
+
)
|
320 |
+
|
321 |
+
def run_pipeline(
|
322 |
+
self,
|
323 |
+
image_pixels,
|
324 |
+
pose_pixels,
|
325 |
+
num_frames,
|
326 |
+
frames_overlap,
|
327 |
+
num_inference_steps,
|
328 |
+
noise_aug_strength,
|
329 |
+
guidance_scale,
|
330 |
+
seed,
|
331 |
+
):
|
332 |
+
image_pixels = [
|
333 |
+
Image.fromarray(
|
334 |
+
(img.cpu().numpy().transpose(1, 2, 0) * 127.5 + 127.5).astype(np.uint8)
|
335 |
+
)
|
336 |
+
for img in image_pixels
|
337 |
+
]
|
338 |
+
pose_pixels = pose_pixels.unsqueeze(0).to(self.device)
|
339 |
+
|
340 |
+
generator = torch.Generator(device=self.device)
|
341 |
+
generator.manual_seed(seed)
|
342 |
+
|
343 |
+
frames = self.pipeline(
|
344 |
+
image_pixels,
|
345 |
+
image_pose=pose_pixels,
|
346 |
+
num_frames=pose_pixels.size(1),
|
347 |
+
tile_size=num_frames,
|
348 |
+
tile_overlap=frames_overlap,
|
349 |
+
height=pose_pixels.shape[-2],
|
350 |
+
width=pose_pixels.shape[-1],
|
351 |
+
fps=7,
|
352 |
+
noise_aug_strength=noise_aug_strength,
|
353 |
+
num_inference_steps=num_inference_steps,
|
354 |
+
generator=generator,
|
355 |
+
min_guidance_scale=guidance_scale,
|
356 |
+
max_guidance_scale=guidance_scale,
|
357 |
+
decode_chunk_size=8,
|
358 |
+
output_type="pt",
|
359 |
+
device=self.device,
|
360 |
+
).frames.cpu()
|
361 |
+
|
362 |
+
video_frames = (frames * 255.0).to(torch.uint8)
|
363 |
+
return video_frames[0, 1:] # Remove the first frame (reference image)
|