diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..92fb47455d7260c8eac7e7391b6a481f672f0754 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+visuals/latte.gif filter=lfs diff=lfs merge=lfs -text
+visuals/latteT2V.gif filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..e2bf27960d069b5280765206620e39a41ff6d917
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+.vscode
+preprocess
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..29f81d812f3e768fa89638d1f72920dbfd1413a8
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
index 22bd3995b4a1aec74a6c808e793098cb8eca225d..a1c18102ba5543877749622a37ebe92d9853fb1d 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,167 @@
----
-title: Latte
-emoji: ๐
-colorFrom: blue
-colorTo: pink
-sdk: gradio
-sdk_version: 4.39.0
-app_file: app.py
-pinned: false
----
-
-Check out the configuration reference at https://huggingface.co./docs/hub/spaces-config-reference
+---
+title: Latte
+app_file: demo.py
+sdk: gradio
+sdk_version: 4.37.2
+---
+## Latte: Latent Diffusion Transformer for Video Generation
Official PyTorch Implementation
+
+
+
+
+[![Arxiv](https://img.shields.io/badge/Arxiv-b31b1b.svg)](https://arxiv.org/abs/2401.03048)
+[![Project Page](https://img.shields.io/badge/Project-Website-blue)](https://maxin-cn.github.io/latte_project/)
+[![HF Demo](https://img.shields.io/static/v1?label=Demo&message=OpenBayes%E8%B4%9D%E5%BC%8F%E8%AE%A1%E7%AE%97&color=green)](https://openbayes.com/console/public/tutorials/UOeU0ywVxl7)
+
+[![Static Badge](https://img.shields.io/badge/Latte--1%20checkpoint%20(T2V)-HuggingFace-yellow?logoColor=violet%20Latte-1%20checkpoint)](https://huggingface.co./maxin-cn/Latte-1)
+[![Static Badge](https://img.shields.io/badge/Latte%20checkpoint%20-HuggingFace-yellow?logoColor=violet%20Latte%20checkpoint)](https://huggingface.co./maxin-cn/Latte)
+
+This repo contains PyTorch model definitions, pre-trained weights, training/sampling code and evaluation code for our paper exploring
+latent diffusion models with transformers (Latte). You can find more visualizations on our [project page](https://maxin-cn.github.io/latte_project/).
+
+> [**Latte: Latent Diffusion Transformer for Video Generation**](https://maxin-cn.github.io/latte_project/)
+> [Xin Ma](https://maxin-cn.github.io/), [Yaohui Wang*](https://wyhsirius.github.io/), [Xinyuan Chen](https://scholar.google.com/citations?user=3fWSC8YAAAAJ), [Gengyun Jia](https://scholar.google.com/citations?user=_04pkGgAAAAJ&hl=zh-CN), [Ziwei Liu](https://liuziwei7.github.io/), [Yuan-Fang Li](https://users.monash.edu/~yli/), [Cunjian Chen](https://cunjian.github.io/), [Yu Qiao](https://scholar.google.com.hk/citations?user=gFtI-8QAAAAJ&hl=zh-CN)
+> (*Corresponding Author & Project Lead)
+
+
+
+
+
+
+## News
+- (๐ฅ New) **Jul 11, 2024** ๐ฅ **Latte-1 is now integrated into [diffusers](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/latte_transformer_3d.py). Thanks to [@yiyixuxu](https://github.com/yiyixuxu), [@sayakpaul](https://github.com/sayakpaul), [@a-r-r-o-w](https://github.com/a-r-r-o-w) and [@DN6](https://github.com/DN6).** You can easily run Latte using the following code. We also support inference with 4/8-bit quantization, which can reduce GPU memory from 17 GB to 9 GB. Please refer to this [tutorial](docs/latte_diffusers.md) for more information.
+
+```
+from diffusers import LattePipeline
+from diffusers.models import AutoencoderKLTemporalDecoder
+from torchvision.utils import save_image
+import torch
+import imageio
+
+torch.manual_seed(0)
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+video_length = 16 # 1 (text-to-image) or 16 (text-to-video)
+pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16).to(device)
+
+# Using temporal decoder of VAE
+vae = AutoencoderKLTemporalDecoder.from_pretrained("maxin-cn/Latte-1", subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
+pipe.vae = vae
+
+prompt = "a cat wearing sunglasses and working as a lifeguard at pool."
+videos = pipe(prompt, video_length=video_length, output_type='pt').frames.cpu()
+```
+
+- (๐ฅ New) **May 23, 2024** ๐ฅ **Latte-1** is released! Pre-trained model can be downloaded [here](https://huggingface.co./maxin-cn/Latte-1/tree/main/transformer). **We support both T2V and T2I**. Please run `bash sample/t2v.sh` and `bash sample/t2i.sh` respectively.
+
+
+
+- (๐ฅ New) **Feb 24, 2024** ๐ฅ We are very grateful that researchers and developers like our work. We will continue to update our LatteT2V model, hoping that our efforts can help the community develop. Our Latte discord channel
+ is created for discussions. Coders are welcome to contribute.
+
+- (๐ฅ New) **Jan 9, 2024** ๐ฅ An updated LatteT2V model initialized with the [PixArt-ฮฑ](https://github.com/PixArt-alpha/PixArt-alpha) is released, the checkpoint can be found [here](https://huggingface.co./maxin-cn/Latte-0/tree/main/transformer).
+
+- (๐ฅ New) **Oct 31, 2023** ๐ฅ The training and inference code is released. All checkpoints (including FaceForensics, SkyTimelapse, UCF101, and Taichi-HD) can be found [here](https://huggingface.co./maxin-cn/Latte/tree/main). In addition, the LatteT2V inference code is provided.
+
+
+## Setup
+
+First, download and set up the repo:
+
+```bash
+git clone https://github.com/Vchitect/Latte
+cd Latte
+```
+
+We provide an [`environment.yml`](environment.yml) file that can be used to create a Conda environment. If you only want
+to run pre-trained models locally on CPU, you can remove the `cudatoolkit` and `pytorch-cuda` requirements from the file.
+
+```bash
+conda env create -f environment.yml
+conda activate latte
+```
+
+
+## Sampling
+
+You can sample from our **pre-trained Latte models** with [`sample.py`](sample/sample.py). Weights for our pre-trained Latte model can be found [here](https://huggingface.co./maxin-cn/Latte). The script has various arguments to adjust sampling steps, change the classifier-free guidance scale, etc. For example, to sample from our model on FaceForensics, you can use:
+
+```bash
+bash sample/ffs.sh
+```
+
+or if you want to sample hundreds of videos, you can use the following script with Pytorch DDP:
+
+```bash
+bash sample/ffs_ddp.sh
+```
+
+If you want to try generating videos from text, just run `bash sample/t2v.sh`. All related checkpoints will download automatically.
+
+If you would like to measure the quantitative metrics of your generated results, please refer to [here](docs/datasets_evaluation.md).
+
+## Training
+
+We provide a training script for Latte in [`train.py`](train.py). The structure of the datasets can be found [here](docs/datasets_evaluation.md). This script can be used to train class-conditional and unconditional
+Latte models. To launch Latte (256x256) training with `N` GPUs on the FaceForensics dataset
+:
+
+```bash
+torchrun --nnodes=1 --nproc_per_node=N train.py --config ./configs/ffs/ffs_train.yaml
+```
+
+or If you have a cluster that uses slurm, you can also train Latte's model using the following scripts:
+
+ ```bash
+sbatch slurm_scripts/ffs.slurm
+```
+
+We also provide the video-image joint training scripts [`train_with_img.py`](train_with_img.py). Similar to [`train.py`](train.py) scripts, these scripts can be also used to train class-conditional and unconditional
+Latte models. For example, if you want to train the Latte model on the FaceForensics dataset, you can use:
+
+```bash
+torchrun --nnodes=1 --nproc_per_node=N train_with_img.py --config ./configs/ffs/ffs_img_train.yaml
+```
+
+## Contact Us
+**Yaohui Wang**: [wangyaohui@pjlab.org.cn](mailto:wangyaohui@pjlab.org.cn)
+**Xin Ma**: [xin.ma1@monash.edu](mailto:xin.ma1@monash.edu)
+
+## Citation
+If you find this work useful for your research, please consider citing it.
+```bibtex
+@article{ma2024latte,
+ title={Latte: Latent Diffusion Transformer for Video Generation},
+ author={Ma, Xin and Wang, Yaohui and Jia, Gengyun and Chen, Xinyuan and Liu, Ziwei and Li, Yuan-Fang and Chen, Cunjian and Qiao, Yu},
+ journal={arXiv preprint arXiv:2401.03048},
+ year={2024}
+}
+```
+
+
+## Acknowledgments
+Latte has been greatly inspired by the following amazing works and teams: [DiT](https://github.com/facebookresearch/DiT) and [PixArt-ฮฑ](https://github.com/PixArt-alpha/PixArt-alpha), we thank all the contributors for open-sourcing.
+
+
+## License
+The code and model weights are licensed under [LICENSE](LICENSE).
diff --git a/configs/ffs/ffs_img_train.yaml b/configs/ffs/ffs_img_train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..15315fccda1894d401778010f2337a3377596cad
--- /dev/null
+++ b/configs/ffs/ffs_img_train.yaml
@@ -0,0 +1,45 @@
+# dataset
+dataset: "ffs_img"
+
+data_path: "/path/to/datasets/preprocessed_ffs/train/videos/"
+frame_data_path: "/path/to/datasets/preprocessed_ffs/train/images/"
+frame_data_txt: "/path/to/datasets/preprocessed_ffs/train_list.txt"
+pretrained_model_path: "/path/to/pretrained/Latte/"
+
+# save and load
+results_dir: "./results_img"
+pretrained:
+
+# model config:
+model: LatteIMG-XL/2
+num_frames: 16
+image_size: 256 # choices=[256, 512]
+num_sampling_steps: 250
+frame_interval: 3
+fixed_spatial: False
+attention_bias: True
+learn_sigma: True # important
+extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation
+
+# train config:
+save_ceph: True # important
+use_image_num: 8
+learning_rate: 1e-4
+ckpt_every: 10000
+clip_max_norm: 0.1
+start_clip_iter: 500000
+local_batch_size: 4 # important
+max_train_steps: 1000000
+global_seed: 3407
+num_workers: 8
+log_every: 100
+lr_warmup_steps: 0
+resume_from_checkpoint:
+gradient_accumulation_steps: 1 # TODO
+num_classes:
+
+# low VRAM and speed up training
+use_compile: False
+mixed_precision: False
+enable_xformers_memory_efficient_attention: False
+gradient_checkpointing: False
\ No newline at end of file
diff --git a/configs/ffs/ffs_sample.yaml b/configs/ffs/ffs_sample.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0a223a5f504ab19a3281722e2282f98b18d8773c
--- /dev/null
+++ b/configs/ffs/ffs_sample.yaml
@@ -0,0 +1,30 @@
+# path:
+ckpt: # will be overwrite
+save_img_path: "./sample_videos" # will be overwrite
+pretrained_model_path: "/path/to/pretrained/Latte/"
+
+# model config:
+model: Latte-XL/2
+num_frames: 16
+image_size: 256 # choices=[256, 512]
+frame_interval: 2
+fixed_spatial: False
+attention_bias: True
+learn_sigma: True
+extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation
+num_classes:
+
+# model speedup
+use_compile: False
+use_fp16: True
+
+# sample config:
+seed:
+sample_method: 'ddpm'
+num_sampling_steps: 250
+cfg_scale: 1.0
+negative_name:
+
+# ddp sample config
+per_proc_batch_size: 2
+num_fvd_samples: 2048
\ No newline at end of file
diff --git a/configs/ffs/ffs_train.yaml b/configs/ffs/ffs_train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ecd341c8b65e54c4119f1acf181fb33fe3faed44
--- /dev/null
+++ b/configs/ffs/ffs_train.yaml
@@ -0,0 +1,42 @@
+# dataset
+dataset: "ffs"
+
+data_path: "/path/to/datasets/preprocess_ffs/train/videos/" # s
+pretrained_model_path: "/path/to/pretrained/Latte/"
+
+# save and load
+results_dir: "./results"
+pretrained:
+
+# model config:
+model: Latte-XL/2
+num_frames: 16
+image_size: 256 # choices=[256, 512]
+num_sampling_steps: 250
+frame_interval: 3
+fixed_spatial: False
+attention_bias: True
+learn_sigma: True # important
+extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation
+
+# train config:
+save_ceph: True # important
+learning_rate: 1e-4
+ckpt_every: 10000
+clip_max_norm: 0.1
+start_clip_iter: 20000
+local_batch_size: 5 # important
+max_train_steps: 1000000
+global_seed: 3407
+num_workers: 8
+log_every: 100
+lr_warmup_steps: 0
+resume_from_checkpoint:
+gradient_accumulation_steps: 1 # TODO
+num_classes:
+
+# low VRAM and speed up training
+use_compile: False
+mixed_precision: False
+enable_xformers_memory_efficient_attention: False
+gradient_checkpointing: False
\ No newline at end of file
diff --git a/configs/sky/sky_img_train.yaml b/configs/sky/sky_img_train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a7b4719553a7d33724f2120afb1f03f6296e7447
--- /dev/null
+++ b/configs/sky/sky_img_train.yaml
@@ -0,0 +1,43 @@
+# dataset
+dataset: "sky_img"
+
+data_path: "/path/to/datasets/sky_timelapse/sky_train/" # s/p
+pretrained_model_path: "/path/to/pretrained/Latte/"
+
+# save and load
+results_dir: "./results_img"
+pretrained:
+
+# model config:
+model: LatteIMG-XL/2
+num_frames: 16
+image_size: 256 # choices=[256, 512]
+num_sampling_steps: 250
+frame_interval: 3
+fixed_spatial: False
+attention_bias: True
+learn_sigma: True
+extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation
+
+# train config:
+save_ceph: True # important
+use_image_num: 8 # important
+learning_rate: 1e-4
+ckpt_every: 10000
+clip_max_norm: 0.1
+start_clip_iter: 20000
+local_batch_size: 4 # important
+max_train_steps: 1000000
+global_seed: 3407
+num_workers: 8
+log_every: 50
+lr_warmup_steps: 0
+resume_from_checkpoint:
+gradient_accumulation_steps: 1 # TODO
+num_classes:
+
+# low VRAM and speed up training
+use_compile: False
+mixed_precision: False
+enable_xformers_memory_efficient_attention: False
+gradient_checkpointing: False
\ No newline at end of file
diff --git a/configs/sky/sky_sample.yaml b/configs/sky/sky_sample.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ea7d7114c02446220542121e5429e3bb34954982
--- /dev/null
+++ b/configs/sky/sky_sample.yaml
@@ -0,0 +1,32 @@
+# path:
+ckpt: # will be overwrite
+save_img_path: "./sample_videos/" # will be overwrite
+pretrained_model_path: "/path/to/pretrained/Latte/"
+
+# model config:
+model: Latte-XL/2
+num_frames: 16
+image_size: 256 # choices=[256, 512]
+frame_interval: 2
+fixed_spatial: False
+attention_bias: True
+learn_sigma: True
+extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation
+num_classes:
+
+# model speedup
+use_compile: False
+use_fp16: True
+
+# sample config:
+seed:
+sample_method: 'ddpm'
+num_sampling_steps: 250
+cfg_scale: 1.0
+run_time: 12
+num_sample: 1
+negative_name:
+
+# ddp sample config
+per_proc_batch_size: 1
+num_fvd_samples: 2
\ No newline at end of file
diff --git a/configs/sky/sky_train.yaml b/configs/sky/sky_train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8d8014fce7ca489ec7f7d50d99b48706db3f300d
--- /dev/null
+++ b/configs/sky/sky_train.yaml
@@ -0,0 +1,42 @@
+# dataset
+dataset: "sky"
+
+data_path: "/path/to/datasets/sky_timelapse/sky_train/"
+pretrained_model_path: "/path/to/pretrained/Latte/"
+
+# save and load
+results_dir: "./results"
+pretrained:
+
+# model config:
+model: Latte-XL/2
+num_frames: 16
+image_size: 256 # choices=[256, 512]
+num_sampling_steps: 250
+frame_interval: 3
+fixed_spatial: False
+attention_bias: True
+learn_sigma: True
+extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation
+
+# train config:
+save_ceph: True # important
+learning_rate: 1e-4
+ckpt_every: 10000
+clip_max_norm: 0.1
+start_clip_iter: 20000
+local_batch_size: 5 # important
+max_train_steps: 1000000
+global_seed: 3407
+num_workers: 8
+log_every: 50
+lr_warmup_steps: 0
+resume_from_checkpoint:
+gradient_accumulation_steps: 1 # TODO
+num_classes:
+
+# low VRAM and speed up training
+use_compile: False
+mixed_precision: False
+enable_xformers_memory_efficient_attention: False
+gradient_checkpointing: False
\ No newline at end of file
diff --git a/configs/t2x/t2i_sample.yaml b/configs/t2x/t2i_sample.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a82b1618023e42534285b6c2a200d9b10f6f3e9c
--- /dev/null
+++ b/configs/t2x/t2i_sample.yaml
@@ -0,0 +1,37 @@
+# path:
+save_img_path: "./sample_videos/t2i-"
+pretrained_model_path: "maxin-cn/Latte-1"
+
+# model config:
+# maxin-cn/Latte-0: the first released version
+# maxin-cn/Latte-1: the second version with better performance (released on May. 23, 2024)
+model: LatteT2V
+video_length: 1
+image_size: [512, 512]
+# # beta schedule
+beta_start: 0.0001
+beta_end: 0.02
+beta_schedule: "linear"
+variance_type: "learned_range"
+
+# model speedup
+use_compile: False
+use_fp16: True
+
+# sample config:
+seed:
+run_time: 0
+guidance_scale: 7.5
+sample_method: 'DDIM'
+num_sampling_steps: 50
+enable_temporal_attentions: True # LatteT2V-V0: set to False; LatteT2V-V1: set to True
+enable_vae_temporal_decoder: False
+
+text_prompt: [
+ 'Yellow and black tropical fish dart through the sea.',
+ 'An epic tornado attacking above aglowing city at night.',
+ 'Slow pan upward of blazing oak fire in an indoor fireplace.',
+ 'a cat wearing sunglasses and working as a lifeguard at pool.',
+ 'Sunset over the sea.',
+ 'A dog in astronaut suit and sunglasses floating in space.',
+ ]
\ No newline at end of file
diff --git a/configs/t2x/t2v_sample.yaml b/configs/t2x/t2v_sample.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5bd1381087425d00359daa9dfad8d8cf1317eeb0
--- /dev/null
+++ b/configs/t2x/t2v_sample.yaml
@@ -0,0 +1,37 @@
+# path:
+save_img_path: "./sample_videos/t2v-"
+pretrained_model_path: "/data/monash_vidgen/pretrained/Latte-1"
+
+# model config:
+# maxin-cn/Latte-0: the first released version
+# maxin-cn/Latte-1: the second version with better performance (released on May. 23, 2024)
+model: LatteT2V
+video_length: 16
+image_size: [512, 512]
+# # beta schedule
+beta_start: 0.0001
+beta_end: 0.02
+beta_schedule: "linear"
+variance_type: "learned_range"
+
+# model speedup
+use_compile: False
+use_fp16: True
+
+# sample config:
+seed: 0
+run_time: 0
+guidance_scale: 7.5
+sample_method: 'DDIM'
+num_sampling_steps: 50
+enable_temporal_attentions: True
+enable_vae_temporal_decoder: True # use temporal vae decoder from SVD, maybe reduce the video flicker (It's not widely tested)
+
+text_prompt: [
+ 'Yellow and black tropical fish dart through the sea.',
+ 'An epic tornado attacking above aglowing city at night.',
+ 'Slow pan upward of blazing oak fire in an indoor fireplace.',
+ 'a cat wearing sunglasses and working as a lifeguard at pool.',
+ 'Sunset over the sea.',
+ 'A dog in astronaut suit and sunglasses floating in space.',
+ ]
\ No newline at end of file
diff --git a/configs/taichi/taichi_img_train.yaml b/configs/taichi/taichi_img_train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4c4e34435d6098b205e04f03505aeae56c3bc00b
--- /dev/null
+++ b/configs/taichi/taichi_img_train.yaml
@@ -0,0 +1,43 @@
+# dataset
+dataset: "taichi_img"
+
+data_path: "/path/to/datasets/taichi"
+pretrained_model_path: "/path/to/pretrained/Latte/"
+
+# save and load
+results_dir: "./results_img"
+pretrained:
+
+# model config:
+model: LatteIMG-XL/2
+num_frames: 16
+image_size: 256 # choices=[256, 512]
+num_sampling_steps: 250
+frame_interval: 3
+fixed_spatial: False
+attention_bias: True
+learn_sigma: True
+extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation
+
+# train config:
+load_from_ceph: False # important
+use_image_num: 8
+learning_rate: 1e-4
+ckpt_every: 10000
+clip_max_norm: 0.1
+start_clip_iter: 500000
+local_batch_size: 4 # important
+max_train_steps: 1000000
+global_seed: 3407
+num_workers: 8
+log_every: 50
+lr_warmup_steps: 0
+resume_from_checkpoint:
+gradient_accumulation_steps: 1 # TODO
+num_classes:
+
+# low VRAM and speed up training TODO
+use_compile: False
+mixed_precision: False
+enable_xformers_memory_efficient_attention: False
+gradient_checkpointing: False
\ No newline at end of file
diff --git a/configs/taichi/taichi_sample.yaml b/configs/taichi/taichi_sample.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..559bffa89d225062c74b5ed78ebdbe628016cf26
--- /dev/null
+++ b/configs/taichi/taichi_sample.yaml
@@ -0,0 +1,30 @@
+# path:
+ckpt: # will be overwrite
+save_img_path: "./sample_videos/" # will be overwrite
+pretrained_model_path: "/path/to/pretrained/Latte/"
+
+# model config:
+model: Latte-XL/2
+num_frames: 16
+image_size: 256 # choices=[256, 512]
+frame_interval: 2
+fixed_spatial: False
+attention_bias: True
+learn_sigma: True
+extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation
+num_classes:
+
+# model speedup
+use_compile: False
+use_fp16: True
+
+# sample config:
+seed:
+sample_method: 'ddpm'
+num_sampling_steps: 250
+cfg_scale: 1.0
+negative_name:
+
+# ddp sample config
+per_proc_batch_size: 1
+num_fvd_samples: 2
\ No newline at end of file
diff --git a/configs/taichi/taichi_train.yaml b/configs/taichi/taichi_train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9bfdbd4921c65631057f733cd5f734dd2a0516c9
--- /dev/null
+++ b/configs/taichi/taichi_train.yaml
@@ -0,0 +1,42 @@
+# dataset
+dataset: "taichi"
+
+data_path: "/path/to/datasets/taichi"
+pretrained_model_path: "/path/to/pretrained/Latte/"
+
+# save and load
+results_dir: "./results"
+pretrained:
+
+# model config:
+model: Latte-XL/2
+num_frames: 16
+image_size: 256 # choices=[256, 512]
+num_sampling_steps: 250
+frame_interval: 3
+fixed_spatial: False
+attention_bias: True
+learn_sigma: True
+extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation
+
+# train config:
+load_from_ceph: False # important
+learning_rate: 1e-4
+ckpt_every: 10000
+clip_max_norm: 0.1
+start_clip_iter: 500000
+local_batch_size: 5 # important
+max_train_steps: 1000000
+global_seed: 3407
+num_workers: 8
+log_every: 50
+lr_warmup_steps: 0
+resume_from_checkpoint:
+gradient_accumulation_steps: 1 # TODO
+num_classes:
+
+# low VRAM and speed up training TODO
+use_compile: False
+mixed_precision: False
+enable_xformers_memory_efficient_attention: False
+gradient_checkpointing: False
\ No newline at end of file
diff --git a/configs/ucf101/ucf101_img_train.yaml b/configs/ucf101/ucf101_img_train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ba8c8471b1086a2c3f94ed0f8070e90ef66ace5f
--- /dev/null
+++ b/configs/ucf101/ucf101_img_train.yaml
@@ -0,0 +1,44 @@
+# dataset
+dataset: "ucf101_img"
+
+data_path: "/path/to/datasets/UCF101/videos/"
+frame_data_txt: "/path/to/datasets/UCF101/train_256_list.txt"
+pretrained_model_path: "/path/to/pretrained/Latte/"
+
+# save and load
+results_dir: "./results_img"
+pretrained:
+
+# model config:
+model: LatteIMG-XL/2
+num_frames: 16
+image_size: 256 # choices=[256, 512]
+num_sampling_steps: 250
+frame_interval: 3
+fixed_spatial: False
+attention_bias: True
+learn_sigma: True
+extras: 2 # [1, 2] 1 unconditional generation, 2 class-conditional generation
+
+# train config:
+save_ceph: True # important
+use_image_num: 8 # important
+learning_rate: 1e-4
+ckpt_every: 10000
+clip_max_norm: 0.1
+start_clip_iter: 100000
+local_batch_size: 4 # important
+max_train_steps: 1000000
+global_seed: 3407
+num_workers: 8
+log_every: 50
+lr_warmup_steps: 0
+resume_from_checkpoint:
+gradient_accumulation_steps: 1 # TODO
+num_classes: 101
+
+# low VRAM and speed up training
+use_compile: False
+mixed_precision: False
+enable_xformers_memory_efficient_attention: False
+gradient_checkpointing: False
\ No newline at end of file
diff --git a/configs/ucf101/ucf101_sample.yaml b/configs/ucf101/ucf101_sample.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..57e0c6ddb3864eceee134a8f736820bdabf612e1
--- /dev/null
+++ b/configs/ucf101/ucf101_sample.yaml
@@ -0,0 +1,33 @@
+# path:
+ckpt:
+save_img_path: "./sample_videos/"
+pretrained_model_path: "/path/to/pretrained/Latte/"
+
+# model config:
+model: Latte-XL/2
+num_frames: 16
+image_size: 256 # choices=[256, 512]
+frame_interval: 3
+fixed_spatial: False
+attention_bias: True
+learn_sigma: True
+extras: 2 # [1, 2] 1 unconditional generation, 2 class-conditional generation
+num_classes: 101
+
+# model speedup
+use_compile: False
+use_fp16: True
+
+# sample config:
+seed:
+sample_method: 'ddpm'
+num_sampling_steps: 250
+cfg_scale: 7.0
+run_time: 12
+num_sample: 1
+sample_names: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
+negative_name: 101
+
+# ddp sample config
+per_proc_batch_size: 2
+num_fvd_samples: 2
\ No newline at end of file
diff --git a/configs/ucf101/ucf101_train.yaml b/configs/ucf101/ucf101_train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..163ee82ae7c3a5b5e06c09b31bf13b79bccc109a
--- /dev/null
+++ b/configs/ucf101/ucf101_train.yaml
@@ -0,0 +1,42 @@
+# dataset
+dataset: "ucf101"
+
+data_path: "/path/to/datasets/UCF101/videos/"
+pretrained_model_path: "/path/to/pretrained/Latte/"
+
+# save and load
+results_dir: "./results"
+pretrained:
+
+# model config:
+model: Latte-XL/2
+num_frames: 16
+image_size: 256 # choices=[256, 512]
+num_sampling_steps: 250
+frame_interval: 3
+fixed_spatial: False
+attention_bias: True
+learn_sigma: True
+extras: 2 # [1, 2] 1 unconditional generation, 2 class-conditional generation
+
+# train config:
+save_ceph: True # important
+learning_rate: 1e-4
+ckpt_every: 10000
+clip_max_norm: 0.1
+start_clip_iter: 100000
+local_batch_size: 5 # important
+max_train_steps: 1000000
+global_seed: 3407
+num_workers: 8
+log_every: 50
+lr_warmup_steps: 0
+resume_from_checkpoint:
+gradient_accumulation_steps: 1 # TODO
+num_classes: 101
+
+# low VRAM and speed up training
+use_compile: False
+mixed_precision: False
+enable_xformers_memory_efficient_attention: False
+gradient_checkpointing: False
\ No newline at end of file
diff --git a/datasets/__init__.py b/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..76dd08db8b3a2f89fd9af79a0a99730b1a1e3e2f
--- /dev/null
+++ b/datasets/__init__.py
@@ -0,0 +1,79 @@
+from .sky_datasets import Sky
+from torchvision import transforms
+from .taichi_datasets import Taichi
+from datasets import video_transforms
+from .ucf101_datasets import UCF101
+from .ffs_datasets import FaceForensics
+from .ffs_image_datasets import FaceForensicsImages
+from .sky_image_datasets import SkyImages
+from .ucf101_image_datasets import UCF101Images
+from .taichi_image_datasets import TaichiImages
+
+
+def get_dataset(args):
+ temporal_sample = video_transforms.TemporalRandomCrop(args.num_frames * args.frame_interval) # 16 1
+
+ if args.dataset == 'ffs':
+ transform_ffs = transforms.Compose([
+ video_transforms.ToTensorVideo(), # TCHW
+ video_transforms.RandomHorizontalFlipVideo(),
+ video_transforms.UCFCenterCropVideo(args.image_size),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+ return FaceForensics(args, transform=transform_ffs, temporal_sample=temporal_sample)
+ elif args.dataset == 'ffs_img':
+ transform_ffs = transforms.Compose([
+ video_transforms.ToTensorVideo(), # TCHW
+ video_transforms.RandomHorizontalFlipVideo(),
+ video_transforms.UCFCenterCropVideo(args.image_size),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+ return FaceForensicsImages(args, transform=transform_ffs, temporal_sample=temporal_sample)
+ elif args.dataset == 'ucf101':
+ transform_ucf101 = transforms.Compose([
+ video_transforms.ToTensorVideo(), # TCHW
+ video_transforms.RandomHorizontalFlipVideo(),
+ video_transforms.UCFCenterCropVideo(args.image_size),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+ return UCF101(args, transform=transform_ucf101, temporal_sample=temporal_sample)
+ elif args.dataset == 'ucf101_img':
+ transform_ucf101 = transforms.Compose([
+ video_transforms.ToTensorVideo(), # TCHW
+ video_transforms.RandomHorizontalFlipVideo(),
+ video_transforms.UCFCenterCropVideo(args.image_size),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+ return UCF101Images(args, transform=transform_ucf101, temporal_sample=temporal_sample)
+ elif args.dataset == 'taichi':
+ transform_taichi = transforms.Compose([
+ video_transforms.ToTensorVideo(), # TCHW
+ video_transforms.RandomHorizontalFlipVideo(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+ return Taichi(args, transform=transform_taichi, temporal_sample=temporal_sample)
+ elif args.dataset == 'taichi_img':
+ transform_taichi = transforms.Compose([
+ video_transforms.ToTensorVideo(), # TCHW
+ video_transforms.RandomHorizontalFlipVideo(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+ return TaichiImages(args, transform=transform_taichi, temporal_sample=temporal_sample)
+ elif args.dataset == 'sky':
+ transform_sky = transforms.Compose([
+ video_transforms.ToTensorVideo(),
+ video_transforms.CenterCropResizeVideo(args.image_size),
+ # video_transforms.RandomHorizontalFlipVideo(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+ return Sky(args, transform=transform_sky, temporal_sample=temporal_sample)
+ elif args.dataset == 'sky_img':
+ transform_sky = transforms.Compose([
+ video_transforms.ToTensorVideo(),
+ video_transforms.CenterCropResizeVideo(args.image_size),
+ # video_transforms.RandomHorizontalFlipVideo(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+ return SkyImages(args, transform=transform_sky, temporal_sample=temporal_sample)
+ else:
+ raise NotImplementedError(args.dataset)
\ No newline at end of file
diff --git a/datasets/ffs_datasets.py b/datasets/ffs_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..aceb99164d215cc6634facd2a8013173683bdec6
--- /dev/null
+++ b/datasets/ffs_datasets.py
@@ -0,0 +1,164 @@
+import os
+import json
+import torch
+import decord
+import torchvision
+
+import numpy as np
+
+
+from PIL import Image
+from einops import rearrange
+from typing import Dict, List, Tuple
+
+class_labels_map = None
+cls_sample_cnt = None
+
+def temporal_sampling(frames, start_idx, end_idx, num_samples):
+ """
+ Given the start and end frame index, sample num_samples frames between
+ the start and end with equal interval.
+ Args:
+ frames (tensor): a tensor of video frames, dimension is
+ `num video frames` x `channel` x `height` x `width`.
+ start_idx (int): the index of the start frame.
+ end_idx (int): the index of the end frame.
+ num_samples (int): number of frames to sample.
+ Returns:
+ frames (tersor): a tensor of temporal sampled video frames, dimension is
+ `num clip frames` x `channel` x `height` x `width`.
+ """
+ index = torch.linspace(start_idx, end_idx, num_samples)
+ index = torch.clamp(index, 0, frames.shape[0] - 1).long()
+ frames = torch.index_select(frames, 0, index)
+ return frames
+
+
+def numpy2tensor(x):
+ return torch.from_numpy(x)
+
+
+def get_filelist(file_path):
+ Filelist = []
+ for home, dirs, files in os.walk(file_path):
+ for filename in files:
+ Filelist.append(os.path.join(home, filename))
+ # Filelist.append( filename)
+ return Filelist
+
+
+def load_annotation_data(data_file_path):
+ with open(data_file_path, 'r') as data_file:
+ return json.load(data_file)
+
+
+def get_class_labels(num_class, anno_pth='./k400_classmap.json'):
+ global class_labels_map, cls_sample_cnt
+
+ if class_labels_map is not None:
+ return class_labels_map, cls_sample_cnt
+ else:
+ cls_sample_cnt = {}
+ class_labels_map = load_annotation_data(anno_pth)
+ for cls in class_labels_map:
+ cls_sample_cnt[cls] = 0
+ return class_labels_map, cls_sample_cnt
+
+
+def load_annotations(ann_file, num_class, num_samples_per_cls):
+ dataset = []
+ class_to_idx, cls_sample_cnt = get_class_labels(num_class)
+ with open(ann_file, 'r') as fin:
+ for line in fin:
+ line_split = line.strip().split('\t')
+ sample = {}
+ idx = 0
+ # idx for frame_dir
+ frame_dir = line_split[idx]
+ sample['video'] = frame_dir
+ idx += 1
+
+ # idx for label[s]
+ label = [x for x in line_split[idx:]]
+ assert label, f'missing label in line: {line}'
+ assert len(label) == 1
+ class_name = label[0]
+ class_index = int(class_to_idx[class_name])
+
+ # choose a class subset of whole dataset
+ if class_index < num_class:
+ sample['label'] = class_index
+ if cls_sample_cnt[class_name] < num_samples_per_cls:
+ dataset.append(sample)
+ cls_sample_cnt[class_name]+=1
+
+ return dataset
+
+
+class DecordInit(object):
+ """Using Decord(https://github.com/dmlc/decord) to initialize the video_reader."""
+
+ def __init__(self, num_threads=1, **kwargs):
+ self.num_threads = num_threads
+ self.ctx = decord.cpu(0)
+ self.kwargs = kwargs
+
+ def __call__(self, filename):
+ """Perform the Decord initialization.
+ Args:
+ results (dict): The resulting dict to be modified and passed
+ to the next transform in pipeline.
+ """
+ reader = decord.VideoReader(filename,
+ ctx=self.ctx,
+ num_threads=self.num_threads)
+ return reader
+
+ def __repr__(self):
+ repr_str = (f'{self.__class__.__name__}('
+ f'sr={self.sr},'
+ f'num_threads={self.num_threads})')
+ return repr_str
+
+
+class FaceForensics(torch.utils.data.Dataset):
+ """Load the FaceForensics video files
+
+ Args:
+ target_video_len (int): the number of video frames will be load.
+ align_transform (callable): Align different videos in a specified size.
+ temporal_sample (callable): Sample the target length of a video.
+ """
+
+ def __init__(self,
+ configs,
+ transform=None,
+ temporal_sample=None):
+ self.configs = configs
+ self.data_path = configs.data_path
+ self.video_lists = get_filelist(configs.data_path)
+ self.transform = transform
+ self.temporal_sample = temporal_sample
+ self.target_video_len = self.configs.num_frames
+ self.v_decoder = DecordInit()
+
+ def __getitem__(self, index):
+ path = self.video_lists[index]
+ vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
+ total_frames = len(vframes)
+
+ # Sampling video frames
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
+ assert end_frame_ind - start_frame_ind >= self.target_video_len
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int)
+ video = vframes[frame_indice]
+ # videotransformer data proprecess
+ video = self.transform(video) # T C H W
+ return {'video': video, 'video_name': 1}
+
+ def __len__(self):
+ return len(self.video_lists)
+
+
+if __name__ == '__main__':
+ pass
\ No newline at end of file
diff --git a/datasets/ffs_image_datasets.py b/datasets/ffs_image_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..1140e16cac1cad01a55fe4e61579e01ad4cfa668
--- /dev/null
+++ b/datasets/ffs_image_datasets.py
@@ -0,0 +1,246 @@
+import os
+import json
+import torch
+import decord
+import torchvision
+
+import numpy as np
+
+import random
+from PIL import Image
+from einops import rearrange
+from typing import Dict, List, Tuple
+from torchvision import transforms
+import traceback
+
+class_labels_map = None
+cls_sample_cnt = None
+
+def temporal_sampling(frames, start_idx, end_idx, num_samples):
+ """
+ Given the start and end frame index, sample num_samples frames between
+ the start and end with equal interval.
+ Args:
+ frames (tensor): a tensor of video frames, dimension is
+ `num video frames` x `channel` x `height` x `width`.
+ start_idx (int): the index of the start frame.
+ end_idx (int): the index of the end frame.
+ num_samples (int): number of frames to sample.
+ Returns:
+ frames (tersor): a tensor of temporal sampled video frames, dimension is
+ `num clip frames` x `channel` x `height` x `width`.
+ """
+ index = torch.linspace(start_idx, end_idx, num_samples)
+ index = torch.clamp(index, 0, frames.shape[0] - 1).long()
+ frames = torch.index_select(frames, 0, index)
+ return frames
+
+
+def numpy2tensor(x):
+ return torch.from_numpy(x)
+
+
+def get_filelist(file_path):
+ Filelist = []
+ for home, dirs, files in os.walk(file_path):
+ for filename in files:
+ # ๆไปถๅๅ่กจ๏ผๅ
ๅซๅฎๆด่ทฏๅพ
+ Filelist.append(os.path.join(home, filename))
+ # # ๆไปถๅๅ่กจ๏ผๅชๅ
ๅซๆไปถๅ
+ # Filelist.append( filename)
+ return Filelist
+
+
+def load_annotation_data(data_file_path):
+ with open(data_file_path, 'r') as data_file:
+ return json.load(data_file)
+
+
+def get_class_labels(num_class, anno_pth='./k400_classmap.json'):
+ global class_labels_map, cls_sample_cnt
+
+ if class_labels_map is not None:
+ return class_labels_map, cls_sample_cnt
+ else:
+ cls_sample_cnt = {}
+ class_labels_map = load_annotation_data(anno_pth)
+ for cls in class_labels_map:
+ cls_sample_cnt[cls] = 0
+ return class_labels_map, cls_sample_cnt
+
+
+def load_annotations(ann_file, num_class, num_samples_per_cls):
+ dataset = []
+ class_to_idx, cls_sample_cnt = get_class_labels(num_class)
+ with open(ann_file, 'r') as fin:
+ for line in fin:
+ line_split = line.strip().split('\t')
+ sample = {}
+ idx = 0
+ # idx for frame_dir
+ frame_dir = line_split[idx]
+ sample['video'] = frame_dir
+ idx += 1
+
+ # idx for label[s]
+ label = [x for x in line_split[idx:]]
+ assert label, f'missing label in line: {line}'
+ assert len(label) == 1
+ class_name = label[0]
+ class_index = int(class_to_idx[class_name])
+
+ # choose a class subset of whole dataset
+ if class_index < num_class:
+ sample['label'] = class_index
+ if cls_sample_cnt[class_name] < num_samples_per_cls:
+ dataset.append(sample)
+ cls_sample_cnt[class_name]+=1
+
+ return dataset
+
+
+class DecordInit(object):
+ """Using Decord(https://github.com/dmlc/decord) to initialize the video_reader."""
+
+ def __init__(self, num_threads=1, **kwargs):
+ self.num_threads = num_threads
+ self.ctx = decord.cpu(0)
+ self.kwargs = kwargs
+
+ def __call__(self, filename):
+ """Perform the Decord initialization.
+ Args:
+ results (dict): The resulting dict to be modified and passed
+ to the next transform in pipeline.
+ """
+ reader = decord.VideoReader(filename,
+ ctx=self.ctx,
+ num_threads=self.num_threads)
+ return reader
+
+ def __repr__(self):
+ repr_str = (f'{self.__class__.__name__}('
+ f'sr={self.sr},'
+ f'num_threads={self.num_threads})')
+ return repr_str
+
+
+class FaceForensicsImages(torch.utils.data.Dataset):
+ """Load the FaceForensics video files
+
+ Args:
+ target_video_len (int): the number of video frames will be load.
+ align_transform (callable): Align different videos in a specified size.
+ temporal_sample (callable): Sample the target length of a video.
+ """
+
+ def __init__(self,
+ configs,
+ transform=None,
+ temporal_sample=None):
+ self.configs = configs
+ self.data_path = configs.data_path
+ self.video_lists = get_filelist(configs.data_path)
+ self.transform = transform
+ self.temporal_sample = temporal_sample
+ self.target_video_len = self.configs.num_frames
+ self.v_decoder = DecordInit()
+ self.video_length = len(self.video_lists)
+
+ # ffs video frames
+ self.video_frame_path = configs.frame_data_path
+ self.video_frame_txt = configs.frame_data_txt
+ self.video_frame_files = [frame_file.strip() for frame_file in open(self.video_frame_txt)]
+ random.shuffle(self.video_frame_files)
+ self.use_image_num = configs.use_image_num
+ self.image_tranform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+
+ def __getitem__(self, index):
+ video_index = index % self.video_length
+ path = self.video_lists[video_index]
+ vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
+ total_frames = len(vframes)
+
+ # Sampling video frames
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
+ assert end_frame_ind - start_frame_ind >= self.target_video_len
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int)
+ video = vframes[frame_indice]
+ # videotransformer data proprecess
+ video = self.transform(video) # T C H W
+
+ # get video frames
+ images = []
+ for i in range(self.use_image_num):
+ while True:
+ try:
+ image = Image.open(os.path.join(self.video_frame_path, self.video_frame_files[index+i])).convert("RGB")
+ image = self.image_tranform(image).unsqueeze(0)
+ images.append(image)
+ break
+ except Exception as e:
+ traceback.print_exc()
+ index = random.randint(0, len(self.video_frame_files) - self.use_image_num)
+ images = torch.cat(images, dim=0)
+
+ assert len(images) == self.use_image_num
+
+ video_cat = torch.cat([video, images], dim=0)
+
+ return {'video': video_cat, 'video_name': 1}
+
+ def __len__(self):
+ return len(self.video_frame_files)
+
+
+if __name__ == '__main__':
+ import argparse
+ import torchvision
+ import video_transforms
+
+ import torch.utils.data as Data
+ import torchvision.transforms as transform
+
+ from PIL import Image
+
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--num_frames", type=int, default=16)
+ parser.add_argument("--use-image-num", type=int, default=5)
+ parser.add_argument("--frame_interval", type=int, default=3)
+ parser.add_argument("--dataset", type=str, default='webvideo10m')
+ parser.add_argument("--test-run", type=bool, default='')
+ parser.add_argument("--data-path", type=str, default="/path/to/datasets/preprocessed_ffs/train/videos/")
+ parser.add_argument("--frame-data-path", type=str, default="/path/to/datasets/preprocessed_ffs/train/images/")
+ parser.add_argument("--frame-data-txt", type=str, default="/path/to/datasets/faceForensics_v1/train_list.txt")
+ config = parser.parse_args()
+
+ temporal_sample = video_transforms.TemporalRandomCrop(config.num_frames * config.frame_interval)
+
+ transform_webvideo = transform.Compose([
+ video_transforms.ToTensorVideo(),
+ transform.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+ ])
+
+ dataset = FaceForensicsImages(config, transform=transform_webvideo, temporal_sample=temporal_sample)
+ dataloader = Data.DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=4)
+
+ for i, video_data in enumerate(dataloader):
+ video, video_label = video_data['video'], video_data['video_name']
+ # print(video_label)
+ # print(image_label)
+ print(video.shape)
+ print(video_label)
+ # video_ = ((video[0] * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
+ # print(video_.shape)
+ # try:
+ # torchvision.io.write_video(f'./test/{i:03d}_{video_label}.mp4', video_[:16], fps=8)
+ # except:
+ # pass
+
+ # if i % 100 == 0 and i != 0:
+ # break
+ print('Done!')
\ No newline at end of file
diff --git a/datasets/sky_datasets.py b/datasets/sky_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5d6f2217f56f135fc8dd70e39ee3386c3139f2a
--- /dev/null
+++ b/datasets/sky_datasets.py
@@ -0,0 +1,110 @@
+import os
+import torch
+import random
+import torch.utils.data as data
+
+import numpy as np
+
+from PIL import Image
+
+IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+class Sky(data.Dataset):
+ def __init__(self, configs, transform, temporal_sample=None, train=True):
+
+ self.configs = configs
+ self.data_path = configs.data_path
+ self.transform = transform
+ self.temporal_sample = temporal_sample
+ self.target_video_len = self.configs.num_frames
+ self.frame_interval = self.configs.frame_interval
+ self.data_all = self.load_video_frames(self.data_path)
+
+ def __getitem__(self, index):
+
+ vframes = self.data_all[index]
+ total_frames = len(vframes)
+
+ # Sampling video frames
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
+ assert end_frame_ind - start_frame_ind >= self.target_video_len
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, num=self.target_video_len, dtype=int) # start, stop, num=50
+
+ select_video_frames = vframes[frame_indice[0]: frame_indice[-1]+1: self.frame_interval]
+
+ video_frames = []
+ for path in select_video_frames:
+ video_frame = torch.as_tensor(np.array(Image.open(path), dtype=np.uint8, copy=True)).unsqueeze(0)
+ video_frames.append(video_frame)
+ video_clip = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2)
+ video_clip = self.transform(video_clip)
+
+ return {'video': video_clip, 'video_name': 1}
+
+ def __len__(self):
+ return self.video_num
+
+ def load_video_frames(self, dataroot):
+ data_all = []
+ frame_list = os.walk(dataroot)
+ for _, meta in enumerate(frame_list):
+ root = meta[0]
+ try:
+ frames = sorted(meta[2], key=lambda item: int(item.split('.')[0].split('_')[-1]))
+ except:
+ print(meta[0]) # root
+ print(meta[2]) # files
+ frames = [os.path.join(root, item) for item in frames if is_image_file(item)]
+ if len(frames) > max(0, self.target_video_len * self.frame_interval): # need all > (16 * frame-interval) videos
+ # if len(frames) >= max(0, self.target_video_len): # need all > 16 frames videos
+ data_all.append(frames)
+ self.video_num = len(data_all)
+ return data_all
+
+
+if __name__ == '__main__':
+
+ import argparse
+ import torchvision
+ import video_transforms
+ import torch.utils.data as data
+
+ from torchvision import transforms
+ from torchvision.utils import save_image
+
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--num_frames", type=int, default=16)
+ parser.add_argument("--frame_interval", type=int, default=4)
+ parser.add_argument("--data-path", type=str, default="/path/to/datasets/sky_timelapse/sky_train/")
+ config = parser.parse_args()
+
+
+ target_video_len = config.num_frames
+
+ temporal_sample = video_transforms.TemporalRandomCrop(target_video_len * config.frame_interval)
+ trans = transforms.Compose([
+ video_transforms.ToTensorVideo(),
+ # video_transforms.CenterCropVideo(256),
+ video_transforms.CenterCropResizeVideo(256),
+ # video_transforms.RandomHorizontalFlipVideo(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+
+ taichi_dataset = Sky(config, transform=trans, temporal_sample=temporal_sample)
+ print(len(taichi_dataset))
+ taichi_dataloader = data.DataLoader(dataset=taichi_dataset, batch_size=1, shuffle=False, num_workers=1)
+
+ for i, video_data in enumerate(taichi_dataloader):
+ print(video_data['video'].shape)
+
+ # print(video_data.dtype)
+ # for i in range(target_video_len):
+ # save_image(video_data[0][i], os.path.join('./test_data', '%04d.png' % i), normalize=True, value_range=(-1, 1))
+
+ # video_ = ((video_data[0] * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
+ # torchvision.io.write_video('./test_data' + 'test.mp4', video_, fps=8)
+ # exit()
\ No newline at end of file
diff --git a/datasets/sky_image_datasets.py b/datasets/sky_image_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ac3c709ae3f032140f4efaa012da2da124f38c5
--- /dev/null
+++ b/datasets/sky_image_datasets.py
@@ -0,0 +1,137 @@
+import os
+import torch
+import random
+import torch.utils.data as data
+import numpy as np
+import copy
+from PIL import Image
+
+IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+class SkyImages(data.Dataset):
+ def __init__(self, configs, transform, temporal_sample=None, train=True):
+
+ self.configs = configs
+ self.data_path = configs.data_path
+ self.transform = transform
+ self.temporal_sample = temporal_sample
+ self.target_video_len = self.configs.num_frames
+ self.frame_interval = self.configs.frame_interval
+ self.data_all, self.video_frame_all = self.load_video_frames(self.data_path)
+
+ # sky video frames
+ random.shuffle(self.video_frame_all)
+ self.use_image_num = configs.use_image_num
+
+ def __getitem__(self, index):
+
+ video_index = index % self.video_num
+ vframes = self.data_all[video_index]
+ total_frames = len(vframes)
+
+ # Sampling video frames
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
+ assert end_frame_ind - start_frame_ind >= self.target_video_len
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, num=self.target_video_len, dtype=int) # start, stop, num=50
+
+ select_video_frames = vframes[frame_indice[0]: frame_indice[-1]+1: self.frame_interval]
+
+ video_frames = []
+ for path in select_video_frames:
+ video_frame = torch.as_tensor(np.array(Image.open(path), dtype=np.uint8, copy=True)).unsqueeze(0)
+ video_frames.append(video_frame)
+ video_clip = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2)
+ video_clip = self.transform(video_clip)
+
+ # get video frames
+ images = []
+
+ for i in range(self.use_image_num):
+ while True:
+ try:
+ video_frame_path = self.video_frame_all[index+i]
+ image = torch.as_tensor(np.array(Image.open(video_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0)
+ images.append(image)
+ break
+ except Exception as e:
+ index = random.randint(0, self.video_frame_num - self.use_image_num)
+
+ images = torch.cat(images, dim=0).permute(0, 3, 1, 2)
+ images = self.transform(images)
+ assert len(images) == self.use_image_num
+
+ video_cat = torch.cat([video_clip, images], dim=0)
+
+ return {'video': video_cat, 'video_name': 1}
+
+ def __len__(self):
+ return self.video_frame_num
+
+ def load_video_frames(self, dataroot):
+ data_all = []
+ frames_all = []
+ frame_list = os.walk(dataroot)
+ for _, meta in enumerate(frame_list):
+ root = meta[0]
+ try:
+ frames = sorted(meta[2], key=lambda item: int(item.split('.')[0].split('_')[-1]))
+ except:
+ print(meta[0]) # root
+ print(meta[2]) # files
+ frames = [os.path.join(root, item) for item in frames if is_image_file(item)]
+ if len(frames) > max(0, self.target_video_len * self.frame_interval): # need all > (16 * frame-interval) videos
+ # if len(frames) >= max(0, self.target_video_len): # need all > 16 frames videos
+ data_all.append(frames)
+ for frame in frames:
+ frames_all.append(frame)
+ self.video_num = len(data_all)
+ self.video_frame_num = len(frames_all)
+ return data_all, frames_all
+
+
+if __name__ == '__main__':
+
+ import argparse
+ import torchvision
+ import video_transforms
+ import torch.utils.data as data
+
+ from torchvision import transforms
+ from torchvision.utils import save_image
+
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--num_frames", type=int, default=16)
+ parser.add_argument("--frame_interval", type=int, default=3)
+ parser.add_argument("--data-path", type=str, default="/path/to/datasets/sky_timelapse/sky_train/")
+ parser.add_argument("--use-image-num", type=int, default=5)
+ config = parser.parse_args()
+
+ target_video_len = config.num_frames
+
+ temporal_sample = video_transforms.TemporalRandomCrop(target_video_len * config.frame_interval)
+ trans = transforms.Compose([
+ video_transforms.ToTensorVideo(),
+ # video_transforms.CenterCropVideo(256),
+ video_transforms.CenterCropResizeVideo(256),
+ # video_transforms.RandomHorizontalFlipVideo(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+
+ taichi_dataset = SkyImages(config, transform=trans, temporal_sample=temporal_sample)
+ print(len(taichi_dataset))
+ taichi_dataloader = data.DataLoader(dataset=taichi_dataset, batch_size=1, shuffle=False, num_workers=1)
+
+ for i, video_data in enumerate(taichi_dataloader):
+ print(video_data['video'].shape)
+
+ # print(video_data.dtype)
+ # for i in range(target_video_len):
+ # save_image(video_data[0][i], os.path.join('./test_data', '%04d.png' % i), normalize=True, value_range=(-1, 1))
+
+ # video_ = ((video_data[0] * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
+ # torchvision.io.write_video('./test_data' + 'test.mp4', video_, fps=8)
+ # exit()
\ No newline at end of file
diff --git a/datasets/taichi_datasets.py b/datasets/taichi_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d64b436cee764eb262211dec458b2f509e6cdcc
--- /dev/null
+++ b/datasets/taichi_datasets.py
@@ -0,0 +1,108 @@
+import os
+import torch
+import random
+import torch.utils.data as data
+
+import numpy as np
+import io
+import json
+from PIL import Image
+
+IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+class Taichi(data.Dataset):
+ def __init__(self, configs, transform, temporal_sample=None, train=True):
+
+ self.configs = configs
+ self.data_path = configs.data_path
+ self.transform = transform
+ self.temporal_sample = temporal_sample
+ self.target_video_len = self.configs.num_frames
+ self.frame_interval = self.configs.frame_interval
+ self.data_all = self.load_video_frames(self.data_path)
+ self.video_num = len(self.data_all)
+
+ def __getitem__(self, index):
+
+ vframes = self.data_all[index]
+ total_frames = len(vframes)
+
+ # Sampling video frames
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
+ assert end_frame_ind - start_frame_ind >= self.target_video_len
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int)
+ select_video_frames = vframes[frame_indice[0]: frame_indice[-1]+1: self.frame_interval]
+
+ video_frames = []
+ for path in select_video_frames:
+ image = Image.open(path).convert('RGB')
+ video_frame = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0)
+ video_frames.append(video_frame)
+ video_clip = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2)
+ video_clip = self.transform(video_clip)
+
+ # return video_clip, 1
+ return {'video': video_clip, 'video_name': 1}
+
+ def __len__(self):
+ return self.video_num
+
+ def load_video_frames(self, dataroot):
+ data_all = []
+ frame_list = os.walk(dataroot)
+ for _, meta in enumerate(frame_list):
+ root = meta[0]
+ try:
+ frames = sorted(meta[2], key=lambda item: int(item.split('.')[0].split('_')[-1]))
+ except:
+ print(meta[0], meta[2])
+ frames = [os.path.join(root, item) for item in frames if is_image_file(item)]
+ # if len(frames) > max(0, self.sequence_length * self.sample_every_n_frames):
+ if len(frames) != 0:
+ data_all.append(frames)
+ # self.video_num = len(data_all)
+ return data_all
+
+
+if __name__ == '__main__':
+
+ import argparse
+ import torchvision
+ import video_transforms
+ import torch.utils.data as data
+
+ from torchvision import transforms
+ from torchvision.utils import save_image
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--num_frames", type=int, default=16)
+ parser.add_argument("--frame_interval", type=int, default=4)
+ parser.add_argument("--load_fron_ceph", type=bool, default=True)
+ parser.add_argument("--data-path", type=str, default="/path/to/datasets/taichi/taichi-256/frames/train")
+ config = parser.parse_args()
+
+
+ target_video_len = config.num_frames
+
+ temporal_sample = video_transforms.TemporalRandomCrop(target_video_len * config.frame_interval)
+ trans = transforms.Compose([
+ video_transforms.ToTensorVideo(),
+ video_transforms.RandomHorizontalFlipVideo(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+
+ taichi_dataset = Taichi(config, transform=trans, temporal_sample=temporal_sample)
+ taichi_dataloader = data.DataLoader(dataset=taichi_dataset, batch_size=1, shuffle=False, num_workers=1)
+
+ for i, video_data in enumerate(taichi_dataloader):
+ print(video_data['video'].shape)
+ # print(video_data.dtype)
+ # for i in range(target_video_len):
+ # save_image(video_data[0][i], os.path.join('./test_data', '%04d.png' % i), normalize=True, value_range=(-1, 1))
+
+ # video_ = ((video_data[0] * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
+ # torchvision.io.write_video('./test_data' + 'test.mp4', video_, fps=8)
+ # exit()
\ No newline at end of file
diff --git a/datasets/taichi_image_datasets.py b/datasets/taichi_image_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..d657dcbc13e7544ef75d7228989ca7cb9efe6eac
--- /dev/null
+++ b/datasets/taichi_image_datasets.py
@@ -0,0 +1,139 @@
+import os
+import torch
+import random
+import torch.utils.data as data
+
+import numpy as np
+import io
+import json
+from PIL import Image
+
+IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+class TaichiImages(data.Dataset):
+ def __init__(self, configs, transform, temporal_sample=None, train=True):
+
+ self.configs = configs
+ self.data_path = configs.data_path
+ self.transform = transform
+ self.temporal_sample = temporal_sample
+ self.target_video_len = self.configs.num_frames
+ self.frame_interval = self.configs.frame_interval
+ self.data_all, self.video_frame_all = self.load_video_frames(self.data_path)
+ self.video_num = len(self.data_all)
+ self.video_frame_num = len(self.video_frame_all)
+
+ # sky video frames
+ random.shuffle(self.video_frame_all)
+ self.use_image_num = configs.use_image_num
+
+ def __getitem__(self, index):
+
+ video_index = index % self.video_num
+ vframes = self.data_all[video_index]
+ total_frames = len(vframes)
+
+ # Sampling video frames
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
+ assert end_frame_ind - start_frame_ind >= self.target_video_len
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int)
+ # print(frame_indice)
+ select_video_frames = vframes[frame_indice[0]: frame_indice[-1]+1: self.frame_interval]
+
+ video_frames = []
+ for path in select_video_frames:
+ image = Image.open(path).convert('RGB')
+ video_frame = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0)
+ video_frames.append(video_frame)
+ video_clip = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2)
+ video_clip = self.transform(video_clip)
+
+ # get video frames
+ images = []
+ for i in range(self.use_image_num):
+ while True:
+ try:
+ video_frame_path = self.video_frame_all[index+i]
+ image_path = os.path.join(self.data_path, video_frame_path)
+ image = Image.open(image_path).convert('RGB')
+ image = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0)
+ images.append(image)
+ break
+ except Exception as e:
+ index = random.randint(0, self.video_frame_num - self.use_image_num)
+
+ images = torch.cat(images, dim=0).permute(0, 3, 1, 2)
+ images = self.transform(images)
+ assert len(images) == self.use_image_num
+
+ video_cat = torch.cat([video_clip, images], dim=0)
+
+ return {'video': video_cat, 'video_name': 1}
+
+ def __len__(self):
+ return self.video_frame_num
+
+ def load_video_frames(self, dataroot):
+ data_all = []
+ frames_all = []
+ frame_list = os.walk(dataroot)
+ for _, meta in enumerate(frame_list):
+ root = meta[0]
+ try:
+ frames = sorted(meta[2], key=lambda item: int(item.split('.')[0].split('_')[-1]))
+ except:
+ print(meta[0], meta[2])
+ frames = [os.path.join(root, item) for item in frames if is_image_file(item)]
+ # if len(frames) > max(0, self.sequence_length * self.sample_every_n_frames):
+ if len(frames) != 0:
+ data_all.append(frames)
+ for frame in frames:
+ frames_all.append(frame)
+ # self.video_num = len(data_all)
+ return data_all, frames_all
+
+
+if __name__ == '__main__':
+
+ import argparse
+ import torchvision
+ import video_transforms
+ import torch.utils.data as data
+
+ from torchvision import transforms
+ from torchvision.utils import save_image
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--num_frames", type=int, default=16)
+ parser.add_argument("--frame_interval", type=int, default=4)
+ parser.add_argument("--load_from_ceph", type=bool, default=True)
+ parser.add_argument("--data-path", type=str, default="/path/to/datasets/taichi/taichi-256/frames/train")
+ parser.add_argument("--use-image-num", type=int, default=5)
+ config = parser.parse_args()
+
+
+ target_video_len = config.num_frames
+
+ temporal_sample = video_transforms.TemporalRandomCrop(target_video_len * config.frame_interval)
+ trans = transforms.Compose([
+ video_transforms.ToTensorVideo(),
+ video_transforms.RandomHorizontalFlipVideo(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+
+ taichi_dataset = TaichiImages(config, transform=trans, temporal_sample=temporal_sample)
+ print(len(taichi_dataset))
+ taichi_dataloader = data.DataLoader(dataset=taichi_dataset, batch_size=1, shuffle=False, num_workers=1)
+
+ for i, video_data in enumerate(taichi_dataloader):
+ print(video_data['video'].shape)
+ # print(video_data.dtype)
+ # for i in range(target_video_len):
+ # save_image(video_data[0][i], os.path.join('./test_data', '%04d.png' % i), normalize=True, value_range=(-1, 1))
+
+ video_ = ((video_data[0] * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
+ torchvision.io.write_video('./test_data' + 'test.mp4', video_, fps=8)
+ exit()
\ No newline at end of file
diff --git a/datasets/ucf101_datasets.py b/datasets/ucf101_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..254b6dc7945b75353074b177ab3e150c15a64a03
--- /dev/null
+++ b/datasets/ucf101_datasets.py
@@ -0,0 +1,229 @@
+import os
+import re
+import json
+import torch
+import decord
+import torchvision
+import numpy as np
+
+
+from PIL import Image
+from einops import rearrange
+from typing import Dict, List, Tuple
+
+class_labels_map = None
+cls_sample_cnt = None
+
+class_labels_map = None
+cls_sample_cnt = None
+
+
+def temporal_sampling(frames, start_idx, end_idx, num_samples):
+ """
+ Given the start and end frame index, sample num_samples frames between
+ the start and end with equal interval.
+ Args:
+ frames (tensor): a tensor of video frames, dimension is
+ `num video frames` x `channel` x `height` x `width`.
+ start_idx (int): the index of the start frame.
+ end_idx (int): the index of the end frame.
+ num_samples (int): number of frames to sample.
+ Returns:
+ frames (tersor): a tensor of temporal sampled video frames, dimension is
+ `num clip frames` x `channel` x `height` x `width`.
+ """
+ index = torch.linspace(start_idx, end_idx, num_samples)
+ index = torch.clamp(index, 0, frames.shape[0] - 1).long()
+ frames = torch.index_select(frames, 0, index)
+ return frames
+
+
+def get_filelist(file_path):
+ Filelist = []
+ for home, dirs, files in os.walk(file_path):
+ for filename in files:
+ # ๆไปถๅๅ่กจ๏ผๅ
ๅซๅฎๆด่ทฏๅพ
+ Filelist.append(os.path.join(home, filename))
+ # # ๆไปถๅๅ่กจ๏ผๅชๅ
ๅซๆไปถๅ
+ # Filelist.append( filename)
+ return Filelist
+
+
+def load_annotation_data(data_file_path):
+ with open(data_file_path, 'r') as data_file:
+ return json.load(data_file)
+
+
+def get_class_labels(num_class, anno_pth='./k400_classmap.json'):
+ global class_labels_map, cls_sample_cnt
+
+ if class_labels_map is not None:
+ return class_labels_map, cls_sample_cnt
+ else:
+ cls_sample_cnt = {}
+ class_labels_map = load_annotation_data(anno_pth)
+ for cls in class_labels_map:
+ cls_sample_cnt[cls] = 0
+ return class_labels_map, cls_sample_cnt
+
+
+def load_annotations(ann_file, num_class, num_samples_per_cls):
+ dataset = []
+ class_to_idx, cls_sample_cnt = get_class_labels(num_class)
+ with open(ann_file, 'r') as fin:
+ for line in fin:
+ line_split = line.strip().split('\t')
+ sample = {}
+ idx = 0
+ # idx for frame_dir
+ frame_dir = line_split[idx]
+ sample['video'] = frame_dir
+ idx += 1
+
+ # idx for label[s]
+ label = [x for x in line_split[idx:]]
+ assert label, f'missing label in line: {line}'
+ assert len(label) == 1
+ class_name = label[0]
+ class_index = int(class_to_idx[class_name])
+
+ # choose a class subset of whole dataset
+ if class_index < num_class:
+ sample['label'] = class_index
+ if cls_sample_cnt[class_name] < num_samples_per_cls:
+ dataset.append(sample)
+ cls_sample_cnt[class_name]+=1
+
+ return dataset
+
+
+def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
+ """Finds the class folders in a dataset.
+
+ See :class:`DatasetFolder` for details.
+ """
+ classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
+ if not classes:
+ raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
+
+ class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
+ return classes, class_to_idx
+
+
+class DecordInit(object):
+ """Using Decord(https://github.com/dmlc/decord) to initialize the video_reader."""
+
+ def __init__(self, num_threads=1):
+ self.num_threads = num_threads
+ self.ctx = decord.cpu(0)
+
+ def __call__(self, filename):
+ """Perform the Decord initialization.
+ Args:
+ results (dict): The resulting dict to be modified and passed
+ to the next transform in pipeline.
+ """
+ reader = decord.VideoReader(filename,
+ ctx=self.ctx,
+ num_threads=self.num_threads)
+ return reader
+
+ def __repr__(self):
+ repr_str = (f'{self.__class__.__name__}('
+ f'sr={self.sr},'
+ f'num_threads={self.num_threads})')
+ return repr_str
+
+
+class UCF101(torch.utils.data.Dataset):
+ """Load the UCF101 video files
+
+ Args:
+ target_video_len (int): the number of video frames will be load.
+ align_transform (callable): Align different videos in a specified size.
+ temporal_sample (callable): Sample the target length of a video.
+ """
+
+ def __init__(self,
+ configs,
+ transform=None,
+ temporal_sample=None):
+ self.configs = configs
+ self.data_path = configs.data_path
+ self.video_lists = get_filelist(configs.data_path)
+ self.transform = transform
+ self.temporal_sample = temporal_sample
+ self.target_video_len = self.configs.num_frames
+ self.v_decoder = DecordInit()
+ self.classes, self.class_to_idx = find_classes(self.data_path)
+ # print(self.class_to_idx)
+ # exit()
+
+ def __getitem__(self, index):
+ path = self.video_lists[index]
+ class_name = path.split('/')[-2]
+ class_index = self.class_to_idx[class_name]
+
+ vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
+ total_frames = len(vframes)
+
+ # Sampling video frames
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
+ assert end_frame_ind - start_frame_ind >= self.target_video_len
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int)
+ # print(frame_indice)
+ video = vframes[frame_indice] #
+ video = self.transform(video) # T C H W
+
+ return {'video': video, 'video_name': class_index}
+
+ def __len__(self):
+ return len(self.video_lists)
+
+
+if __name__ == '__main__':
+
+ import argparse
+ import video_transforms
+ import torch.utils.data as Data
+ import torchvision.transforms as transforms
+
+ from PIL import Image
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--num_frames", type=int, default=16)
+ parser.add_argument("--frame_interval", type=int, default=1)
+ # parser.add_argument("--data-path", type=str, default="/nvme/share_data/datasets/UCF101/videos")
+ parser.add_argument("--data-path", type=str, default="/path/to/datasets/UCF101/videos/")
+ config = parser.parse_args()
+
+
+ temporal_sample = video_transforms.TemporalRandomCrop(config.num_frames * config.frame_interval)
+
+ transform_ucf101 = transforms.Compose([
+ video_transforms.ToTensorVideo(), # TCHW
+ video_transforms.RandomHorizontalFlipVideo(),
+ video_transforms.UCFCenterCropVideo(256),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+
+
+ ffs_dataset = UCF101(config, transform=transform_ucf101, temporal_sample=temporal_sample)
+ ffs_dataloader = Data.DataLoader(dataset=ffs_dataset, batch_size=6, shuffle=False, num_workers=1)
+
+ # for i, video_data in enumerate(ffs_dataloader):
+ for video_data in ffs_dataloader:
+ print(type(video_data))
+ video = video_data['video']
+ video_name = video_data['video_name']
+ print(video.shape)
+ print(video_name)
+ # print(video_data[2])
+
+ # for i in range(16):
+ # img0 = rearrange(video_data[0][0][i], 'c h w -> h w c')
+ # print('Label: {}'.format(video_data[1]))
+ # print(img0.shape)
+ # img0 = Image.fromarray(np.uint8(img0 * 255))
+ # img0.save('./img{}.jpg'.format(i))
+ exit()
\ No newline at end of file
diff --git a/datasets/ucf101_image_datasets.py b/datasets/ucf101_image_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5a44e3ca48ba453bcf1b822eb2bc58da008b5a1
--- /dev/null
+++ b/datasets/ucf101_image_datasets.py
@@ -0,0 +1,279 @@
+import os, io
+import re
+import json
+import torch
+import decord
+import torchvision
+import numpy as np
+
+
+from PIL import Image
+from einops import rearrange
+from typing import Dict, List, Tuple
+from torchvision import transforms
+import random
+
+
+class_labels_map = None
+cls_sample_cnt = None
+
+class_labels_map = None
+cls_sample_cnt = None
+
+
+def temporal_sampling(frames, start_idx, end_idx, num_samples):
+ """
+ Given the start and end frame index, sample num_samples frames between
+ the start and end with equal interval.
+ Args:
+ frames (tensor): a tensor of video frames, dimension is
+ `num video frames` x `channel` x `height` x `width`.
+ start_idx (int): the index of the start frame.
+ end_idx (int): the index of the end frame.
+ num_samples (int): number of frames to sample.
+ Returns:
+ frames (tersor): a tensor of temporal sampled video frames, dimension is
+ `num clip frames` x `channel` x `height` x `width`.
+ """
+ index = torch.linspace(start_idx, end_idx, num_samples)
+ index = torch.clamp(index, 0, frames.shape[0] - 1).long()
+ frames = torch.index_select(frames, 0, index)
+ return frames
+
+
+def get_filelist(file_path):
+ Filelist = []
+ for home, dirs, files in os.walk(file_path):
+ for filename in files:
+ Filelist.append(os.path.join(home, filename))
+ # Filelist.append( filename)
+ return Filelist
+
+
+def load_annotation_data(data_file_path):
+ with open(data_file_path, 'r') as data_file:
+ return json.load(data_file)
+
+
+def get_class_labels(num_class, anno_pth='./k400_classmap.json'):
+ global class_labels_map, cls_sample_cnt
+
+ if class_labels_map is not None:
+ return class_labels_map, cls_sample_cnt
+ else:
+ cls_sample_cnt = {}
+ class_labels_map = load_annotation_data(anno_pth)
+ for cls in class_labels_map:
+ cls_sample_cnt[cls] = 0
+ return class_labels_map, cls_sample_cnt
+
+
+def load_annotations(ann_file, num_class, num_samples_per_cls):
+ dataset = []
+ class_to_idx, cls_sample_cnt = get_class_labels(num_class)
+ with open(ann_file, 'r') as fin:
+ for line in fin:
+ line_split = line.strip().split('\t')
+ sample = {}
+ idx = 0
+ # idx for frame_dir
+ frame_dir = line_split[idx]
+ sample['video'] = frame_dir
+ idx += 1
+
+ # idx for label[s]
+ label = [x for x in line_split[idx:]]
+ assert label, f'missing label in line: {line}'
+ assert len(label) == 1
+ class_name = label[0]
+ class_index = int(class_to_idx[class_name])
+
+ # choose a class subset of whole dataset
+ if class_index < num_class:
+ sample['label'] = class_index
+ if cls_sample_cnt[class_name] < num_samples_per_cls:
+ dataset.append(sample)
+ cls_sample_cnt[class_name]+=1
+
+ return dataset
+
+
+def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
+ """Finds the class folders in a dataset.
+
+ See :class:`DatasetFolder` for details.
+ """
+ classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
+ if not classes:
+ raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
+
+ class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
+ return classes, class_to_idx
+
+
+class DecordInit(object):
+ """Using Decord(https://github.com/dmlc/decord) to initialize the video_reader."""
+
+ def __init__(self, num_threads=1):
+ self.num_threads = num_threads
+ self.ctx = decord.cpu(0)
+
+ def __call__(self, filename):
+ """Perform the Decord initialization.
+ Args:
+ results (dict): The resulting dict to be modified and passed
+ to the next transform in pipeline.
+ """
+ reader = decord.VideoReader(filename,
+ ctx=self.ctx,
+ num_threads=self.num_threads)
+ return reader
+
+ def __repr__(self):
+ repr_str = (f'{self.__class__.__name__}('
+ f'sr={self.sr},'
+ f'num_threads={self.num_threads})')
+ return repr_str
+
+
+class UCF101Images(torch.utils.data.Dataset):
+ """Load the UCF101 video files
+
+ Args:
+ target_video_len (int): the number of video frames will be load.
+ align_transform (callable): Align different videos in a specified size.
+ temporal_sample (callable): Sample the target length of a video.
+ """
+
+ def __init__(self,
+ configs,
+ transform=None,
+ temporal_sample=None):
+ self.configs = configs
+ self.data_path = configs.data_path
+ self.video_lists = get_filelist(configs.data_path)
+ self.transform = transform
+ self.temporal_sample = temporal_sample
+ self.target_video_len = self.configs.num_frames
+ self.v_decoder = DecordInit()
+ self.classes, self.class_to_idx = find_classes(self.data_path)
+ self.video_num = len(self.video_lists)
+
+ # ucf101 video frames
+ self.frame_data_path = configs.frame_data_path # important
+
+ self.video_frame_txt = configs.frame_data_txt
+ self.video_frame_files = [frame_file.strip() for frame_file in open(self.video_frame_txt)]
+ random.shuffle(self.video_frame_files)
+ self.use_image_num = configs.use_image_num
+ self.image_tranform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+ self.video_frame_num = len(self.video_frame_files)
+
+
+ def __getitem__(self, index):
+
+ video_index = index % self.video_num
+ path = self.video_lists[video_index]
+ class_name = path.split('/')[-2]
+ class_index = self.class_to_idx[class_name]
+
+ vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
+ total_frames = len(vframes)
+
+ # Sampling video frames
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
+ assert end_frame_ind - start_frame_ind >= self.target_video_len
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int)
+ video = vframes[frame_indice]
+
+ # videotransformer data proprecess
+ video = self.transform(video) # T C H W
+ images = []
+ image_names = []
+ for i in range(self.use_image_num):
+ while True:
+ try:
+ video_frame_path = self.video_frame_files[index+i]
+ image_class_name = video_frame_path.split('_')[1]
+ image_class_index = self.class_to_idx[image_class_name]
+ video_frame_path = os.path.join(self.frame_data_path, video_frame_path)
+ image = Image.open(video_frame_path).convert('RGB')
+ image = self.image_tranform(image).unsqueeze(0)
+ images.append(image)
+ image_names.append(str(image_class_index))
+ break
+ except Exception as e:
+ index = random.randint(0, self.video_frame_num - self.use_image_num)
+ images = torch.cat(images, dim=0)
+ assert len(images) == self.use_image_num
+ assert len(image_names) == self.use_image_num
+
+ image_names = '====='.join(image_names)
+
+ video_cat = torch.cat([video, images], dim=0)
+
+ return {'video': video_cat,
+ 'video_name': class_index,
+ 'image_name': image_names}
+
+ def __len__(self):
+ return self.video_frame_num
+
+
+if __name__ == '__main__':
+
+ import argparse
+ import video_transforms
+ import torch.utils.data as Data
+ import torchvision.transforms as transforms
+
+ from PIL import Image
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--num_frames", type=int, default=16)
+ parser.add_argument("--frame_interval", type=int, default=3)
+ parser.add_argument("--use-image-num", type=int, default=5)
+ parser.add_argument("--data-path", type=str, default="/path/to/datasets/UCF101/videos/")
+ parser.add_argument("--frame-data-path", type=str, default="/path/to/datasets/preprocessed_ffs/train/images/")
+ parser.add_argument("--frame-data-txt", type=str, default="/path/to/datasets/UCF101/train_256_list.txt")
+ config = parser.parse_args()
+
+
+ temporal_sample = video_transforms.TemporalRandomCrop(config.num_frames * config.frame_interval)
+
+ transform_ucf101 = transforms.Compose([
+ video_transforms.ToTensorVideo(), # TCHW
+ video_transforms.RandomHorizontalFlipVideo(),
+ video_transforms.UCFCenterCropVideo(256),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+
+
+ ffs_dataset = UCF101Images(config, transform=transform_ucf101, temporal_sample=temporal_sample)
+ ffs_dataloader = Data.DataLoader(dataset=ffs_dataset, batch_size=6, shuffle=False, num_workers=1)
+
+ # for i, video_data in enumerate(ffs_dataloader):
+ for video_data in ffs_dataloader:
+ # print(type(video_data))
+ video = video_data['video']
+ # video_name = video_data['video_name']
+ print(video.shape)
+ print(video_data['image_name'])
+ image_name = video_data['image_name']
+ image_names = []
+ for caption in image_name:
+ single_caption = [int(item) for item in caption.split('=====')]
+ image_names.append(torch.as_tensor(single_caption))
+ print(image_names)
+ # print(video_name)
+ # print(video_data[2])
+
+ # for i in range(16):
+ # img0 = rearrange(video_data[0][0][i], 'c h w -> h w c')
+ # print('Label: {}'.format(video_data[1]))
+ # print(img0.shape)
+ # img0 = Image.fromarray(np.uint8(img0 * 255))
+ # img0.save('./img{}.jpg'.format(i))
diff --git a/datasets/video_transforms.py b/datasets/video_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..88260e00d4353f3a955b0c608dd8f86bb0e3a0bf
--- /dev/null
+++ b/datasets/video_transforms.py
@@ -0,0 +1,482 @@
+import torch
+import random
+import numbers
+from torchvision.transforms import RandomCrop, RandomResizedCrop
+
+def _is_tensor_video_clip(clip):
+ if not torch.is_tensor(clip):
+ raise TypeError("clip should be Tensor. Got %s" % type(clip))
+
+ if not clip.ndimension() == 4:
+ raise ValueError("clip should be 4D. Got %dD" % clip.dim())
+
+ return True
+
+
+def center_crop_arr(pil_image, image_size):
+ """
+ Center cropping implementation from ADM.
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
+ """
+ while min(*pil_image.size) >= 2 * image_size:
+ pil_image = pil_image.resize(
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
+ )
+
+ scale = image_size / min(*pil_image.size)
+ pil_image = pil_image.resize(
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
+ )
+
+ arr = np.array(pil_image)
+ crop_y = (arr.shape[0] - image_size) // 2
+ crop_x = (arr.shape[1] - image_size) // 2
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
+
+
+def crop(clip, i, j, h, w):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+ """
+ if len(clip.size()) != 4:
+ raise ValueError("clip should be a 4D tensor")
+ return clip[..., i : i + h, j : j + w]
+
+
+def resize(clip, target_size, interpolation_mode):
+ if len(target_size) != 2:
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
+ return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
+
+def resize_scale(clip, target_size, interpolation_mode):
+ if len(target_size) != 2:
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
+ H, W = clip.size(-2), clip.size(-1)
+ scale_ = target_size[0] / min(H, W)
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
+
+
+def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
+ """
+ Do spatial cropping and resizing to the video clip
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+ i (int): i in (i,j) i.e coordinates of the upper left corner.
+ j (int): j in (i,j) i.e coordinates of the upper left corner.
+ h (int): Height of the cropped region.
+ w (int): Width of the cropped region.
+ size (tuple(int, int)): height and width of resized clip
+ Returns:
+ clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
+ """
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ clip = crop(clip, i, j, h, w)
+ clip = resize(clip, size, interpolation_mode)
+ return clip
+
+
+def center_crop(clip, crop_size):
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ h, w = clip.size(-2), clip.size(-1)
+ th, tw = crop_size
+ if h < th or w < tw:
+ raise ValueError("height and width must be no smaller than crop_size")
+
+ i = int(round((h - th) / 2.0))
+ j = int(round((w - tw) / 2.0))
+ return crop(clip, i, j, th, tw)
+
+
+def center_crop_using_short_edge(clip):
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ h, w = clip.size(-2), clip.size(-1)
+ if h < w:
+ th, tw = h, h
+ i = 0
+ j = int(round((w - tw) / 2.0))
+ else:
+ th, tw = w, w
+ i = int(round((h - th) / 2.0))
+ j = 0
+ return crop(clip, i, j, th, tw)
+
+
+def random_shift_crop(clip):
+ '''
+ Slide along the long edge, with the short edge as crop size
+ '''
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ h, w = clip.size(-2), clip.size(-1)
+
+ if h <= w:
+ long_edge = w
+ short_edge = h
+ else:
+ long_edge = h
+ short_edge =w
+
+ th, tw = short_edge, short_edge
+
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
+ return crop(clip, i, j, th, tw)
+
+
+def to_tensor(clip):
+ """
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
+ permute the dimensions of clip tensor
+ Args:
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
+ Return:
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
+ """
+ _is_tensor_video_clip(clip)
+ if not clip.dtype == torch.uint8:
+ raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
+ # return clip.float().permute(3, 0, 1, 2) / 255.0
+ return clip.float() / 255.0
+
+
+def normalize(clip, mean, std, inplace=False):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
+ mean (tuple): pixel RGB mean. Size is (3)
+ std (tuple): pixel standard deviation. Size is (3)
+ Returns:
+ normalized clip (torch.tensor): Size is (T, C, H, W)
+ """
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ if not inplace:
+ clip = clip.clone()
+ mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
+ # print(mean)
+ std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
+ clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
+ return clip
+
+
+def hflip(clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
+ Returns:
+ flipped clip (torch.tensor): Size is (T, C, H, W)
+ """
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ return clip.flip(-1)
+
+
+class RandomCropVideo:
+ def __init__(self, size):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+ Returns:
+ torch.tensor: randomly cropped video clip.
+ size is (T, C, OH, OW)
+ """
+ i, j, h, w = self.get_params(clip)
+ return crop(clip, i, j, h, w)
+
+ def get_params(self, clip):
+ h, w = clip.shape[-2:]
+ th, tw = self.size
+
+ if h < th or w < tw:
+ raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
+
+ if w == tw and h == th:
+ return 0, 0, h, w
+
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
+
+ return i, j, th, tw
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(size={self.size})"
+
+class CenterCropResizeVideo:
+ '''
+ First use the short side for cropping length,
+ center crop video, then resize to the specified size
+ '''
+ def __init__(
+ self,
+ size,
+ interpolation_mode="bilinear",
+ ):
+ if isinstance(size, tuple):
+ if len(size) != 2:
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
+ self.size = size
+ else:
+ self.size = (size, size)
+
+ self.interpolation_mode = interpolation_mode
+
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+ Returns:
+ torch.tensor: scale resized / center cropped video clip.
+ size is (T, C, crop_size, crop_size)
+ """
+ clip_center_crop = center_crop_using_short_edge(clip)
+ clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode)
+ return clip_center_crop_resize
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
+
+class UCFCenterCropVideo:
+ '''
+ First scale to the specified size in equal proportion to the short edge,
+ then center cropping
+ '''
+ def __init__(
+ self,
+ size,
+ interpolation_mode="bilinear",
+ ):
+ if isinstance(size, tuple):
+ if len(size) != 2:
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
+ self.size = size
+ else:
+ self.size = (size, size)
+
+ self.interpolation_mode = interpolation_mode
+
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+ Returns:
+ torch.tensor: scale resized / center cropped video clip.
+ size is (T, C, crop_size, crop_size)
+ """
+ clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
+ clip_center_crop = center_crop(clip_resize, self.size)
+ return clip_center_crop
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
+
+class KineticsRandomCropResizeVideo:
+ '''
+ Slide along the long edge, with the short edge as crop size. And resie to the desired size.
+ '''
+ def __init__(
+ self,
+ size,
+ interpolation_mode="bilinear",
+ ):
+ if isinstance(size, tuple):
+ if len(size) != 2:
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
+ self.size = size
+ else:
+ self.size = (size, size)
+
+ self.interpolation_mode = interpolation_mode
+
+ def __call__(self, clip):
+ clip_random_crop = random_shift_crop(clip)
+ clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
+ return clip_resize
+
+
+class CenterCropVideo:
+ def __init__(
+ self,
+ size,
+ interpolation_mode="bilinear",
+ ):
+ if isinstance(size, tuple):
+ if len(size) != 2:
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
+ self.size = size
+ else:
+ self.size = (size, size)
+
+ self.interpolation_mode = interpolation_mode
+
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+ Returns:
+ torch.tensor: center cropped video clip.
+ size is (T, C, crop_size, crop_size)
+ """
+ clip_center_crop = center_crop(clip, self.size)
+ return clip_center_crop
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
+
+
+class NormalizeVideo:
+ """
+ Normalize the video clip by mean subtraction and division by standard deviation
+ Args:
+ mean (3-tuple): pixel RGB mean
+ std (3-tuple): pixel RGB standard deviation
+ inplace (boolean): whether do in-place normalization
+ """
+
+ def __init__(self, mean, std, inplace=False):
+ self.mean = mean
+ self.std = std
+ self.inplace = inplace
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
+ """
+ return normalize(clip, self.mean, self.std, self.inplace)
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
+
+
+class ToTensorVideo:
+ """
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
+ permute the dimensions of clip tensor
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
+ Return:
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
+ """
+ return to_tensor(clip)
+
+ def __repr__(self) -> str:
+ return self.__class__.__name__
+
+
+class RandomHorizontalFlipVideo:
+ """
+ Flip the video clip along the horizontal direction with a given probability
+ Args:
+ p (float): probability of the clip being flipped. Default value is 0.5
+ """
+
+ def __init__(self, p=0.5):
+ self.p = p
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): Size is (T, C, H, W)
+ Return:
+ clip (torch.tensor): Size is (T, C, H, W)
+ """
+ if random.random() < self.p:
+ clip = hflip(clip)
+ return clip
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(p={self.p})"
+
+# ------------------------------------------------------------
+# --------------------- Sampling ---------------------------
+# ------------------------------------------------------------
+class TemporalRandomCrop(object):
+ """Temporally crop the given frame indices at a random location.
+
+ Args:
+ size (int): Desired length of frames will be seen in the model.
+ """
+
+ def __init__(self, size):
+ self.size = size
+
+ def __call__(self, total_frames):
+ rand_end = max(0, total_frames - self.size - 1)
+ begin_index = random.randint(0, rand_end)
+ end_index = min(begin_index + self.size, total_frames)
+ return begin_index, end_index
+
+
+if __name__ == '__main__':
+ from torchvision import transforms
+ import torchvision.io as io
+ import numpy as np
+ from torchvision.utils import save_image
+ import os
+
+ vframes, aframes, info = io.read_video(
+ filename='./v_Archery_g01_c03.avi',
+ pts_unit='sec',
+ output_format='TCHW'
+ )
+
+ trans = transforms.Compose([
+ ToTensorVideo(),
+ RandomHorizontalFlipVideo(),
+ UCFCenterCropVideo(512),
+ # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+
+ target_video_len = 32
+ frame_interval = 1
+ total_frames = len(vframes)
+ print(total_frames)
+
+ temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
+
+
+ # Sampling video frames
+ start_frame_ind, end_frame_ind = temporal_sample(total_frames)
+ # print(start_frame_ind)
+ # print(end_frame_ind)
+ assert end_frame_ind - start_frame_ind >= target_video_len
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
+ print(frame_indice)
+
+ select_vframes = vframes[frame_indice]
+ print(select_vframes.shape)
+ print(select_vframes.dtype)
+
+ select_vframes_trans = trans(select_vframes)
+ print(select_vframes_trans.shape)
+ print(select_vframes_trans.dtype)
+
+ select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
+ print(select_vframes_trans_int.dtype)
+ print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)
+
+ io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
+
+ for i in range(target_video_len):
+ save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True, value_range=(-1, 1))
\ No newline at end of file
diff --git a/demo.py b/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..1dd7b8f0cedaba3f674294a9623ccd93dcf7847d
--- /dev/null
+++ b/demo.py
@@ -0,0 +1,284 @@
+import gradio as gr
+import os
+import torch
+import argparse
+import torchvision
+
+
+from diffusers.schedulers import (DDIMScheduler, DDPMScheduler, PNDMScheduler,
+ EulerDiscreteScheduler, DPMSolverMultistepScheduler,
+ HeunDiscreteScheduler, EulerAncestralDiscreteScheduler,
+ DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler)
+from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
+from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
+from omegaconf import OmegaConf
+from transformers import T5EncoderModel, T5Tokenizer
+
+import os, sys
+sys.path.append(os.path.split(sys.path[0])[0])
+from sample.pipeline_latte import LattePipeline
+from models import get_models
+# import imageio
+from torchvision.utils import save_image
+import spaces
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--config", type=str, default="./configs/t2x/t2v_sample.yaml")
+args = parser.parse_args()
+args = OmegaConf.load(args.config)
+
+torch.set_grad_enabled(False)
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+transformer_model = get_models(args).to(device, dtype=torch.float16)
+# state_dict = find_model(args.ckpt)
+# msg, unexp = transformer_model.load_state_dict(state_dict, strict=False)
+
+if args.enable_vae_temporal_decoder:
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
+else:
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae", torch_dtype=torch.float16).to(device)
+tokenizer = T5Tokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
+text_encoder = T5EncoderModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device)
+
+# set eval mode
+transformer_model.eval()
+vae.eval()
+text_encoder.eval()
+
+@spaces.GPU
+def gen_video(text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step):
+ torch.manual_seed(seed)
+ if sample_method == 'DDIM':
+ scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path,
+ subfolder="scheduler",
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ beta_schedule=args.beta_schedule,
+ variance_type=args.variance_type,
+ clip_sample=False)
+ elif sample_method == 'EulerDiscrete':
+ scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_path,
+ subfolder="scheduler",
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ beta_schedule=args.beta_schedule,
+ variance_type=args.variance_type)
+ elif sample_method == 'DDPM':
+ scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_path,
+ subfolder="scheduler",
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ beta_schedule=args.beta_schedule,
+ variance_type=args.variance_type,
+ clip_sample=False)
+ elif sample_method == 'DPMSolverMultistep':
+ scheduler = DPMSolverMultistepScheduler.from_pretrained(args.pretrained_model_path,
+ subfolder="scheduler",
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ beta_schedule=args.beta_schedule,
+ variance_type=args.variance_type)
+ elif sample_method == 'DPMSolverSinglestep':
+ scheduler = DPMSolverSinglestepScheduler.from_pretrained(args.pretrained_model_path,
+ subfolder="scheduler",
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ beta_schedule=args.beta_schedule,
+ variance_type=args.variance_type)
+ elif sample_method == 'PNDM':
+ scheduler = PNDMScheduler.from_pretrained(args.pretrained_model_path,
+ subfolder="scheduler",
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ beta_schedule=args.beta_schedule,
+ variance_type=args.variance_type)
+ elif sample_method == 'HeunDiscrete':
+ scheduler = HeunDiscreteScheduler.from_pretrained(args.pretrained_model_path,
+ subfolder="scheduler",
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ beta_schedule=args.beta_schedule,
+ variance_type=args.variance_type)
+ elif sample_method == 'EulerAncestralDiscrete':
+ scheduler = EulerAncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path,
+ subfolder="scheduler",
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ beta_schedule=args.beta_schedule,
+ variance_type=args.variance_type)
+ elif sample_method == 'DEISMultistep':
+ scheduler = DEISMultistepScheduler.from_pretrained(args.pretrained_model_path,
+ subfolder="scheduler",
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ beta_schedule=args.beta_schedule,
+ variance_type=args.variance_type)
+ elif sample_method == 'KDPM2AncestralDiscrete':
+ scheduler = KDPM2AncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path,
+ subfolder="scheduler",
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ beta_schedule=args.beta_schedule,
+ variance_type=args.variance_type)
+
+
+ videogen_pipeline = LattePipeline(vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ transformer=transformer_model).to(device)
+ # videogen_pipeline.enable_xformers_memory_efficient_attention()
+
+ videos = videogen_pipeline(text_input,
+ video_length=video_length,
+ height=height,
+ width=width,
+ num_inference_steps=diffusion_step,
+ guidance_scale=scfg_scale,
+ enable_temporal_attentions=args.enable_temporal_attentions,
+ num_images_per_prompt=1,
+ mask_feature=True,
+ enable_vae_temporal_decoder=args.enable_vae_temporal_decoder
+ ).video
+
+ save_path = args.save_img_path + 'temp' + '.mp4'
+ torchvision.io.write_video(save_path, videos[0], fps=8)
+ return save_path
+
+
+if not os.path.exists(args.save_img_path):
+ os.makedirs(args.save_img_path)
+
+intro = """
+
+
Latte: Latent Diffusion Transformer for Video Generation
+
+"""
+
+with gr.Blocks() as demo:
+ # gr.HTML(intro)
+ # with gr.Accordion("README", open=False):
+ # gr.HTML(
+ # """
+ #
+ # project page | paper
+ #
+
+ # We will continue update Latte.
+ # """
+ # )
+ gr.Markdown("Latte: Latent Diffusion Transformer for Video Generation")
+ gr.Markdown(
+ """
+
Latte supports both T2I and T2V, and will be continuously updated, so stay tuned!
+ """
+ )
+ gr.Markdown(
+ """
+ """
+ )
+
+
+ with gr.Row():
+ with gr.Column(visible=True) as input_raws:
+ with gr.Row():
+ with gr.Column(scale=1.0):
+ # text_input = gr.Textbox(show_label=True, interactive=True, label="Text prompt").style(container=False)
+ text_input = gr.Textbox(show_label=True, interactive=True, label="Prompt")
+ # with gr.Row():
+ # with gr.Column(scale=0.5):
+ # image_input = gr.Image(show_label=True, interactive=True, label="Reference image").style(container=False)
+ # with gr.Column(scale=0.5):
+ # preframe_input = gr.Image(show_label=True, interactive=True, label="First frame").style(container=False)
+ with gr.Row():
+ with gr.Column(scale=0.5):
+ sample_method = gr.Dropdown(choices=["DDIM", "EulerDiscrete", "PNDM"], label="Sample Method", value="DDIM")
+ # with gr.Row():
+ # with gr.Column(scale=1.0):
+ # video_length = gr.Slider(
+ # minimum=1,
+ # maximum=24,
+ # value=1,
+ # step=1,
+ # interactive=True,
+ # label="Video Length (1 for T2I and 16 for T2V)",
+ # )
+ with gr.Column(scale=0.5):
+ video_length = gr.Dropdown(choices=[1, 16], label="Video Length (1 for T2I and 16 for T2V)", value=16)
+ with gr.Row():
+ with gr.Column(scale=1.0):
+ scfg_scale = gr.Slider(
+ minimum=1,
+ maximum=50,
+ value=7.5,
+ step=0.1,
+ interactive=True,
+ label="Guidence Scale",
+ )
+ with gr.Row():
+ with gr.Column(scale=1.0):
+ seed = gr.Slider(
+ minimum=1,
+ maximum=2147483647,
+ value=100,
+ step=1,
+ interactive=True,
+ label="Seed",
+ )
+ with gr.Row():
+ with gr.Column(scale=0.5):
+ height = gr.Slider(
+ minimum=256,
+ maximum=768,
+ value=512,
+ step=16,
+ interactive=False,
+ label="Height",
+ )
+ # with gr.Row():
+ with gr.Column(scale=0.5):
+ width = gr.Slider(
+ minimum=256,
+ maximum=768,
+ value=512,
+ step=16,
+ interactive=False,
+ label="Width",
+ )
+ with gr.Row():
+ with gr.Column(scale=1.0):
+ diffusion_step = gr.Slider(
+ minimum=20,
+ maximum=250,
+ value=50,
+ step=1,
+ interactive=True,
+ label="Sampling Step",
+ )
+
+
+ with gr.Column(scale=0.6, visible=True) as video_upload:
+ # with gr.Column(visible=True) as video_upload:
+ output = gr.Video(interactive=False, include_audio=True, elem_id="่พๅบ็่ง้ข") #.style(height=360)
+ # with gr.Column(elem_id="image", scale=0.5) as img_part:
+ # with gr.Tab("Video", elem_id='video_tab'):
+
+ # with gr.Tab("Image", elem_id='image_tab'):
+ # up_image = gr.Image(type="pil", interactive=True, elem_id="image_upload").style(height=360)
+ # upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
+ # clear = gr.Button("Restart")
+
+ with gr.Row():
+ with gr.Column(scale=1.0, min_width=0):
+ run = gr.Button("๐ญRun")
+ # with gr.Column(scale=0.5, min_width=0):
+ # clear = gr.Button("๐Clear๏ธ")
+
+ run.click(gen_video, [text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step], [output])
+
+demo.launch(debug=False, share=True)
+
+# demo.launch(server_name="0.0.0.0", server_port=10034, enable_queue=True)
diff --git a/diffusion/__init__.py b/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..30bd2db911cdb43cb1e8385baccc1e8ee49f1184
--- /dev/null
+++ b/diffusion/__init__.py
@@ -0,0 +1,47 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+from . import gaussian_diffusion as gd
+from .respace import SpacedDiffusion, space_timesteps
+
+
+def create_diffusion(
+ timestep_respacing,
+ noise_schedule="linear",
+ use_kl=False,
+ sigma_small=False,
+ predict_xstart=False,
+ learn_sigma=True,
+ # learn_sigma=False,
+ rescale_learned_sigmas=False,
+ diffusion_steps=1000
+):
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
+ if use_kl:
+ loss_type = gd.LossType.RESCALED_KL
+ elif rescale_learned_sigmas:
+ loss_type = gd.LossType.RESCALED_MSE
+ else:
+ loss_type = gd.LossType.MSE
+ if timestep_respacing is None or timestep_respacing == "":
+ timestep_respacing = [diffusion_steps]
+ return SpacedDiffusion(
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
+ betas=betas,
+ model_mean_type=(
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
+ ),
+ model_var_type=(
+ (
+ gd.ModelVarType.FIXED_LARGE
+ if not sigma_small
+ else gd.ModelVarType.FIXED_SMALL
+ )
+ if not learn_sigma
+ else gd.ModelVarType.LEARNED_RANGE
+ ),
+ loss_type=loss_type
+ # rescale_timesteps=rescale_timesteps,
+ )
diff --git a/diffusion/diffusion_utils.py b/diffusion/diffusion_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1756908f0e22613c7e8ed91936c232b097650c13
--- /dev/null
+++ b/diffusion/diffusion_utils.py
@@ -0,0 +1,88 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+import torch as th
+import numpy as np
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, th.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for th.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + th.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
+ )
+
+
+def approx_standard_normal_cdf(x):
+ """
+ A fast approximation of the cumulative distribution function of the
+ standard normal.
+ """
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
+
+
+def continuous_gaussian_log_likelihood(x, *, means, log_scales):
+ """
+ Compute the log-likelihood of a continuous Gaussian distribution.
+ :param x: the targets
+ :param means: the Gaussian mean Tensor.
+ :param log_scales: the Gaussian log stddev Tensor.
+ :return: a tensor like x of log probabilities (in nats).
+ """
+ centered_x = x - means
+ inv_stdv = th.exp(-log_scales)
+ normalized_x = centered_x * inv_stdv
+ log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
+ return log_probs
+
+
+def discretized_gaussian_log_likelihood(x, *, means, log_scales):
+ """
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
+ given image.
+ :param x: the target images. It is assumed that this was uint8 values,
+ rescaled to the range [-1, 1].
+ :param means: the Gaussian mean Tensor.
+ :param log_scales: the Gaussian log stddev Tensor.
+ :return: a tensor like x of log probabilities (in nats).
+ """
+ assert x.shape == means.shape == log_scales.shape
+ centered_x = x - means
+ inv_stdv = th.exp(-log_scales)
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
+ cdf_plus = approx_standard_normal_cdf(plus_in)
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
+ cdf_min = approx_standard_normal_cdf(min_in)
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
+ cdf_delta = cdf_plus - cdf_min
+ log_probs = th.where(
+ x < -0.999,
+ log_cdf_plus,
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
+ )
+ assert log_probs.shape == x.shape
+ return log_probs
diff --git a/diffusion/gaussian_diffusion.py b/diffusion/gaussian_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..9933d65e9e7db6780db524c09b02ac0304753584
--- /dev/null
+++ b/diffusion/gaussian_diffusion.py
@@ -0,0 +1,881 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+
+import math
+
+import numpy as np
+import torch as th
+import enum
+
+from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+class ModelMeanType(enum.Enum):
+ """
+ Which type of output the model predicts.
+ """
+
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
+ START_X = enum.auto() # the model predicts x_0
+ EPSILON = enum.auto() # the model predicts epsilon
+
+
+class ModelVarType(enum.Enum):
+ """
+ What is used as the model's output variance.
+ The LEARNED_RANGE option has been added to allow the model to predict
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
+ """
+
+ LEARNED = enum.auto()
+ FIXED_SMALL = enum.auto()
+ FIXED_LARGE = enum.auto()
+ LEARNED_RANGE = enum.auto()
+
+
+class LossType(enum.Enum):
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
+ RESCALED_MSE = (
+ enum.auto()
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
+ KL = enum.auto() # use the variational lower-bound
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
+
+ def is_vb(self):
+ return self == LossType.KL or self == LossType.RESCALED_KL
+
+
+def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
+ return betas
+
+
+def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
+ """
+ This is the deprecated API for creating beta schedules.
+ See get_named_beta_schedule() for the new library of schedules.
+ """
+ if beta_schedule == "quad":
+ betas = (
+ np.linspace(
+ beta_start ** 0.5,
+ beta_end ** 0.5,
+ num_diffusion_timesteps,
+ dtype=np.float64,
+ )
+ ** 2
+ )
+ elif beta_schedule == "linear":
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
+ elif beta_schedule == "warmup10":
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
+ elif beta_schedule == "warmup50":
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
+ elif beta_schedule == "const":
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
+ betas = 1.0 / np.linspace(
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
+ )
+ else:
+ raise NotImplementedError(beta_schedule)
+ assert betas.shape == (num_diffusion_timesteps,)
+ return betas
+
+
+def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
+ """
+ Get a pre-defined beta schedule for the given name.
+ The beta schedule library consists of beta schedules which remain similar
+ in the limit of num_diffusion_timesteps.
+ Beta schedules may be added, but should not be removed or changed once
+ they are committed to maintain backwards compatibility.
+ """
+ if schedule_name == "linear":
+ # Linear schedule from Ho et al, extended to work for any number of
+ # diffusion steps.
+ scale = 1000 / num_diffusion_timesteps
+ return get_beta_schedule(
+ "linear",
+ beta_start=scale * 0.0001,
+ beta_end=scale * 0.02,
+ num_diffusion_timesteps=num_diffusion_timesteps,
+ )
+ elif schedule_name == "squaredcos_cap_v2":
+ return betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
+ )
+ else:
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+class GaussianDiffusion:
+ """
+ Utilities for training and sampling diffusion models.
+ Original ported from this codebase:
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
+ starting at T and going to 1.
+ """
+
+ def __init__(
+ self,
+ *,
+ betas,
+ model_mean_type,
+ model_var_type,
+ loss_type
+ ):
+
+ self.model_mean_type = model_mean_type
+ self.model_var_type = model_var_type
+ self.loss_type = loss_type
+
+ # Use float64 for accuracy.
+ betas = np.array(betas, dtype=np.float64)
+ self.betas = betas
+ assert len(betas.shape) == 1, "betas must be 1-D"
+ assert (betas > 0).all() and (betas <= 1).all()
+
+ self.num_timesteps = int(betas.shape[0])
+
+ alphas = 1.0 - betas
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ self.posterior_variance = (
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.posterior_log_variance_clipped = np.log(
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
+ ) if len(self.posterior_variance) > 1 else np.array([])
+
+ self.posterior_mean_coef1 = (
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ self.posterior_mean_coef2 = (
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
+ )
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+ return mean, variance, log_variance
+
+ def q_sample(self, x_start, t, noise=None):
+ """
+ Diffuse the data for a given number of diffusion steps.
+ In other words, sample from q(x_t | x_0).
+ :param x_start: the initial data batch.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :param noise: if specified, the split-out normal noise.
+ :return: A noisy version of x_start.
+ """
+ if noise is None:
+ noise = th.randn_like(x_start)
+ assert noise.shape == x_start.shape
+ return (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
+ )
+
+ def q_posterior_mean_variance(self, x_start, x_t, t):
+ """
+ Compute the mean and variance of the diffusion posterior:
+ q(x_{t-1} | x_t, x_0)
+ """
+ assert x_start.shape == x_t.shape
+ posterior_mean = (
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x_t.shape
+ )
+ assert (
+ posterior_mean.shape[0]
+ == posterior_variance.shape[0]
+ == posterior_log_variance_clipped.shape[0]
+ == x_start.shape[0]
+ )
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
+ """
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
+ the initial x, x_0.
+ :param model: the model, which takes a signal and a batch of timesteps
+ as input.
+ :param x: the [N x C x ...] tensor at time t.
+ :param t: a 1-D Tensor of timesteps.
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample. Applies before
+ clip_denoised.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict with the following keys:
+ - 'mean': the model mean output.
+ - 'variance': the model variance output.
+ - 'log_variance': the log of 'variance'.
+ - 'pred_xstart': the prediction for x_0.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+
+ B, F, C = x.shape[:3]
+ assert t.shape == (B,)
+ model_output = model(x, t, **model_kwargs)
+ # try:
+ # model_output = model_output.sample # for tav unet
+ # except:
+ # model_output = model(x, t, **model_kwargs)
+ if isinstance(model_output, tuple):
+ model_output, extra = model_output
+ else:
+ extra = None
+
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
+ assert model_output.shape == (B, F, C * 2, *x.shape[3:])
+ model_output, model_var_values = th.split(model_output, C, dim=2)
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
+ # The model_var_values is [-1, 1] for [min_var, max_var].
+ frac = (model_var_values + 1) / 2
+ model_log_variance = frac * max_log + (1 - frac) * min_log
+ model_variance = th.exp(model_log_variance)
+ else:
+ model_variance, model_log_variance = {
+ # for fixedlarge, we set the initial (log-)variance like so
+ # to get a better decoder log likelihood.
+ ModelVarType.FIXED_LARGE: (
+ np.append(self.posterior_variance[1], self.betas[1:]),
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
+ ),
+ ModelVarType.FIXED_SMALL: (
+ self.posterior_variance,
+ self.posterior_log_variance_clipped,
+ ),
+ }[self.model_var_type]
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
+
+ def process_xstart(x):
+ if denoised_fn is not None:
+ x = denoised_fn(x)
+ if clip_denoised:
+ return x.clamp(-1, 1)
+ return x
+
+ if self.model_mean_type == ModelMeanType.START_X:
+ pred_xstart = process_xstart(model_output)
+ else:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
+ )
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
+
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
+ return {
+ "mean": model_mean,
+ "variance": model_variance,
+ "log_variance": model_log_variance,
+ "pred_xstart": pred_xstart,
+ "extra": extra,
+ }
+
+ def _predict_xstart_from_eps(self, x_t, t, eps):
+ assert x_t.shape == eps.shape
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
+ )
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute the mean for the previous step, given a function cond_fn that
+ computes the gradient of a conditional log probability with respect to
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
+ condition on y.
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
+ """
+ gradient = cond_fn(x, t, **model_kwargs)
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
+ return new_mean
+
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute what the p_mean_variance output would have been, should the
+ model's score function be conditioned by cond_fn.
+ See condition_mean() for details on cond_fn.
+ Unlike condition_mean(), this instead uses the conditioning strategy
+ from Song et al (2020).
+ """
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
+
+ out = p_mean_var.copy()
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
+ return out
+
+ def p_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ ):
+ """
+ Sample x_{t-1} from the model at the given timestep.
+ :param model: the model to sample from.
+ :param x: the current tensor at x_{t-1}.
+ :param t: the value of t, starting at 0 for the first diffusion step.
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - 'sample': a random sample from the model.
+ - 'pred_xstart': a prediction of x_0.
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ noise = th.randn_like(x)
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ if cond_fn is not None:
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def p_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model.
+ :param model: the model module.
+ :param shape: the shape of the samples, (N, C, H, W).
+ :param noise: if specified, the noise from the encoder to sample.
+ Should be of the same shape as `shape`.
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param device: if specified, the device to create the samples on.
+ If not specified, use a model parameter's device.
+ :param progress: if True, show a tqdm progress bar.
+ :return: a non-differentiable batch of samples.
+ """
+ final = None
+ for sample in self.p_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ ):
+ final = sample
+ return final["sample"]
+
+ def p_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model and yield intermediate samples from
+ each timestep of diffusion.
+ Arguments are the same as p_sample_loop().
+ Returns a generator over dicts, where each dict is the return value of
+ p_sample().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.p_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ )
+ yield out
+ img = out["sample"]
+
+ def ddim_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t-1} from the model using DDIM.
+ Same usage as p_sample().
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ if cond_fn is not None:
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
+
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
+
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
+ sigma = (
+ eta
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
+ )
+ # Equation 12.
+ noise = th.randn_like(x)
+ mean_pred = (
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
+ )
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ sample = mean_pred + nonzero_mask * sigma * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_reverse_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t+1} from the model using DDIM reverse ODE.
+ """
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ if cond_fn is not None:
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
+ - out["pred_xstart"]
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
+
+ # Equation 12. reversed
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
+
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ ):
+ """
+ Generate samples from the model using DDIM.
+ Same usage as p_sample_loop().
+ """
+ final = None
+ for sample in self.ddim_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ eta=eta,
+ ):
+ final = sample
+ return final["sample"]
+
+ def ddim_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ ):
+ """
+ Use DDIM to sample from the model and yield intermediate samples from
+ each timestep of DDIM.
+ Same usage as p_sample_loop_progressive().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.ddim_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ eta=eta,
+ )
+ yield out
+ img = out["sample"]
+
+ def _vb_terms_bpd(
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
+ ):
+ """
+ Get a term for the variational lower-bound.
+ The resulting units are bits (rather than nats, as one might expect).
+ This allows for comparison to other papers.
+ :return: a dict with the following keys:
+ - 'output': a shape [N] tensor of NLLs or KLs.
+ - 'pred_xstart': the x_0 predictions.
+ """
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )
+ out = self.p_mean_variance(
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
+ )
+ kl = normal_kl(
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
+ )
+ kl = mean_flat(kl) / np.log(2.0)
+
+ decoder_nll = -discretized_gaussian_log_likelihood(
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
+ )
+ assert decoder_nll.shape == x_start.shape
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
+
+ # At the first timestep return the decoder NLL,
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
+ output = th.where((t == 0), decoder_nll, kl)
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
+
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
+ """
+ Compute training losses for a single timestep.
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param t: a batch of timestep indices.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param noise: if specified, the specific Gaussian noise to try to remove.
+ :return: a dict with the key "loss" containing a tensor of shape [N].
+ Some mean or variance settings may also have other keys.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+ if noise is None:
+ noise = th.randn_like(x_start)
+ x_t = self.q_sample(x_start, t, noise=noise)
+
+ terms = {}
+
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] = self._vb_terms_bpd(
+ model=model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ model_kwargs=model_kwargs,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] *= self.num_timesteps
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
+ model_output = model(x_t, t, **model_kwargs)
+ # try:
+ # model_output = model(x_t, t, **model_kwargs).sample # for tav unet
+ # except:
+ # model_output = model(x_t, t, **model_kwargs)
+
+ if self.model_var_type in [
+ ModelVarType.LEARNED,
+ ModelVarType.LEARNED_RANGE,
+ ]:
+ B, F, C = x_t.shape[:3]
+ assert model_output.shape == (B, F, C * 2, *x_t.shape[3:])
+ model_output, model_var_values = th.split(model_output, C, dim=2)
+ # Learn the variance using the variational bound, but don't let
+ # it affect our mean prediction.
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=2)
+ terms["vb"] = self._vb_terms_bpd(
+ model=lambda *args, r=frozen_out: r,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_MSE:
+ # Divide by 1000 for equivalence with initial implementation.
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
+ terms["vb"] *= self.num_timesteps / 1000.0
+
+ target = {
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )[0],
+ ModelMeanType.START_X: x_start,
+ ModelMeanType.EPSILON: noise,
+ }[self.model_mean_type]
+ assert model_output.shape == target.shape == x_start.shape
+ terms["mse"] = mean_flat((target - model_output) ** 2)
+ if "vb" in terms:
+ terms["loss"] = terms["mse"] + terms["vb"]
+ else:
+ terms["loss"] = terms["mse"]
+ else:
+ raise NotImplementedError(self.loss_type)
+
+ return terms
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+ This term can't be optimized, as it only depends on the encoder.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
+ )
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
+ """
+ Compute the entire variational lower-bound, measured in bits-per-dim,
+ as well as other related quantities.
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param clip_denoised: if True, clip denoised samples.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - total_bpd: the total variational lower-bound, per batch element.
+ - prior_bpd: the prior term in the lower-bound.
+ - vb: an [N x T] tensor of terms in the lower-bound.
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
+ """
+ device = x_start.device
+ batch_size = x_start.shape[0]
+
+ vb = []
+ xstart_mse = []
+ mse = []
+ for t in list(range(self.num_timesteps))[::-1]:
+ t_batch = th.tensor([t] * batch_size, device=device)
+ noise = th.randn_like(x_start)
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
+ # Calculate VLB term at the current timestep
+ with th.no_grad():
+ out = self._vb_terms_bpd(
+ model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t_batch,
+ clip_denoised=clip_denoised,
+ model_kwargs=model_kwargs,
+ )
+ vb.append(out["output"])
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
+ mse.append(mean_flat((eps - noise) ** 2))
+
+ vb = th.stack(vb, dim=1)
+ xstart_mse = th.stack(xstart_mse, dim=1)
+ mse = th.stack(mse, dim=1)
+
+ prior_bpd = self._prior_bpd(x_start)
+ total_bpd = vb.sum(dim=1) + prior_bpd
+ return {
+ "total_bpd": total_bpd,
+ "prior_bpd": prior_bpd,
+ "vb": vb,
+ "xstart_mse": xstart_mse,
+ "mse": mse,
+ }
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ Extract values from a 1-D numpy array for a batch of indices.
+ :param arr: the 1-D numpy array.
+ :param timesteps: a tensor of indices into the array to extract.
+ :param broadcast_shape: a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
diff --git a/diffusion/respace.py b/diffusion/respace.py
new file mode 100644
index 0000000000000000000000000000000000000000..a23f8e2a90e91efc5de3d368a7f3503c9e1bca7f
--- /dev/null
+++ b/diffusion/respace.py
@@ -0,0 +1,130 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+import torch
+import numpy as np
+import torch as th
+
+from .gaussian_diffusion import GaussianDiffusion
+
+
+def space_timesteps(num_timesteps, section_counts):
+ """
+ Create a list of timesteps to use from an original diffusion process,
+ given the number of timesteps we want to take from equally-sized portions
+ of the original process.
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
+ If the stride is a string starting with "ddim", then the fixed striding
+ from the DDIM paper is used, and only one section is allowed.
+ :param num_timesteps: the number of diffusion steps in the original
+ process to divide up.
+ :param section_counts: either a list of numbers, or a string containing
+ comma-separated numbers, indicating the step count
+ per section. As a special case, use "ddimN" where N
+ is a number of steps to use the striding from the
+ DDIM paper.
+ :return: a set of diffusion steps from the original process to use.
+ """
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim") :])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
+ )
+ section_counts = [int(x) for x in section_counts.split(",")]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(
+ f"cannot divide section of {size} steps into {section_count}"
+ )
+ if section_count <= 1:
+ frac_stride = 1
+ else:
+ frac_stride = (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+
+class SpacedDiffusion(GaussianDiffusion):
+ """
+ A diffusion process which can skip steps in a base diffusion process.
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
+ original diffusion process to retain.
+ :param kwargs: the kwargs to create the base diffusion process.
+ """
+
+ def __init__(self, use_timesteps, **kwargs):
+ self.use_timesteps = set(use_timesteps)
+ self.timestep_map = []
+ self.original_num_steps = len(kwargs["betas"])
+
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
+ last_alpha_cumprod = 1.0
+ new_betas = []
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
+ if i in self.use_timesteps:
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+ last_alpha_cumprod = alpha_cumprod
+ self.timestep_map.append(i)
+ kwargs["betas"] = np.array(new_betas)
+ super().__init__(**kwargs)
+
+ def p_mean_variance(
+ self, model, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
+
+ # @torch.compile
+ def training_losses(
+ self, model, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
+
+ def condition_mean(self, cond_fn, *args, **kwargs):
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def condition_score(self, cond_fn, *args, **kwargs):
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def _wrap_model(self, model):
+ if isinstance(model, _WrappedModel):
+ return model
+ return _WrappedModel(
+ model, self.timestep_map, self.original_num_steps
+ )
+
+ def _scale_timesteps(self, t):
+ # Scaling is done by the wrapped model.
+ return t
+
+
+class _WrappedModel:
+ def __init__(self, model, timestep_map, original_num_steps):
+ self.model = model
+ self.timestep_map = timestep_map
+ # self.rescale_timesteps = rescale_timesteps
+ self.original_num_steps = original_num_steps
+
+ def __call__(self, x, ts, **kwargs):
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
+ new_ts = map_tensor[ts]
+ # if self.rescale_timesteps:
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
+ return self.model(x, new_ts, **kwargs)
diff --git a/diffusion/timestep_sampler.py b/diffusion/timestep_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cab6f4ba5956ac839642f1c049f6d3b9fcdb5c4
--- /dev/null
+++ b/diffusion/timestep_sampler.py
@@ -0,0 +1,150 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+from abc import ABC, abstractmethod
+
+import numpy as np
+import torch as th
+import torch.distributed as dist
+
+
+def create_named_schedule_sampler(name, diffusion):
+ """
+ Create a ScheduleSampler from a library of pre-defined samplers.
+ :param name: the name of the sampler.
+ :param diffusion: the diffusion object to sample for.
+ """
+ if name == "uniform":
+ return UniformSampler(diffusion)
+ elif name == "loss-second-moment":
+ return LossSecondMomentResampler(diffusion)
+ else:
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
+
+
+class ScheduleSampler(ABC):
+ """
+ A distribution over timesteps in the diffusion process, intended to reduce
+ variance of the objective.
+ By default, samplers perform unbiased importance sampling, in which the
+ objective's mean is unchanged.
+ However, subclasses may override sample() to change how the resampled
+ terms are reweighted, allowing for actual changes in the objective.
+ """
+
+ @abstractmethod
+ def weights(self):
+ """
+ Get a numpy array of weights, one per diffusion step.
+ The weights needn't be normalized, but must be positive.
+ """
+
+ def sample(self, batch_size, device):
+ """
+ Importance-sample timesteps for a batch.
+ :param batch_size: the number of timesteps.
+ :param device: the torch device to save to.
+ :return: a tuple (timesteps, weights):
+ - timesteps: a tensor of timestep indices.
+ - weights: a tensor of weights to scale the resulting losses.
+ """
+ w = self.weights()
+ p = w / np.sum(w)
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
+ indices = th.from_numpy(indices_np).long().to(device)
+ weights_np = 1 / (len(p) * p[indices_np])
+ weights = th.from_numpy(weights_np).float().to(device)
+ return indices, weights
+
+
+class UniformSampler(ScheduleSampler):
+ def __init__(self, diffusion):
+ self.diffusion = diffusion
+ self._weights = np.ones([diffusion.num_timesteps])
+
+ def weights(self):
+ return self._weights
+
+
+class LossAwareSampler(ScheduleSampler):
+ def update_with_local_losses(self, local_ts, local_losses):
+ """
+ Update the reweighting using losses from a model.
+ Call this method from each rank with a batch of timesteps and the
+ corresponding losses for each of those timesteps.
+ This method will perform synchronization to make sure all of the ranks
+ maintain the exact same reweighting.
+ :param local_ts: an integer Tensor of timesteps.
+ :param local_losses: a 1D Tensor of losses.
+ """
+ batch_sizes = [
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
+ for _ in range(dist.get_world_size())
+ ]
+ dist.all_gather(
+ batch_sizes,
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
+ )
+
+ # Pad all_gather batches to be the maximum batch size.
+ batch_sizes = [x.item() for x in batch_sizes]
+ max_bs = max(batch_sizes)
+
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
+ dist.all_gather(timestep_batches, local_ts)
+ dist.all_gather(loss_batches, local_losses)
+ timesteps = [
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
+ ]
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
+ self.update_with_all_losses(timesteps, losses)
+
+ @abstractmethod
+ def update_with_all_losses(self, ts, losses):
+ """
+ Update the reweighting using losses from a model.
+ Sub-classes should override this method to update the reweighting
+ using losses from the model.
+ This method directly updates the reweighting without synchronizing
+ between workers. It is called by update_with_local_losses from all
+ ranks with identical arguments. Thus, it should have deterministic
+ behavior to maintain state across workers.
+ :param ts: a list of int timesteps.
+ :param losses: a list of float losses, one per timestep.
+ """
+
+
+class LossSecondMomentResampler(LossAwareSampler):
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
+ self.diffusion = diffusion
+ self.history_per_term = history_per_term
+ self.uniform_prob = uniform_prob
+ self._loss_history = np.zeros(
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
+ )
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
+
+ def weights(self):
+ if not self._warmed_up():
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
+ weights /= np.sum(weights)
+ weights *= 1 - self.uniform_prob
+ weights += self.uniform_prob / len(weights)
+ return weights
+
+ def update_with_all_losses(self, ts, losses):
+ for t, loss in zip(ts, losses):
+ if self._loss_counts[t] == self.history_per_term:
+ # Shift out the oldest loss term.
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
+ self._loss_history[t, -1] = loss
+ else:
+ self._loss_history[t, self._loss_counts[t]] = loss
+ self._loss_counts[t] += 1
+
+ def _warmed_up(self):
+ return (self._loss_counts == self.history_per_term).all()
diff --git a/docs/datasets_evaluation.md b/docs/datasets_evaluation.md
new file mode 100644
index 0000000000000000000000000000000000000000..b04e2f9f13cc8231580ed2d4013b0f9ccd016ab5
--- /dev/null
+++ b/docs/datasets_evaluation.md
@@ -0,0 +1,53 @@
+## Download datasets
+
+Here are the links to download the datasets [FaceForensics](https://huggingface.co./datasets/maxin-cn/FaceForensics), [SkyTimelapse](https://huggingface.co./datasets/maxin-cn/SkyTimelapse/tree/main), [UCF101](https://www.crcv.ucf.edu/data/UCF101/UCF101.rar), and [Taichi-HD](https://huggingface.co./datasets/maxin-cn/Taichi-HD).
+
+
+## Dataset structure
+
+All datasets follow their original dataset structure. As for video-image joint training, there is a `train_list.txt` file, whose format is `video_name/frame.jpg`. Here, we show an example of the FaceForensics datsset.
+
+All datasets retain their original structure. For video-image joint training, there is a `train_list.txt` file formatted as `video_name/frame.jpg`. Below is an example from the FaceForensics dataset.
+
+```bash
+aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000306.jpg
+aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000111.jpg
+aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000007.jpg
+aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000057.jpg
+aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000084.jpg
+aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000268.jpg
+aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000270.jpg
+aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000259.jpg
+aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000127.jpg
+aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000099.jpg
+aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000189.jpg
+aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000228.jpg
+aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000026.jpg
+aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000081.jpg
+aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000094.jpg
+aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000223.jpg
+aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000055.jpg
+qEnKi82wWgE_2_rJPM8EdWShs_1/000486.jpg
+qEnKi82wWgE_2_rJPM8EdWShs_1/000396.jpg
+qEnKi82wWgE_2_rJPM8EdWShs_1/000475.jpg
+qEnKi82wWgE_2_rJPM8EdWShs_1/000028.jpg
+qEnKi82wWgE_2_rJPM8EdWShs_1/000261.jpg
+qEnKi82wWgE_2_rJPM8EdWShs_1/000294.jpg
+qEnKi82wWgE_2_rJPM8EdWShs_1/000257.jpg
+qEnKi82wWgE_2_rJPM8EdWShs_1/000490.jpg
+qEnKi82wWgE_2_rJPM8EdWShs_1/000143.jpg
+qEnKi82wWgE_2_rJPM8EdWShs_1/000190.jpg
+qEnKi82wWgE_2_rJPM8EdWShs_1/000476.jpg
+qEnKi82wWgE_2_rJPM8EdWShs_1/000397.jpg
+qEnKi82wWgE_2_rJPM8EdWShs_1/000437.jpg
+qEnKi82wWgE_2_rJPM8EdWShs_1/000071.jpg
+```
+
+## Evaluation
+
+We follow [StyleGAN-V](https://github.com/universome/stylegan-v) to measure the quality of the generated video. The code for calculating the relevant metrics is located in [tools](../tools/) folder. To measure the quantitative metrics of your generated results, you need to put all the videos from real data into a folder and turn them into video frames (the same goes for fake data). Then you can run the following command on one GPU:
+
+```bash
+# cd Latte
+bash tools/eval_metrics.sh
+```
\ No newline at end of file
diff --git a/docs/latte_diffusers.md b/docs/latte_diffusers.md
new file mode 100644
index 0000000000000000000000000000000000000000..3965510e71bf237b42a6390784fa70c884767f78
--- /dev/null
+++ b/docs/latte_diffusers.md
@@ -0,0 +1,106 @@
+## Requirements
+
+Please follow [README](../README.md) to install the environment. After installation, update the version of `diffusers` at leaset to 0.30.0.
+
+## Inference
+
+```bash
+from diffusers import LattePipeline
+from diffusers.models import AutoencoderKLTemporalDecoder
+
+from torchvision.utils import save_image
+
+import torch
+import imageio
+
+torch.manual_seed(0)
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+video_length = 1 # 1 or 16
+pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16).to(device)
+
+# if you want to use the temporal decoder of VAE, please uncomment the following codes
+# vae = AutoencoderKLTemporalDecoder.from_pretrained("maxin-cn/Latte-1", subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
+# pipe.vae = vae
+
+prompt = "a cat wearing sunglasses and working as a lifeguard at pool."
+videos = pipe(prompt, video_length=video_length, output_type='pt').frames.cpu()
+
+if video_length > 1:
+ videos = (videos.clamp(0, 1) * 255).to(dtype=torch.uint8) # convert to uint8
+ imageio.mimwrite('./latte_output.mp4', videos[0].permute(0, 2, 3, 1), fps=8, quality=5) # highest quality is 10, lowest is 0
+else:
+ save_image(videos[0], './latte_output.png')
+```
+
+## Inference with 4/8-bit quantization
+[@Aryan](https://github.com/a-r-r-o-w) provides a quantization solution for inference, which can reduce GPU memory from 17 GB to 9 GB. Note that please install `bitsandbytes` (`pip install bitsandbytes`).
+
+```bash
+import gc
+
+import torch
+from diffusers import LattePipeline
+from transformers import T5EncoderModel, BitsAndBytesConfig
+import imageio
+from torchvision.utils import save_image
+
+torch.manual_seed(0)
+
+def flush():
+ gc.collect()
+ torch.cuda.empty_cache()
+
+def bytes_to_giga_bytes(bytes):
+ return bytes / 1024 / 1024 / 1024
+
+video_length = 16
+model_id = "maxin-cn/Latte-1/"
+
+text_encoder = T5EncoderModel.from_pretrained(
+ model_id,
+ subfolder="text_encoder",
+ quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16),
+ device_map="auto",
+)
+pipe = LattePipeline.from_pretrained(
+ model_id,
+ text_encoder=text_encoder,
+ transformer=None,
+ device_map="balanced",
+)
+
+with torch.no_grad():
+ prompt = "a cat wearing sunglasses and working as a lifeguard at pool."
+ negative_prompt = ""
+ prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt, negative_prompt=negative_prompt)
+
+del text_encoder
+del pipe
+flush()
+
+pipe = LattePipeline.from_pretrained(
+ model_id,
+ text_encoder=None,
+ torch_dtype=torch.float16,
+).to("cuda")
+# pipe.enable_vae_tiling()
+# pipe.enable_vae_slicing()
+
+videos = pipe(
+ video_length=video_length,
+ num_inference_steps=50,
+ negative_prompt=None,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ output_type="pt",
+).frames.cpu()
+
+print(f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB")
+
+if video_length > 1:
+ videos = (videos.clamp(0, 1) * 255).to(dtype=torch.uint8) # convert to uint8
+ imageio.mimwrite('./latte_output.mp4', videos[0].permute(0, 2, 3, 1), fps=8, quality=5) # highest quality is 10, lowest is 0
+else:
+ save_image(videos[0], './latte_output.png')
+```
\ No newline at end of file
diff --git a/environment.yml b/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..f4a4ad28501ba1ad6e028d6c59197e3570b1f917
--- /dev/null
+++ b/environment.yml
@@ -0,0 +1,25 @@
+name: latte
+channels:
+ - pytorch
+ - nvidia
+dependencies:
+ - python >= 3.10
+ - pytorch > 2.0.0
+ - torchvision
+ - pytorch-cuda >= 11.7
+ - pip:
+ - timm
+ - diffusers[torch]==0.24.0
+ - accelerate
+ - tensorboard
+ - einops
+ - transformers
+ - av
+ - scikit-image
+ - decord
+ - pandas
+ - imageio-ffmpeg
+ - sentencepiece
+ - beautifulsoup4
+ - ftfy
+ - omegaconf
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..76cc3eeff082af06c124a7f1926a9dec30172aae
--- /dev/null
+++ b/models/__init__.py
@@ -0,0 +1,52 @@
+import os
+import sys
+sys.path.append(os.path.split(sys.path[0])[0])
+
+from .latte import Latte_models
+from .latte_img import LatteIMG_models
+from .latte_t2v import LatteT2V
+
+from torch.optim.lr_scheduler import LambdaLR
+
+
+def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit
+ from torch.optim.lr_scheduler import LambdaLR
+ def fn(step):
+ if warmup_steps > 0:
+ return min(step / warmup_steps, 1)
+ else:
+ return 1
+ return LambdaLR(optimizer, fn)
+
+
+def get_lr_scheduler(optimizer, name, **kwargs):
+ if name == 'warmup':
+ return customized_lr_scheduler(optimizer, **kwargs)
+ elif name == 'cosine':
+ from torch.optim.lr_scheduler import CosineAnnealingLR
+ return CosineAnnealingLR(optimizer, **kwargs)
+ else:
+ raise NotImplementedError(name)
+
+def get_models(args):
+ if 'LatteIMG' in args.model:
+ return LatteIMG_models[args.model](
+ input_size=args.latent_size,
+ num_classes=args.num_classes,
+ num_frames=args.num_frames,
+ learn_sigma=args.learn_sigma,
+ extras=args.extras
+ )
+ elif 'LatteT2V' in args.model:
+ return LatteT2V.from_pretrained(args.pretrained_model_path, subfolder="transformer", video_length=args.video_length)
+ elif 'Latte' in args.model:
+ return Latte_models[args.model](
+ input_size=args.latent_size,
+ num_classes=args.num_classes,
+ num_frames=args.num_frames,
+ learn_sigma=args.learn_sigma,
+ extras=args.extras
+ )
+ else:
+ raise '{} Model Not Supported!'.format(args.model)
+
\ No newline at end of file
diff --git a/models/__pycache__/__init__.cpython-312.pyc b/models/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dde80a0c7c529cf8d91f410e8857b5c8d0f0d2f5
Binary files /dev/null and b/models/__pycache__/__init__.cpython-312.pyc differ
diff --git a/models/__pycache__/latte.cpython-312.pyc b/models/__pycache__/latte.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca64f8d67cb4c73bae22cae132ebfbafc75a4aa8
Binary files /dev/null and b/models/__pycache__/latte.cpython-312.pyc differ
diff --git a/models/__pycache__/latte_img.cpython-312.pyc b/models/__pycache__/latte_img.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bc4eb408a2e9085b0e0084011df87dfa157eda83
Binary files /dev/null and b/models/__pycache__/latte_img.cpython-312.pyc differ
diff --git a/models/__pycache__/latte_t2v.cpython-312.pyc b/models/__pycache__/latte_t2v.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b2c8d2bed30b56638c029af6f83fa99780697e0e
Binary files /dev/null and b/models/__pycache__/latte_t2v.cpython-312.pyc differ
diff --git a/models/clip.py b/models/clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..a818273c08e81b1be428f406492ce48decf40229
--- /dev/null
+++ b/models/clip.py
@@ -0,0 +1,126 @@
+import numpy
+import torch.nn as nn
+from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
+
+import transformers
+transformers.logging.set_verbosity_error()
+
+"""
+Will encounter following warning:
+- This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task
+or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
+- This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model
+that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
+
+https://github.com/CompVis/stable-diffusion/issues/97
+according to this issue, this warning is safe.
+
+This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion.
+You can safely ignore the warning, it is not an error.
+
+This clip usage is from U-ViT and same with Stable Diffusion.
+"""
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
+ # def __init__(self, version="openai/clip-vit-huge-patch14", device="cuda", max_length=77):
+ def __init__(self, path, device="cuda", max_length=77):
+ super().__init__()
+ self.tokenizer = CLIPTokenizer.from_pretrained(path, subfolder="tokenizer")
+ self.transformer = CLIPTextModel.from_pretrained(path, subfolder='text_encoder')
+ self.device = device
+ self.max_length = max_length
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens)
+
+ z = outputs.last_hidden_state
+ pooled_z = outputs.pooler_output
+ return z, pooled_z
+
+ def encode(self, text):
+ return self(text)
+
+
+class TextEmbedder(nn.Module):
+ """
+ Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance.
+ """
+ def __init__(self, path, dropout_prob=0.1):
+ super().__init__()
+ self.text_encodder = FrozenCLIPEmbedder(path=path)
+ self.dropout_prob = dropout_prob
+
+ def token_drop(self, text_prompts, force_drop_ids=None):
+ """
+ Drops text to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob
+ else:
+ # TODO
+ drop_ids = force_drop_ids == 1
+ labels = list(numpy.where(drop_ids, "", text_prompts))
+ # print(labels)
+ return labels
+
+ def forward(self, text_prompts, train, force_drop_ids=None):
+ use_dropout = self.dropout_prob > 0
+ if (train and use_dropout) or (force_drop_ids is not None):
+ text_prompts = self.token_drop(text_prompts, force_drop_ids)
+ embeddings, pooled_embeddings = self.text_encodder(text_prompts)
+ # return embeddings, pooled_embeddings
+ return pooled_embeddings
+
+
+if __name__ == '__main__':
+
+ r"""
+ Returns:
+
+ Examples from CLIPTextModel:
+
+ ```python
+ >>> from transformers import AutoTokenizer, CLIPTextModel
+
+ >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
+
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
+ ```"""
+
+ import torch
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ text_encoder = TextEmbedder(path='/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-2-1-base',
+ dropout_prob=0.00001).to(device)
+
+ text_prompt = [["a photo of a cat", "a photo of a cat"], ["a photo of a dog", "a photo of a cat"], ['a photo of a dog human', "a photo of a cat"]]
+ # text_prompt = ('None', 'None', 'None')
+ output, pooled_output = text_encoder(text_prompts=text_prompt, train=False)
+ # print(output)
+ print(output.shape)
+ print(pooled_output.shape)
+ # print(output.shape)
\ No newline at end of file
diff --git a/models/latte.py b/models/latte.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7ee920e3f6bbf137355e0bd1bfc62f93ba5a21b
--- /dev/null
+++ b/models/latte.py
@@ -0,0 +1,526 @@
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# GLIDE: https://github.com/openai/glide-text2im
+# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
+# --------------------------------------------------------
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+
+from einops import rearrange, repeat
+from timm.models.vision_transformer import Mlp, PatchEmbed
+
+# the xformers lib allows less memory, faster training and inference
+try:
+ import xformers
+ import xformers.ops
+except:
+ XFORMERS_IS_AVAILBLE = False
+
+# from timm.models.layers.helpers import to_2tuple
+# from timm.models.layers.trace_utils import _assert
+
+def modulate(x, shift, scale):
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+#################################################################################
+# Attention Layers from TIMM #
+#################################################################################
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math'):
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+ self.attention_mode = attention_mode
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ if self.attention_mode == 'xformers': # cause loss nan while using with amp
+ # https://github.com/facebookresearch/xformers/blob/e8bd8f932c2f48e3a3171d06749eecbbf1de420c/xformers/ops/fmha/__init__.py#L135
+ q_xf = q.transpose(1,2).contiguous()
+ k_xf = k.transpose(1,2).contiguous()
+ v_xf = v.transpose(1,2).contiguous()
+ x = xformers.ops.memory_efficient_attention(q_xf, k_xf, v_xf).reshape(B, N, C)
+
+ elif self.attention_mode == 'flash':
+ # cause loss nan while using with amp
+ # Optionally use the context manager to ensure one of the fused kerenels is run
+ with torch.backends.cuda.sdp_kernel(enable_math=False):
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v).reshape(B, N, C) # require pytorch 2.0
+
+ elif self.attention_mode == 'math':
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+
+ else:
+ raise NotImplemented
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+#################################################################################
+# Embedding Layers for Timesteps and Class Labels #
+#################################################################################
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t, use_fp16=False):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ if use_fp16:
+ t_freq = t_freq.to(dtype=torch.float16)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class LabelEmbedder(nn.Module):
+ """
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
+ """
+ def __init__(self, num_classes, hidden_size, dropout_prob):
+ super().__init__()
+ use_cfg_embedding = dropout_prob > 0
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
+ self.num_classes = num_classes
+ self.dropout_prob = dropout_prob
+
+ def token_drop(self, labels, force_drop_ids=None):
+ """
+ Drops labels to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
+ else:
+ drop_ids = force_drop_ids == 1
+ labels = torch.where(drop_ids, self.num_classes, labels)
+ return labels
+
+ def forward(self, labels, train, force_drop_ids=None):
+ use_dropout = self.dropout_prob > 0
+ if (train and use_dropout) or (force_drop_ids is not None):
+ labels = self.token_drop(labels, force_drop_ids)
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+
+#################################################################################
+# Core Latte Model #
+#################################################################################
+
+class TransformerBlock(nn.Module):
+ """
+ A Latte tansformer block with adaptive layer norm zero (adaLN-Zero) conditioning.
+ """
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FinalLayer(nn.Module):
+ """
+ The final layer of Latte.
+ """
+ def __init__(self, hidden_size, patch_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+
+class Latte(nn.Module):
+ """
+ Diffusion model with a Transformer backbone.
+ """
+ def __init__(
+ self,
+ input_size=32,
+ patch_size=2,
+ in_channels=4,
+ hidden_size=1152,
+ depth=28,
+ num_heads=16,
+ mlp_ratio=4.0,
+ num_frames=16,
+ class_dropout_prob=0.1,
+ num_classes=1000,
+ learn_sigma=True,
+ extras=1,
+ attention_mode='math',
+ ):
+ super().__init__()
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
+ self.patch_size = patch_size
+ self.num_heads = num_heads
+ self.extras = extras
+ self.num_frames = num_frames
+
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+
+ if self.extras == 2:
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
+ if self.extras == 78: # timestep + text_embedding
+ self.text_embedding_projection = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(77 * 768, hidden_size, bias=True)
+ )
+
+ num_patches = self.x_embedder.num_patches
+ # Will use fixed sin-cos embedding:
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
+ self.temp_embed = nn.Parameter(torch.zeros(1, num_frames, hidden_size), requires_grad=False)
+ self.hidden_size = hidden_size
+
+ self.blocks = nn.ModuleList([
+ TransformerBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attention_mode=attention_mode) for _ in range(depth)
+ ])
+
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ self.apply(_basic_init)
+
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
+
+ temp_embed = get_1d_sincos_temp_embed(self.temp_embed.shape[-1], self.temp_embed.shape[-2])
+ self.temp_embed.data.copy_(torch.from_numpy(temp_embed).float().unsqueeze(0))
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ if self.extras == 2:
+ # Initialize label embedding table:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # Zero-out adaLN modulation layers in Latte blocks:
+ for block in self.blocks:
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def unpatchify(self, x):
+ """
+ x: (N, T, patch_size**2 * C)
+ imgs: (N, H, W, C)
+ """
+ c = self.out_channels
+ p = self.x_embedder.patch_size[0]
+ h = w = int(x.shape[1] ** 0.5)
+ assert h * w == x.shape[1]
+
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
+ x = torch.einsum('nhwpqc->nchpwq', x)
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
+ return imgs
+
+ # @torch.cuda.amp.autocast()
+ # @torch.compile
+ def forward(self,
+ x,
+ t,
+ y=None,
+ text_embedding=None,
+ use_fp16=False):
+ """
+ Forward pass of Latte.
+ x: (N, F, C, H, W) tensor of video inputs
+ t: (N,) tensor of diffusion timesteps
+ y: (N,) tensor of class labels
+ """
+ if use_fp16:
+ x = x.to(dtype=torch.float16)
+
+ batches, frames, channels, high, weight = x.shape
+ x = rearrange(x, 'b f c h w -> (b f) c h w')
+ x = self.x_embedder(x) + self.pos_embed
+ t = self.t_embedder(t, use_fp16=use_fp16)
+ timestep_spatial = repeat(t, 'n d -> (n c) d', c=self.temp_embed.shape[1])
+ timestep_temp = repeat(t, 'n d -> (n c) d', c=self.pos_embed.shape[1])
+
+ if self.extras == 2:
+ y = self.y_embedder(y, self.training)
+ y_spatial = repeat(y, 'n d -> (n c) d', c=self.temp_embed.shape[1])
+ y_temp = repeat(y, 'n d -> (n c) d', c=self.pos_embed.shape[1])
+ elif self.extras == 78:
+ text_embedding = self.text_embedding_projection(text_embedding.reshape(batches, -1))
+ text_embedding_spatial = repeat(text_embedding, 'n d -> (n c) d', c=self.temp_embed.shape[1])
+ text_embedding_temp = repeat(text_embedding, 'n d -> (n c) d', c=self.pos_embed.shape[1])
+
+ for i in range(0, len(self.blocks), 2):
+ spatial_block, temp_block = self.blocks[i:i+2]
+ if self.extras == 2:
+ c = timestep_spatial + y_spatial
+ elif self.extras == 78:
+ c = timestep_spatial + text_embedding_spatial
+ else:
+ c = timestep_spatial
+ x = spatial_block(x, c)
+
+ x = rearrange(x, '(b f) t d -> (b t) f d', b=batches)
+ # Add Time Embedding
+ if i == 0:
+ x = x + self.temp_embed
+
+ if self.extras == 2:
+ c = timestep_temp + y_temp
+ elif self.extras == 78:
+ c = timestep_temp + text_embedding_temp
+ else:
+ c = timestep_temp
+
+ x = temp_block(x, c)
+ x = rearrange(x, '(b t) f d -> (b f) t d', b=batches)
+
+ if self.extras == 2:
+ c = timestep_spatial + y_spatial
+ else:
+ c = timestep_spatial
+ x = self.final_layer(x, c)
+ x = self.unpatchify(x)
+ x = rearrange(x, '(b f) c h w -> b f c h w', b=batches)
+ return x
+
+ def forward_with_cfg(self, x, t, y=None, cfg_scale=7.0, use_fp16=False, text_embedding=None):
+ """
+ Forward pass of Latte, but also batches the unconditional forward pass for classifier-free guidance.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
+ half = x[: len(x) // 2]
+ combined = torch.cat([half, half], dim=0)
+ if use_fp16:
+ combined = combined.to(dtype=torch.float16)
+ model_out = self.forward(combined, t, y=y, use_fp16=use_fp16, text_embedding=text_embedding)
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
+ # three channels by default. The standard approach to cfg applies it to all channels.
+ # This can be done by uncommenting the following line and commenting-out the line following that.
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
+ # eps, rest = model_out[:, :3], model_out[:, 3:]
+ eps, rest = model_out[:, :, :4, ...], model_out[:, :, 4:, ...]
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
+ eps = torch.cat([half_eps, half_eps], dim=0)
+ return torch.cat([eps, rest], dim=2)
+
+
+#################################################################################
+# Sine/Cosine Positional Embedding Functions #
+#################################################################################
+# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
+
+def get_1d_sincos_temp_embed(embed_dim, length):
+ pos = torch.arange(0, length).unsqueeze(1)
+ return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
+
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
+
+ emb = np.concatenate([emb_h, emb_w], axis=1)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000**omega
+
+ pos = pos.reshape(-1)
+ out = np.einsum('m,d->md', pos, omega)
+
+ emb_sin = np.sin(out)
+ emb_cos = np.cos(out)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1)
+ return emb
+
+
+#################################################################################
+# Latte Configs #
+#################################################################################
+
+def Latte_XL_2(**kwargs):
+ return Latte(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
+
+def Latte_XL_4(**kwargs):
+ return Latte(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
+
+def Latte_XL_8(**kwargs):
+ return Latte(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
+
+def Latte_L_2(**kwargs):
+ return Latte(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
+
+def Latte_L_4(**kwargs):
+ return Latte(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
+
+def Latte_L_8(**kwargs):
+ return Latte(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
+
+def Latte_B_2(**kwargs):
+ return Latte(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
+
+def Latte_B_4(**kwargs):
+ return Latte(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
+
+def Latte_B_8(**kwargs):
+ return Latte(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
+
+def Latte_S_2(**kwargs):
+ return Latte(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
+
+def Latte_S_4(**kwargs):
+ return Latte(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
+
+def Latte_S_8(**kwargs):
+ return Latte(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
+
+
+Latte_models = {
+ 'Latte-XL/2': Latte_XL_2, 'Latte-XL/4': Latte_XL_4, 'Latte-XL/8': Latte_XL_8,
+ 'Latte-L/2': Latte_L_2, 'Latte-L/4': Latte_L_4, 'Latte-L/8': Latte_L_8,
+ 'Latte-B/2': Latte_B_2, 'Latte-B/4': Latte_B_4, 'Latte-B/8': Latte_B_8,
+ 'Latte-S/2': Latte_S_2, 'Latte-S/4': Latte_S_4, 'Latte-S/8': Latte_S_8,
+}
+
+if __name__ == '__main__':
+
+ import torch
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ img = torch.randn(3, 16, 4, 32, 32).to(device)
+ t = torch.tensor([1, 2, 3]).to(device)
+ y = torch.tensor([1, 2, 3]).to(device)
+ network = Latte_XL_2().to(device)
+ from thop import profile
+ flops, params = profile(network, inputs=(img, t))
+ print('FLOPs = ' + str(flops/1000**3) + 'G')
+ print('Params = ' + str(params/1000**2) + 'M')
+ # y_embeder = LabelEmbedder(num_classes=101, hidden_size=768, dropout_prob=0.5).to(device)
+ # lora.mark_only_lora_as_trainable(network)
+ # out = y_embeder(y, True)
+ # out = network(img, t, y)
+ # print(out.shape)
diff --git a/models/latte_img.py b/models/latte_img.py
new file mode 100644
index 0000000000000000000000000000000000000000..c468c6354fe8d34402ee412e0d8c16cc3e8ac37f
--- /dev/null
+++ b/models/latte_img.py
@@ -0,0 +1,552 @@
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# GLIDE: https://github.com/openai/glide-text2im
+# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
+# --------------------------------------------------------
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+
+from einops import rearrange, repeat
+from timm.models.vision_transformer import Mlp, PatchEmbed
+
+import os
+import sys
+sys.path.append(os.path.split(sys.path[0])[0])
+
+# the xformers lib allows less memory, faster training and inference
+try:
+ import xformers
+ import xformers.ops
+except:
+ XFORMERS_IS_AVAILBLE = False
+
+# from timm.models.layers.helpers import to_2tuple
+# from timm.models.layers.trace_utils import _assert
+
+def modulate(x, shift, scale):
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+#################################################################################
+# Attention Layers from TIMM #
+#################################################################################
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math'):
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+ self.attention_mode = attention_mode
+
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ if self.attention_mode == 'xformers': # cause loss nan while using with amp
+ x = xformers.ops.memory_efficient_attention(q, k, v).reshape(B, N, C)
+
+ elif self.attention_mode == 'flash':
+ # cause loss nan while using with amp
+ # Optionally use the context manager to ensure one of the fused kerenels is run
+ with torch.backends.cuda.sdp_kernel(enable_math=False):
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v).reshape(B, N, C) # require pytorch 2.0
+
+ elif self.attention_mode == 'math':
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+
+ else:
+ raise NotImplemented
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+#################################################################################
+# Embedding Layers for Timesteps and Class Labels #
+#################################################################################
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t, use_fp16=False):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ if use_fp16:
+ t_freq = t_freq.to(dtype=torch.float16)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class LabelEmbedder(nn.Module):
+ """
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
+ """
+ def __init__(self, num_classes, hidden_size, dropout_prob):
+ super().__init__()
+ use_cfg_embedding = dropout_prob > 0
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
+ self.num_classes = num_classes
+ self.dropout_prob = dropout_prob
+
+ def token_drop(self, labels, force_drop_ids=None):
+ """
+ Drops labels to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
+ else:
+ drop_ids = force_drop_ids == 1
+ labels = torch.where(drop_ids, self.num_classes, labels)
+ return labels
+
+ def forward(self, labels, train, force_drop_ids=None):
+ use_dropout = self.dropout_prob > 0
+ if (train and use_dropout) or (force_drop_ids is not None):
+ labels = self.token_drop(labels, force_drop_ids)
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+
+#################################################################################
+# Core Latte Model #
+#################################################################################
+
+class TransformerBlock(nn.Module):
+ """
+ A Latte block with adaptive layer norm zero (adaLN-Zero) conditioning.
+ """
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FinalLayer(nn.Module):
+ """
+ The final layer of Latte.
+ """
+ def __init__(self, hidden_size, patch_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+
+class Latte(nn.Module):
+ """
+ Diffusion model with a Transformer backbone.
+ """
+ def __init__(
+ self,
+ input_size=32,
+ patch_size=2,
+ in_channels=4,
+ hidden_size=1152,
+ depth=28,
+ num_heads=16,
+ mlp_ratio=4.0,
+ num_frames=16,
+ class_dropout_prob=0.1,
+ num_classes=1000,
+ learn_sigma=True,
+ extras=2,
+ attention_mode='math',
+ ):
+ super().__init__()
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
+ self.patch_size = patch_size
+ self.num_heads = num_heads
+ self.extras = extras
+ self.num_frames = num_frames
+
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+
+ if self.extras == 2:
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
+ if self.extras == 78: # timestep + text_embedding
+ self.text_embedding_projection = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(1024, hidden_size, bias=True)
+ )
+
+ num_patches = self.x_embedder.num_patches
+ # Will use fixed sin-cos embedding:
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
+ self.temp_embed = nn.Parameter(torch.zeros(1, num_frames, hidden_size), requires_grad=False)
+
+ self.blocks = nn.ModuleList([
+ TransformerBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attention_mode=attention_mode) for _ in range(depth)
+ ])
+
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ self.apply(_basic_init)
+
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
+
+ temp_embed = get_1d_sincos_temp_embed(self.temp_embed.shape[-1], self.temp_embed.shape[-2])
+ self.temp_embed.data.copy_(torch.from_numpy(temp_embed).float().unsqueeze(0))
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ if self.extras == 2:
+ # Initialize label embedding table:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # Zero-out adaLN modulation layers in Latte blocks:
+ for block in self.blocks:
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def unpatchify(self, x):
+ """
+ x: (N, T, patch_size**2 * C)
+ imgs: (N, H, W, C)
+ """
+ c = self.out_channels
+ p = self.x_embedder.patch_size[0]
+ h = w = int(x.shape[1] ** 0.5)
+ assert h * w == x.shape[1]
+
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
+ x = torch.einsum('nhwpqc->nchpwq', x)
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
+ return imgs
+
+ # @torch.cuda.amp.autocast()
+ # @torch.compile
+ def forward(self, x, t, y=None, use_fp16=False, y_image=None, use_image_num=0):
+ """
+ Forward pass of Latte.
+ x: (N, F, C, H, W) tensor of video inputs
+ t: (N,) tensor of diffusion timesteps
+ y: (N,) tensor of class labels
+ y_image: tensor of video frames
+ use_image_num: how many video frames are used
+ """
+ if use_fp16:
+ x = x.to(dtype=torch.float16)
+ batches, frames, channels, high, weight = x.shape
+ x = rearrange(x, 'b f c h w -> (b f) c h w')
+ x = self.x_embedder(x) + self.pos_embed
+ t = self.t_embedder(t, use_fp16=use_fp16)
+ timestep_spatial = repeat(t, 'n d -> (n c) d', c=self.temp_embed.shape[1] + use_image_num)
+ timestep_temp = repeat(t, 'n d -> (n c) d', c=self.pos_embed.shape[1])
+
+ if self.extras == 2:
+ y = self.y_embedder(y, self.training)
+ if self.training:
+ y_image_emb = []
+ # print(y_image)
+ for y_image_single in y_image:
+ # print(y_image_single)
+ y_image_single = y_image_single.reshape(1, -1)
+ y_image_emb.append(self.y_embedder(y_image_single, self.training))
+ y_image_emb = torch.cat(y_image_emb, dim=0)
+ y_spatial = repeat(y, 'n d -> n c d', c=self.temp_embed.shape[1])
+ y_spatial = torch.cat([y_spatial, y_image_emb], dim=1)
+ y_spatial = rearrange(y_spatial, 'n c d -> (n c) d')
+ else:
+ y_spatial = repeat(y, 'n d -> (n c) d', c=self.temp_embed.shape[1])
+
+ y_temp = repeat(y, 'n d -> (n c) d', c=self.pos_embed.shape[1])
+ elif self.extras == 78:
+ text_embedding = self.text_embedding_projection(text_embedding)
+ text_embedding_video = text_embedding[:, :1, :]
+ text_embedding_image = text_embedding[:, 1:, :]
+ text_embedding_video = repeat(text_embedding, 'n t d -> n (t c) d', c=self.temp_embed.shape[1])
+ text_embedding_spatial = torch.cat([text_embedding_video, text_embedding_image], dim=1)
+ text_embedding_spatial = rearrange(text_embedding_spatial, 'n t d -> (n t) d')
+ text_embedding_temp = repeat(text_embedding_video, 'n t d -> n (t c) d', c=self.pos_embed.shape[1])
+ text_embedding_temp = rearrange(text_embedding_temp, 'n t d -> (n t) d')
+
+ for i in range(0, len(self.blocks), 2):
+ spatial_block, temp_block = self.blocks[i:i+2]
+
+ if self.extras == 2:
+ c = timestep_spatial + y_spatial
+ elif self.extras == 78:
+ c = timestep_spatial + text_embedding_spatial
+ else:
+ c = timestep_spatial
+ x = spatial_block(x, c)
+
+ x = rearrange(x, '(b f) t d -> (b t) f d', b=batches)
+ x_video = x[:, :(frames-use_image_num), :]
+ x_image = x[:, (frames-use_image_num):, :]
+
+ # Add Time Embedding
+ if i == 0:
+ x_video = x_video + self.temp_embed
+
+ if self.extras == 2:
+ c = timestep_temp + y_temp
+ elif self.extras == 78:
+ c = timestep_temp + text_embedding_temp
+ else:
+ c = timestep_temp
+
+ x_video = temp_block(x_video, c)
+ x = torch.cat([x_video, x_image], dim=1)
+ x = rearrange(x, '(b t) f d -> (b f) t d', b=batches)
+
+ if self.extras == 2:
+ c = timestep_spatial + y_spatial
+ else:
+ c = timestep_spatial
+ x = self.final_layer(x, c)
+ x = self.unpatchify(x)
+ x = rearrange(x, '(b f) c h w -> b f c h w', b=batches)
+ # print(x.shape)
+ return x
+
+
+ def forward_with_cfg(self, x, t, y, cfg_scale, use_fp16=False):
+ """
+ Forward pass of Latte, but also batches the unconditional forward pass for classifier-free guidance.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
+ half = x[: len(x) // 2]
+ combined = torch.cat([half, half], dim=0)
+ if use_fp16:
+ combined = combined.to(dtype=torch.float16)
+ model_out = self.forward(combined, t, y, use_fp16=use_fp16)
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
+ # three channels by default. The standard approach to cfg applies it to all channels.
+ # This can be done by uncommenting the following line and commenting-out the line following that.
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
+ # eps, rest = model_out[:, :3], model_out[:, 3:]
+ eps, rest = model_out[:, :, :4, ...], model_out[:, :, 4:, ...] # 2 16 4 32 32
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
+ eps = torch.cat([half_eps, half_eps], dim=0)
+ return torch.cat([eps, rest], dim=2)
+
+
+#################################################################################
+# Sine/Cosine Positional Embedding Functions #
+#################################################################################
+# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
+
+def get_1d_sincos_temp_embed(embed_dim, length):
+ pos = torch.arange(0, length).unsqueeze(1)
+ return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
+
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+#################################################################################
+# Latte Configs #
+#################################################################################
+
+def Latte_XL_2(**kwargs):
+ return Latte(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
+
+def Latte_XL_4(**kwargs):
+ return Latte(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
+
+def Latte_XL_8(**kwargs):
+ return Latte(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
+
+def Latte_L_2(**kwargs):
+ return Latte(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
+
+def Latte_L_4(**kwargs):
+ return Latte(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
+
+def Latte_L_8(**kwargs):
+ return Latte(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
+
+def Latte_B_2(**kwargs):
+ return Latte(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
+
+def Latte_B_4(**kwargs):
+ return Latte(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
+
+def Latte_B_8(**kwargs):
+ return Latte(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
+
+def Latte_S_2(**kwargs):
+ return Latte(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
+
+def Latte_S_4(**kwargs):
+ return Latte(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
+
+def Latte_S_8(**kwargs):
+ return Latte(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
+
+
+LatteIMG_models = {
+ 'LatteIMG-XL/2': Latte_XL_2, 'LatteIMG-XL/4': Latte_XL_4, 'LatteIMG-XL/8': Latte_XL_8,
+ 'LatteIMG-L/2': Latte_L_2, 'LatteIMG-L/4': Latte_L_4, 'LatteIMG-L/8': Latte_L_8,
+ 'LatteIMG-B/2': Latte_B_2, 'LatteIMG-B/4': Latte_B_4, 'LatteIMG-B/8': Latte_B_8,
+ 'LatteIMG-S/2': Latte_S_2, 'LatteIMG-S/4': Latte_S_4, 'LatteIMG-S/8': Latte_S_8,
+}
+
+if __name__ == '__main__':
+ import torch
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ use_image_num = 8
+
+ img = torch.randn(3, 16+use_image_num, 4, 32, 32).to(device)
+
+ t = torch.tensor([1, 2, 3]).to(device)
+ y = torch.tensor([1, 2, 3]).to(device)
+ y_image = [torch.tensor([48, 37, 72, 63, 74, 6, 7, 8]).to(device),
+ torch.tensor([37, 72, 63, 74, 70, 1, 2, 3]).to(device),
+ torch.tensor([72, 63, 74, 70, 71, 5, 8, 7]).to(device),
+ ]
+
+
+ network = Latte_XL_2().to(device)
+ network.train()
+
+ out = network(img, t, y=y, y_image=y_image, use_image_num=use_image_num)
+ print(out.shape)
\ No newline at end of file
diff --git a/models/latte_t2v.py b/models/latte_t2v.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1296fa1a1c45f3d7c86046f81873a547060e08c
--- /dev/null
+++ b/models/latte_t2v.py
@@ -0,0 +1,945 @@
+import torch
+
+import os
+import json
+
+from dataclasses import dataclass
+from einops import rearrange, repeat
+from typing import Any, Dict, Optional, Tuple
+from diffusers.models import Transformer2DModel
+from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate
+from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid, ImagePositionalEmbeddings, CaptionProjection, PatchEmbed, CombinedTimestepSizeEmbeddings
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.attention import BasicTransformerBlock
+from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from diffusers.models.embeddings import SinusoidalPositionalEmbedding
+from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
+from diffusers.models.attention_processor import Attention
+from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
+
+from dataclasses import dataclass
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+@maybe_allow_in_graph
+class GatedSelfAttentionDense(nn.Module):
+ r"""
+ A gated self-attention dense layer that combines visual features and object features.
+
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ context_dim (`int`): The number of channels in the context.
+ n_heads (`int`): The number of heads to use for attention.
+ d_head (`int`): The number of channels in each head.
+ """
+
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
+ super().__init__()
+
+ # we need a linear projection since we need cat visual feature and obj feature
+ self.linear = nn.Linear(context_dim, query_dim)
+
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
+
+ self.norm1 = nn.LayerNorm(query_dim)
+ self.norm2 = nn.LayerNorm(query_dim)
+
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
+
+ self.enabled = True
+
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
+ if not self.enabled:
+ return x
+
+ n_visual = x.shape[1]
+ objs = self.linear(objs)
+
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
+
+ return x
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (`int`): The number of channels in the input.
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ mult: int = 4,
+ dropout: float = 0.0,
+ activation_fn: str = "geglu",
+ final_dropout: bool = False,
+ ):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+ linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
+
+ if activation_fn == "gelu":
+ act_fn = GELU(dim, inner_dim)
+ if activation_fn == "gelu-approximate":
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
+ elif activation_fn == "geglu":
+ act_fn = GEGLU(dim, inner_dim)
+ elif activation_fn == "geglu-approximate":
+ act_fn = ApproximateGELU(dim, inner_dim)
+
+ self.net = nn.ModuleList([])
+ # project in
+ self.net.append(act_fn)
+ # project dropout
+ self.net.append(nn.Dropout(dropout))
+ # project out
+ self.net.append(linear_cls(inner_dim, dim_out))
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
+ if final_dropout:
+ self.net.append(nn.Dropout(dropout))
+
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
+ compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
+ for module in self.net:
+ if isinstance(module, compatible_cls):
+ hidden_states = module(hidden_states, scale)
+ else:
+ hidden_states = module(hidden_states)
+ return hidden_states
+
+@maybe_allow_in_graph
+class BasicTransformerBlock_(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, *optional*):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ upcast_attention (`bool`, *optional*):
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
+ final_dropout (`bool` *optional*, defaults to False):
+ Whether to apply a final dropout after the last feed-forward layer.
+ attention_type (`str`, *optional*, defaults to `"default"`):
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
+ positional_embeddings (`str`, *optional*, defaults to `None`):
+ The type of positional embeddings to apply to.
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
+ The maximum number of positional embeddings to apply.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
+ norm_eps: float = 1e-5,
+ final_dropout: bool = False,
+ attention_type: str = "default",
+ positional_embeddings: Optional[str] = None,
+ num_positional_embeddings: Optional[int] = None,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
+ self.use_layer_norm = norm_type == "layer_norm"
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ if positional_embeddings and (num_positional_embeddings is None):
+ raise ValueError(
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
+ )
+
+ if positional_embeddings == "sinusoidal":
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
+ else:
+ self.pos_embed = None
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ if self.use_ada_layer_norm:
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif self.use_ada_layer_norm_zero:
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) # go here
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ )
+
+ # # 2. Cross-Attn
+ # if cross_attention_dim is not None or double_self_attention:
+ # # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # # the second cross attention block.
+ # self.norm2 = (
+ # AdaLayerNorm(dim, num_embeds_ada_norm)
+ # if self.use_ada_layer_norm
+ # else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+ # )
+ # self.attn2 = Attention(
+ # query_dim=dim,
+ # cross_attention_dim=cross_attention_dim if not double_self_attention else None,
+ # heads=num_attention_heads,
+ # dim_head=attention_head_dim,
+ # dropout=dropout,
+ # bias=attention_bias,
+ # upcast_attention=upcast_attention,
+ # ) # is self-attn if encoder_hidden_states is none
+ # else:
+ # self.norm2 = None
+ # self.attn2 = None
+
+ # 3. Feed-forward
+ # if not self.use_ada_layer_norm_single:
+ # self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
+
+ # 4. Fuser
+ if attention_type == "gated" or attention_type == "gated-text-image":
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
+
+ # 5. Scale-shift for PixArt-Alpha.
+ if self.use_ada_layer_norm_single:
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ ) -> torch.FloatTensor:
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Self-Attention
+ batch_size = hidden_states.shape[0]
+
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ elif self.use_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states)
+ elif self.use_ada_layer_norm_single: # go here
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+ # norm_hidden_states = norm_hidden_states.squeeze(1)
+ else:
+ raise ValueError("Incorrect norm used")
+
+ if self.pos_embed is not None:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ # 1. Retrieve lora scale.
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+
+ # 2. Prepare GLIGEN inputs
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ elif self.use_ada_layer_norm_single:
+ attn_output = gate_msa * attn_output
+
+ hidden_states = attn_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ # 2.5 GLIGEN Control
+ if gligen_kwargs is not None:
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
+
+ # # 3. Cross-Attention
+ # if self.attn2 is not None:
+ # if self.use_ada_layer_norm:
+ # norm_hidden_states = self.norm2(hidden_states, timestep)
+ # elif self.use_ada_layer_norm_zero or self.use_layer_norm:
+ # norm_hidden_states = self.norm2(hidden_states)
+ # elif self.use_ada_layer_norm_single:
+ # # For PixArt norm2 isn't applied here:
+ # # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
+ # norm_hidden_states = hidden_states
+ # else:
+ # raise ValueError("Incorrect norm")
+
+ # if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
+ # norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ # attn_output = self.attn2(
+ # norm_hidden_states,
+ # encoder_hidden_states=encoder_hidden_states,
+ # attention_mask=encoder_attention_mask,
+ # **cross_attention_kwargs,
+ # )
+ # hidden_states = attn_output + hidden_states
+
+ # 4. Feed-forward
+ # if not self.use_ada_layer_norm_single:
+ # norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ if self.use_ada_layer_norm_single:
+ # norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = self.norm3(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
+ raise ValueError(
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+ )
+
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
+ ff_output = torch.cat(
+ [
+ self.ff(hid_slice, scale=lora_scale)
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
+ ],
+ dim=self._chunk_dim,
+ )
+ else:
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+ elif self.use_ada_layer_norm_single:
+ ff_output = gate_mlp * ff_output
+
+ hidden_states = ff_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ return hidden_states
+
+class AdaLayerNormSingle(nn.Module):
+ r"""
+ Norm layer adaptive layer norm single (adaLN-single).
+
+ As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ use_additional_conditions (`bool`): To use additional conditions for normalization or not.
+ """
+
+ def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
+ super().__init__()
+
+ self.emb = CombinedTimestepSizeEmbeddings(
+ embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
+ )
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
+ batch_size: int = None,
+ hidden_dtype: Optional[torch.dtype] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # No modulation happening here.
+ embedded_timestep = self.emb(timestep, batch_size=batch_size, hidden_dtype=hidden_dtype, resolution=None, aspect_ratio=None)
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep
+
+@dataclass
+class Transformer3DModelOutput(BaseOutput):
+ """
+ The output of [`Transformer2DModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
+ distributions for the unnoised latent pixels.
+ """
+
+ sample: torch.FloatTensor
+
+
+class LatteT2V(ModelMixin, ConfigMixin):
+ _supports_gradient_checkpointing = True
+
+ """
+ A 2D Transformer model for image-like data.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ The number of channels in the input and output (specify if the input is **continuous**).
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
+ This is fixed during training since it is used to learn a number of position embeddings.
+ num_vector_embeds (`int`, *optional*):
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
+ Includes the class for the masked latent pixel.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*):
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
+ added to the hidden states.
+
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
+ attention_bias (`bool`, *optional*):
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ out_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ sample_size: Optional[int] = None,
+ num_vector_embeds: Optional[int] = None,
+ patch_size: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_type: str = "layer_norm",
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ attention_type: str = "default",
+ caption_channels: int = None,
+ video_length: int = 16,
+ ):
+ super().__init__()
+ self.use_linear_projection = use_linear_projection
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+ self.video_length = video_length
+
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
+
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
+ # Define whether input is continuous or discrete depending on configuration
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
+ self.is_input_vectorized = num_vector_embeds is not None
+ self.is_input_patches = in_channels is not None and patch_size is not None
+
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
+ deprecation_message = (
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
+ )
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
+ norm_type = "ada_norm"
+
+ if self.is_input_continuous and self.is_input_vectorized:
+ raise ValueError(
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
+ " sure that either `in_channels` or `num_vector_embeds` is None."
+ )
+ elif self.is_input_vectorized and self.is_input_patches:
+ raise ValueError(
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
+ " sure that either `num_vector_embeds` or `num_patches` is None."
+ )
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
+ raise ValueError(
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
+ )
+
+ # 2. Define input layers
+ if self.is_input_continuous:
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ if use_linear_projection:
+ self.proj_in = linear_cls(in_channels, inner_dim)
+ else:
+ self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+ elif self.is_input_vectorized:
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
+
+ self.height = sample_size
+ self.width = sample_size
+ self.num_vector_embeds = num_vector_embeds
+ self.num_latent_pixels = self.height * self.width
+
+ self.latent_image_embedding = ImagePositionalEmbeddings(
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
+ )
+ elif self.is_input_patches:
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
+
+ self.height = sample_size
+ self.width = sample_size
+
+ self.patch_size = patch_size
+ interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
+ interpolation_scale = max(interpolation_scale, 1)
+ self.pos_embed = PatchEmbed(
+ height=sample_size,
+ width=sample_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ interpolation_scale=interpolation_scale,
+ )
+
+ # 3. Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ double_self_attention=double_self_attention,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ attention_type=attention_type,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # Define temporal transformers blocks
+ self.temporal_transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock_( # one attention
+ inner_dim,
+ num_attention_heads, # num_attention_heads
+ attention_head_dim, # attention_head_dim 72
+ dropout=dropout,
+ cross_attention_dim=None,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ double_self_attention=False,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ attention_type=attention_type,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+
+ # 4. Define output layers
+ self.out_channels = in_channels if out_channels is None else out_channels
+ if self.is_input_continuous:
+ # TODO: should use out_channels for continuous projections
+ if use_linear_projection:
+ self.proj_out = linear_cls(inner_dim, in_channels)
+ else:
+ self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ elif self.is_input_vectorized:
+ self.norm_out = nn.LayerNorm(inner_dim)
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
+ elif self.is_input_patches and norm_type != "ada_norm_single":
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
+ elif self.is_input_patches and norm_type == "ada_norm_single":
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
+
+ # 5. PixArt-Alpha blocks.
+ self.adaln_single = None
+ self.use_additional_conditions = False
+ if norm_type == "ada_norm_single":
+ self.use_additional_conditions = self.config.sample_size == 128 # False, 128 -> 1024
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
+ # additional conditions until we find better name
+ self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
+
+ self.caption_projection = None
+ if caption_channels is not None:
+ self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim)
+
+ self.gradient_checkpointing = False
+
+ # define temporal positional embedding
+ temp_pos_embed = self.get_1d_sincos_temp_embed(inner_dim, video_length) # 1152 hidden size
+ self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False)
+
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ self.gradient_checkpointing = value
+
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: Optional[torch.LongTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ use_image_num: int = 0,
+ enable_temporal_attentions: bool = True,
+ return_dict: bool = True,
+ ):
+ """
+ The [`Transformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous):
+ Input `hidden_states`.
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.LongTensor`, *optional*):
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
+ `AdaLayerZeroNorm`.
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ attention_mask ( `torch.Tensor`, *optional*):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
+
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
+
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
+ above. This bias will be added to the cross-attention scores.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ input_batch_size, c, frame, h, w = hidden_states.shape
+ frame = frame - use_image_num
+ hidden_states = rearrange(hidden_states, 'b c f h w -> (b f) c h w').contiguous()
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None and attention_mask.ndim == 2:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: # ndim == 2 means no image joint
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+ encoder_attention_mask = repeat(encoder_attention_mask, 'b 1 l -> (b f) 1 l', f=frame).contiguous()
+ elif encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # ndim == 3 means image joint
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask_video = encoder_attention_mask[:, :1, ...]
+ encoder_attention_mask_video = repeat(encoder_attention_mask_video, 'b 1 l -> b (1 f) l', f=frame).contiguous()
+ encoder_attention_mask_image = encoder_attention_mask[:, 1:, ...]
+ encoder_attention_mask = torch.cat([encoder_attention_mask_video, encoder_attention_mask_image], dim=1)
+ encoder_attention_mask = rearrange(encoder_attention_mask, 'b n l -> (b n) l').contiguous().unsqueeze(1)
+
+
+ # Retrieve lora scale.
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+
+ # 1. Input
+ if self.is_input_patches: # here
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
+ num_patches = height * width
+
+ hidden_states = self.pos_embed(hidden_states) # alrady add positional embeddings
+
+ if self.adaln_single is not None:
+ if self.use_additional_conditions and added_cond_kwargs is None:
+ raise ValueError(
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
+ )
+ # batch_size = hidden_states.shape[0]
+ batch_size = input_batch_size
+ timestep, embedded_timestep = self.adaln_single(
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
+ )
+
+ # 2. Blocks
+ if self.caption_projection is not None:
+ batch_size = hidden_states.shape[0]
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152
+
+ if use_image_num != 0 and self.training:
+ encoder_hidden_states_video = encoder_hidden_states[:, :1, ...]
+ encoder_hidden_states_video = repeat(encoder_hidden_states_video, 'b 1 t d -> b (1 f) t d', f=frame).contiguous()
+ encoder_hidden_states_image = encoder_hidden_states[:, 1:, ...]
+ encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1)
+ encoder_hidden_states_spatial = rearrange(encoder_hidden_states, 'b f t d -> (b f) t d').contiguous()
+ else:
+ encoder_hidden_states_spatial = repeat(encoder_hidden_states, 'b t d -> (b f) t d', f=frame).contiguous()
+
+ # prepare timesteps for spatial and temporal block
+ timestep_spatial = repeat(timestep, 'b d -> (b f) d', f=frame + use_image_num).contiguous()
+ timestep_temp = repeat(timestep, 'b d -> (b p) d', p=num_patches).contiguous()
+
+ for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
+
+ if self.training and self.gradient_checkpointing:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ spatial_block,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states_spatial,
+ encoder_attention_mask,
+ timestep_spatial,
+ cross_attention_kwargs,
+ class_labels,
+ use_reentrant=False,
+ )
+
+ if enable_temporal_attentions:
+ hidden_states = rearrange(hidden_states, '(b f) t d -> (b t) f d', b=input_batch_size).contiguous()
+
+ if use_image_num != 0: # image-video joitn training
+ hidden_states_video = hidden_states[:, :frame, ...]
+ hidden_states_image = hidden_states[:, frame:, ...]
+
+ if i == 0:
+ hidden_states_video = hidden_states_video + self.temp_pos_embed
+
+ hidden_states_video = torch.utils.checkpoint.checkpoint(
+ temp_block,
+ hidden_states_video,
+ None, # attention_mask
+ None, # encoder_hidden_states
+ None, # encoder_attention_mask
+ timestep_temp,
+ cross_attention_kwargs,
+ class_labels,
+ use_reentrant=False,
+ )
+
+ hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
+ hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', b=input_batch_size).contiguous()
+
+ else:
+ if i == 0:
+ hidden_states = hidden_states + self.temp_pos_embed
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ temp_block,
+ hidden_states,
+ None, # attention_mask
+ None, # encoder_hidden_states
+ None, # encoder_attention_mask
+ timestep_temp,
+ cross_attention_kwargs,
+ class_labels,
+ use_reentrant=False,
+ )
+
+ hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', b=input_batch_size).contiguous()
+ else:
+ hidden_states = spatial_block(
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states_spatial,
+ encoder_attention_mask,
+ timestep_spatial,
+ cross_attention_kwargs,
+ class_labels,
+ )
+
+ if enable_temporal_attentions:
+
+ hidden_states = rearrange(hidden_states, '(b f) t d -> (b t) f d', b=input_batch_size).contiguous()
+
+ if use_image_num != 0 and self.training:
+ hidden_states_video = hidden_states[:, :frame, ...]
+ hidden_states_image = hidden_states[:, frame:, ...]
+
+ hidden_states_video = temp_block(
+ hidden_states_video,
+ None, # attention_mask
+ None, # encoder_hidden_states
+ None, # encoder_attention_mask
+ timestep_temp,
+ cross_attention_kwargs,
+ class_labels,
+ )
+
+ hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
+ hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', b=input_batch_size).contiguous()
+
+ else:
+ if i == 0 and frame > 1:
+ hidden_states = hidden_states + self.temp_pos_embed
+
+ hidden_states = temp_block(
+ hidden_states,
+ None, # attention_mask
+ None, # encoder_hidden_states
+ None, # encoder_attention_mask
+ timestep_temp,
+ cross_attention_kwargs,
+ class_labels,
+ )
+
+ hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', b=input_batch_size).contiguous()
+
+
+ if self.is_input_patches:
+ if self.config.norm_type != "ada_norm_single":
+ conditioning = self.transformer_blocks[0].norm1.emb(
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
+ hidden_states = self.proj_out_2(hidden_states)
+ elif self.config.norm_type == "ada_norm_single":
+ embedded_timestep = repeat(embedded_timestep, 'b d -> (b f) d', f=frame + use_image_num).contiguous()
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states)
+ # Modulation
+ hidden_states = hidden_states * (1 + scale) + shift
+ hidden_states = self.proj_out(hidden_states)
+
+ # unpatchify
+ if self.adaln_single is None:
+ height = width = int(hidden_states.shape[1] ** 0.5)
+ hidden_states = hidden_states.reshape(
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
+ )
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
+ )
+ output = rearrange(output, '(b f) c h w -> b c f h w', b=input_batch_size).contiguous()
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer3DModelOutput(sample=output)
+
+ def get_1d_sincos_temp_embed(self, embed_dim, length):
+ pos = torch.arange(0, length).unsqueeze(1)
+ return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
\ No newline at end of file
diff --git a/models/utils.py b/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e130569d3a4c48de0335832495a967668d9afcd
--- /dev/null
+++ b/models/utils.py
@@ -0,0 +1,215 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+
+import os
+import math
+import torch
+
+import numpy as np
+import torch.nn as nn
+
+from einops import repeat
+
+
+#################################################################################
+# Unet Utils #
+#################################################################################
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ else:
+ embedding = repeat(timesteps, 'b -> b d', d=dim).contiguous()
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+# class HybridConditioner(nn.Module):
+
+# def __init__(self, c_concat_config, c_crossattn_config):
+# super().__init__()
+# self.concat_conditioner = instantiate_from_config(c_concat_config)
+# self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+# def forward(self, c_concat, c_crossattn):
+# c_concat = self.concat_conditioner(c_concat)
+# c_crossattn = self.crossattn_conditioner(c_crossattn)
+# return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
+ model.total_ops += torch.DoubleTensor([matmul_ops])
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
+ return total_params
\ No newline at end of file
diff --git a/sample/__pycache__/pipeline_latte.cpython-312.pyc b/sample/__pycache__/pipeline_latte.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..780dfa4d72a653b1f9fd1ca272a6d2e2921ab166
Binary files /dev/null and b/sample/__pycache__/pipeline_latte.cpython-312.pyc differ
diff --git a/sample/ffs.sh b/sample/ffs.sh
new file mode 100644
index 0000000000000000000000000000000000000000..70b646d3cad2a02e965910aa54c20c06f5739d3e
--- /dev/null
+++ b/sample/ffs.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+export CUDA_VISIBLE_DEVICES=7
+
+python sample/sample.py \
+--config ./configs/ffs/ffs_sample.yaml \
+--ckpt ./share_ckpts/ffs.pt \
+--save_video_path ./test
\ No newline at end of file
diff --git a/sample/ffs_ddp.sh b/sample/ffs_ddp.sh
new file mode 100644
index 0000000000000000000000000000000000000000..08b474a861aa9f21d28c7edc961df87334a6d140
--- /dev/null
+++ b/sample/ffs_ddp.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+export CUDA_VISIBLE_DEVICES=6,7
+
+torchrun --nnodes=1 --nproc_per_node=2 sample/sample_ddp.py \
+--config ./configs/ffs/ffs_sample.yaml \
+--ckpt ./share_ckpts/ffs.pt \
+--save_video_path ./test
diff --git a/sample/pipeline_latte.py b/sample/pipeline_latte.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7860f8eff719fff0842c60e3e3f3fa944e041a3
--- /dev/null
+++ b/sample/pipeline_latte.py
@@ -0,0 +1,783 @@
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+import re
+import urllib.parse as ul
+from typing import Callable, List, Optional, Tuple, Union
+
+import torch
+import einops
+from transformers import T5EncoderModel, T5Tokenizer
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.models import AutoencoderKL, Transformer2DModel
+from diffusers.schedulers import DPMSolverMultistepScheduler
+from diffusers.utils import (
+ BACKENDS_MAPPING,
+ is_bs4_available,
+ is_ftfy_available,
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.utils import BaseOutput
+from dataclasses import dataclass
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import PixArtAlphaPipeline
+
+ >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
+ >>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
+ >>> # Enable memory optimizations.
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> prompt = "A small cactus with a happy face in the Sahara desert."
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+@dataclass
+class VideoPipelineOutput(BaseOutput):
+ video: torch.Tensor
+
+
+class LattePipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using PixArt-Alpha.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. PixArt-Alpha uses
+ [T5](https://huggingface.co./docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
+ [t5-v1_1-xxl](https://huggingface.co./PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+ tokenizer (`T5Tokenizer`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co./docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`Transformer2DModel`]):
+ A text conditioned `Transformer2DModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ """
+ bad_punct_regex = re.compile(
+ r"[" + "#ยฎโขยฉโข&@ยทยบยฝยพยฟยกยง~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
+ ) # noqa
+
+ _optional_components = ["tokenizer", "text_encoder"]
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ vae: AutoencoderKL,
+ transformer: Transformer2DModel,
+ scheduler: DPMSolverMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
+ def mask_text_embeddings(self, emb, mask):
+ if emb.shape[0] == 1:
+ keep_index = mask.sum().item()
+ return emb[:, :, :keep_index, :], keep_index # 1, 120, 4096 -> 1 7 4096
+ else:
+ masked_feature = emb * mask[:, None, :, None] # 1 120 4096
+ return masked_feature, emb.shape[2]
+
+ # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: str = "",
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ clean_caption: bool = False,
+ mask_feature: bool = True,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
+ PixArt-Alpha, this should be "".
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
+ string.
+ clean_caption (bool, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ mask_feature: (bool, defaults to `True`):
+ If `True`, the function will mask the text embeddings.
+ """
+ embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
+
+ if device is None:
+ device = self._execution_device
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # See Section 3.1. of the paper.
+ max_length = 120
+
+ if prompt_embeds is None:
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {max_length} tokens: {removed_text}"
+ )
+
+ attention_mask = text_inputs.attention_mask.to(device)
+ prompt_embeds_attention_mask = attention_mask
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds_attention_mask = torch.ones_like(prompt_embeds)
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ elif self.transformer is not None:
+ dtype = self.transformer.dtype
+ else:
+ dtype = None
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed, -1)
+ prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(num_images_per_prompt, 1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens = [negative_prompt] * batch_size
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ attention_mask = uncond_input.attention_mask.to(device)
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ else:
+ negative_prompt_embeds = None
+
+ # Perform additional masking.
+ if mask_feature and not embeds_initially_provided:
+ prompt_embeds = prompt_embeds.unsqueeze(1)
+ masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
+ masked_prompt_embeds = masked_prompt_embeds.squeeze(1)
+ masked_negative_prompt_embeds = (
+ negative_prompt_embeds[:, :keep_indices, :] if negative_prompt_embeds is not None else None
+ )
+
+ # import torch.nn.functional as F
+
+ # padding = (0, 0, 0, 113) # (ๅทฆ, ๅณ, ไธ, ไธ)
+ # masked_prompt_embeds_ = F.pad(masked_prompt_embeds, padding, "constant", 0)
+ # masked_negative_prompt_embeds_ = F.pad(masked_negative_prompt_embeds, padding, "constant", 0)
+
+ # print(masked_prompt_embeds == masked_prompt_embeds_[:, :masked_negative_prompt_embeds.shape[1], ...])
+
+ return masked_prompt_embeds, masked_negative_prompt_embeds
+ # return masked_prompt_embeds_, masked_negative_prompt_embeds_
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (ฮท) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to ฮท in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_steps,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
+ logger.warn("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
+ logger.warn("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0โ31EF CJK Strokes
+ # 31F0โ31FF Katakana Phonetic Extensions
+ # 3200โ32FF Enclosed CJK Letters and Months
+ # 3300โ33FF CJK Compatibility
+ # 3400โ4DBF CJK Unified Ideographs Extension A
+ # 4DC0โ4DFF Yijing Hexagram Symbols
+ # 4E00โ9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # ะฒัะต ะฒะธะดั ัะธัะต / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # ะบะฐะฒััะบะธ ะบ ะพะดะฝะพะผั ััะฐะฝะดะฐััั
+ caption = re.sub(r"[`ยดยซยปโโยจ]", '"', caption)
+ caption = re.sub(r"[โโ]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip adresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xั
ร]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: str = "",
+ num_inference_steps: int = 20,
+ timesteps: List[int] = None,
+ guidance_scale: float = 4.5,
+ num_images_per_prompt: Optional[int] = 1,
+ video_length: Optional[int] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ clean_caption: bool = True,
+ mask_feature: bool = True,
+ enable_temporal_attentions: bool = True,
+ enable_vae_temporal_decoder: bool = False,
+ ) -> Union[VideoPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
+ timesteps are used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The width in pixels of the generated image.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (ฮท) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images
+ """
+ # 1. Check inputs. Raise error if not correct
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
+ self.check_inputs(
+ prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Default height and width to transformer
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ clean_caption=clean_caption,
+ mask_feature=mask_feature,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ latent_channels,
+ video_length,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 6.1 Prepare micro-conditions.
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
+ if self.transformer.config.sample_size == 128:
+ resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
+ aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
+ resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
+ aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
+ added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
+
+ # 7. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ current_timestep = t
+ if not torch.is_tensor(current_timestep):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = latent_model_input.device.type == "mps"
+ if isinstance(current_timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
+ elif len(current_timestep.shape) == 0:
+ current_timestep = current_timestep[None].to(latent_model_input.device)
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ current_timestep = current_timestep.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=current_timestep,
+ added_cond_kwargs=added_cond_kwargs,
+ enable_temporal_attentions=enable_temporal_attentions,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # learned sigma
+ if self.transformer.config.out_channels // 2 == latent_channels:
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
+ else:
+ noise_pred = noise_pred
+
+ # compute previous image: x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == 'latents':
+ if enable_vae_temporal_decoder:
+ video = self.decode_latents_with_temporal_decoder(latents)
+ else:
+ video = self.decode_latents(latents)
+ else:
+ video = latents
+ return VideoPipelineOutput(video=video)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return VideoPipelineOutput(video=video)
+
+ def decode_latents(self, latents):
+ video_length = latents.shape[2]
+ latents = 1 / self.vae.config.scaling_factor * latents
+ latents = einops.rearrange(latents, "b c f h w -> (b f) c h w")
+ video = []
+ for frame_idx in range(latents.shape[0]):
+ video.append(self.vae.decode(
+ latents[frame_idx:frame_idx+1]).sample)
+ video = torch.cat(video)
+ video = einops.rearrange(video, "(b f) c h w -> b f h w c", f=video_length)
+ video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().contiguous()
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ return video
+
+ def decode_latents_with_temporal_decoder(self, latents):
+ video_length = latents.shape[2]
+ latents = 1 / self.vae.config.scaling_factor * latents
+ latents = einops.rearrange(latents, "b c f h w -> (b f) c h w")
+ video = []
+
+ decode_chunk_size = 14
+ for frame_idx in range(0, latents.shape[0], decode_chunk_size):
+ num_frames_in = latents[frame_idx : frame_idx + decode_chunk_size].shape[0]
+
+ decode_kwargs = {}
+ decode_kwargs["num_frames"] = num_frames_in
+
+ video.append(self.vae.decode(latents[frame_idx:frame_idx+decode_chunk_size], **decode_kwargs).sample)
+
+ video = torch.cat(video)
+ video = einops.rearrange(video, "(b f) c h w -> b f h w c", f=video_length)
+ video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().contiguous()
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ return video
\ No newline at end of file
diff --git a/sample/sample.py b/sample/sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..e99e77651f70a63daa62702bf3c79bcad0a64134
--- /dev/null
+++ b/sample/sample.py
@@ -0,0 +1,138 @@
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Sample new images from a pre-trained Latte.
+"""
+import os
+import sys
+try:
+ import utils
+
+ from diffusion import create_diffusion
+ from utils import find_model
+except:
+ sys.path.append(os.path.split(sys.path[0])[0])
+
+ import utils
+
+ from diffusion import create_diffusion
+ from utils import find_model
+
+import torch
+import argparse
+import torchvision
+
+from einops import rearrange
+from models import get_models
+from torchvision.utils import save_image
+from diffusers.models import AutoencoderKL
+from models.clip import TextEmbedder
+import imageio
+from omegaconf import OmegaConf
+
+torch.backends.cuda.matmul.allow_tf32 = True
+torch.backends.cudnn.allow_tf32 = True
+
+def main(args):
+ # Setup PyTorch:
+ # torch.manual_seed(args.seed)
+ torch.set_grad_enabled(False)
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ # device = "cpu"
+
+ if args.ckpt is None:
+ assert args.model == "Latte-XL/2", "Only Latte-XL/2 models are available for auto-download."
+ assert args.image_size in [256, 512]
+ assert args.num_classes == 1000
+
+ using_cfg = args.cfg_scale > 1.0
+
+ # Load model:
+ latent_size = args.image_size // 8
+ args.latent_size = latent_size
+ model = get_models(args).to(device)
+
+ if args.use_compile:
+ model = torch.compile(model)
+
+ # a pre-trained model or load a custom Latte checkpoint from train.py:
+ ckpt_path = args.ckpt
+ state_dict = find_model(ckpt_path)
+ model.load_state_dict(state_dict)
+
+ model.eval() # important!
+ diffusion = create_diffusion(str(args.num_sampling_steps))
+ # vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device)
+ # text_encoder = TextEmbedder().to(device)
+
+ if args.use_fp16:
+ print('WARNING: using half percision for inferencing!')
+ vae.to(dtype=torch.float16)
+ model.to(dtype=torch.float16)
+ # text_encoder.to(dtype=torch.float16)
+
+ # Labels to condition the model with (feel free to change):
+
+ # Create sampling noise:
+ if args.use_fp16:
+ z = torch.randn(1, args.num_frames, 4, latent_size, latent_size, dtype=torch.float16, device=device) # b c f h w
+ else:
+ z = torch.randn(1, args.num_frames, 4, latent_size, latent_size, device=device)
+
+ # Setup classifier-free guidance:
+ # z = torch.cat([z, z], 0)
+ if using_cfg:
+ z = torch.cat([z, z], 0)
+ y = torch.randint(0, args.num_classes, (1,), device=device)
+ y_null = torch.tensor([101] * 1, device=device)
+ y = torch.cat([y, y_null], dim=0)
+ model_kwargs = dict(y=y, cfg_scale=args.cfg_scale, use_fp16=args.use_fp16)
+ sample_fn = model.forward_with_cfg
+ else:
+ sample_fn = model.forward
+ model_kwargs = dict(y=None, use_fp16=args.use_fp16)
+
+ # Sample images:
+ if args.sample_method == 'ddim':
+ samples = diffusion.ddim_sample_loop(
+ sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device
+ )
+ elif args.sample_method == 'ddpm':
+ samples = diffusion.p_sample_loop(
+ sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device
+ )
+
+ print(samples.shape)
+ if args.use_fp16:
+ samples = samples.to(dtype=torch.float16)
+ b, f, c, h, w = samples.shape
+ samples = rearrange(samples, 'b f c h w -> (b f) c h w')
+ samples = vae.decode(samples / 0.18215).sample
+ samples = rearrange(samples, '(b f) c h w -> b f c h w', b=b)
+ # Save and display images:
+
+ if not os.path.exists(args.save_video_path):
+ os.makedirs(args.save_video_path)
+
+
+ video_ = ((samples[0] * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1).contiguous()
+ video_save_path = os.path.join(args.save_video_path, 'sample' + '.mp4')
+ print(video_save_path)
+ imageio.mimwrite(video_save_path, video_, fps=8, quality=9)
+ print('save path {}'.format(args.save_video_path))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, default="./configs/ucf101/ucf101_sample.yaml")
+ parser.add_argument("--ckpt", type=str, default="")
+ parser.add_argument("--save_video_path", type=str, default="./sample_videos/")
+ args = parser.parse_args()
+ omega_conf = OmegaConf.load(args.config)
+ omega_conf.ckpt = args.ckpt
+ omega_conf.save_video_path = args.save_video_path
+ main(omega_conf)
diff --git a/sample/sample_ddp.py b/sample/sample_ddp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9ca3ce027daa2e104cb6605d02864bf702f31b5
--- /dev/null
+++ b/sample/sample_ddp.py
@@ -0,0 +1,199 @@
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Samples a large number of images from a pre-trained Latte model using DDP.
+Subsequently saves a .npz file that can be used to compute FVD and other
+evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations
+
+For a simple single-GPU/CPU sampling script, see sample.py.
+"""
+import io
+import os
+import sys
+import torch
+sys.path.append(os.path.split(sys.path[0])[0])
+import torch.distributed as dist
+from utils import find_model
+from diffusion import create_diffusion
+from diffusers.models import AutoencoderKL
+from tqdm import tqdm
+import os
+from PIL import Image
+import numpy as np
+import math
+import argparse
+import imageio
+from omegaconf import OmegaConf
+from models import get_models
+from einops import rearrange
+
+
+def create_npz_from_sample_folder(sample_dir, num=50_000):
+ """
+ Builds a single .npz file from a folder of .png samples.
+ """
+ samples = []
+ for i in tqdm(range(num), desc="Building .npz file from samples"):
+ sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
+ samples.append(sample_np)
+ samples = np.stack(samples)
+ assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
+ npz_path = f"{sample_dir}.npz"
+ np.savez(npz_path, arr_0=samples)
+ print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
+ return npz_path
+
+
+def main(args):
+ """
+ Run sampling.
+ """
+ torch.backends.cuda.matmul.allow_tf32 = True # True: fast but may lead to some small numerical differences
+ assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
+ torch.set_grad_enabled(False)
+
+ # Setup DDP:
+ dist.init_process_group("nccl")
+ rank = dist.get_rank()
+ device = rank % torch.cuda.device_count()
+ if args.seed:
+ seed = args.seed * dist.get_world_size() + rank
+ torch.manual_seed(seed)
+ torch.cuda.set_device(device)
+ # print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
+
+ if args.ckpt is None:
+ assert args.model == "Latte-XL/2", "Only Latte-XL/2 models are available for auto-download."
+ assert args.image_size in [256, 512]
+ assert args.num_classes == 1000
+
+ # Load model:
+ latent_size = args.image_size // 8
+ args.latent_size = latent_size
+ model = get_models(args).to(device)
+
+ if args.use_compile:
+ model = torch.compile(model)
+
+ # a pre-trained model or load a custom Latte checkpoint from train.py:
+ ckpt_path = args.ckpt
+ state_dict = find_model(ckpt_path)
+ model.load_state_dict(state_dict)
+ model.eval() # important!
+ diffusion = create_diffusion(str(args.num_sampling_steps))
+ # vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
+ # vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device)
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="sd-vae-ft-ema").to(device)
+
+ if args.use_fp16:
+ print('WARNING: using half percision for inferencing!')
+ vae.to(dtype=torch.float16)
+ model.to(dtype=torch.float16)
+ # text_encoder.to(dtype=torch.float16)
+
+ assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0"
+ using_cfg = args.cfg_scale > 1.0
+
+ # Create folder to save samples:
+ # model_string_name = args.model.replace("/", "-")
+ # ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
+ # folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.image_size}-vae-{args.vae}-" \
+ # f"cfg-{args.cfg_scale}-seed-{args.seed}"
+ # sample_folder_dir = f"{args.sample_dir}/{folder_name}"
+ sample_folder_dir = args.save_video_path
+ if args.seed:
+ sample_folder_dir = args.save_video_path + '-seed-' + str(args.seed)
+ if rank == 0:
+ os.makedirs(sample_folder_dir, exist_ok=True)
+ print(f"Saving .mp4 samples at {sample_folder_dir}")
+ dist.barrier()
+
+ # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
+ n = args.per_proc_batch_size
+ global_batch_size = n * dist.get_world_size()
+ # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
+ total_samples = int(math.ceil(args.num_fvd_samples / global_batch_size) * global_batch_size)
+ if rank == 0:
+ print(f"Total number of images that will be sampled: {total_samples}")
+ assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
+ samples_needed_this_gpu = int(total_samples // dist.get_world_size())
+ assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
+ iterations = int(samples_needed_this_gpu // n)
+ pbar = range(iterations)
+ pbar = tqdm(pbar) if rank == 0 else pbar
+ total = 0
+ for _ in pbar:
+ # Sample inputs:
+ if args.use_fp16:
+ z = torch.randn(n, args.num_frames, 4, latent_size, latent_size, dtype=torch.float16, device=device)
+ else:
+ z = torch.randn(n, args.num_frames, 4, latent_size, latent_size, device=device)
+
+ # Setup classifier-free guidance:
+ if using_cfg:
+ z = torch.cat([z, z], 0)
+ y = torch.randint(0, args.num_classes, (n,), device=device)
+ y_null = torch.tensor([101] * n, device=device)
+ y = torch.cat([y, y_null], dim=0)
+ model_kwargs = dict(y=y, cfg_scale=args.cfg_scale, use_fp16=args.use_fp16)
+ sample_fn = model.forward_with_cfg
+ else:
+ model_kwargs = dict(y=None, use_fp16=args.use_fp16)
+ sample_fn = model.forward
+
+ # Sample images:
+ if args.sample_method == 'ddim':
+ samples = diffusion.ddim_sample_loop(
+ sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device
+ )
+ elif args.sample_method == 'ddpm':
+ samples = diffusion.p_sample_loop(
+ sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device
+ )
+
+
+ if using_cfg:
+ samples, _ = samples.chunk(2, dim=0) # Remove null class samples
+
+ if args.use_fp16:
+ samples = samples.to(dtype=torch.float16)
+
+ b, f, c, h, w = samples.shape
+ samples = rearrange(samples, 'b f c h w -> (b f) c h w')
+ samples = vae.decode(samples / 0.18215).sample
+ samples = rearrange(samples, '(b f) c h w -> b f c h w', b=b)
+
+ # Save samples to disk as individual .png files
+ for i, sample in enumerate(samples):
+ sample = ((sample * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1).contiguous()
+ index = i * dist.get_world_size() + rank + total
+ # Image.fromarray(sample).save(f"{sample_folder_dir}/{index:04d}.png")
+ sample_save_path = f"{sample_folder_dir}/{index:04d}.mp4"
+ imageio.mimwrite(sample_save_path, sample, fps=8, quality=9)
+ total += global_batch_size
+
+ # Make sure all processes have finished saving their samples before attempting to convert to .npz
+ dist.barrier()
+ # if rank == 0:
+ # create_npz_from_sample_folder(sample_folder_dir, args.num_fvd_samples)
+ # print("Done.")
+ # dist.barrier()
+ dist.destroy_process_group()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, default="./configs/tuneavideo.yaml")
+ parser.add_argument("--ckpt", type=str, default="")
+ parser.add_argument("--save_video_path", type=str, default="./sample_videos/")
+ parser.add_argument("--save_ceph", default=False, action='store_true')
+ args = parser.parse_args()
+ omega_conf = OmegaConf.load(args.config)
+ omega_conf.ckpt = args.ckpt
+ omega_conf.save_video_path = args.save_video_path
+ omega_conf.save_ceph = args.save_ceph
+ main(omega_conf)
\ No newline at end of file
diff --git a/sample/sample_t2x.py b/sample/sample_t2x.py
new file mode 100644
index 0000000000000000000000000000000000000000..c28ebce87e4d1b838f9bc18b068f963526494ef9
--- /dev/null
+++ b/sample/sample_t2x.py
@@ -0,0 +1,171 @@
+import os
+import torch
+import argparse
+import torchvision
+
+
+from diffusers.schedulers import (DDIMScheduler, DDPMScheduler, PNDMScheduler,
+ EulerDiscreteScheduler, DPMSolverMultistepScheduler,
+ HeunDiscreteScheduler, EulerAncestralDiscreteScheduler,
+ DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler)
+from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
+from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
+from omegaconf import OmegaConf
+from transformers import T5EncoderModel, T5Tokenizer
+
+import os, sys
+sys.path.append(os.path.split(sys.path[0])[0])
+from pipeline_latte import LattePipeline
+from models import get_models
+from utils import save_video_grid
+import imageio
+from torchvision.utils import save_image
+
+def main(args):
+ # torch.manual_seed(args.seed)
+ torch.set_grad_enabled(False)
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ transformer_model = get_models(args).to(device, dtype=torch.float16)
+
+ if args.enable_vae_temporal_decoder:
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
+ else:
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae", torch_dtype=torch.float16).to(device)
+ tokenizer = T5Tokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
+ text_encoder = T5EncoderModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device)
+
+ # set eval mode
+ transformer_model.eval()
+ vae.eval()
+ text_encoder.eval()
+
+ if args.sample_method == 'DDIM':
+ scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path,
+ subfolder="scheduler",
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ beta_schedule=args.beta_schedule,
+ variance_type=args.variance_type,
+ clip_sample=False)
+ elif args.sample_method == 'EulerDiscrete':
+ scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_path,
+ subfolder="scheduler",
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ beta_schedule=args.beta_schedule,
+ variance_type=args.variance_type)
+ elif args.sample_method == 'DDPM':
+ scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_path,
+ subfolder="scheduler",
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ beta_schedule=args.beta_schedule,
+ variance_type=args.variance_type,
+ clip_sample=False)
+ elif args.sample_method == 'DPMSolverMultistep':
+ scheduler = DPMSolverMultistepScheduler.from_pretrained(args.pretrained_model_path,
+ subfolder="scheduler",
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ beta_schedule=args.beta_schedule,
+ variance_type=args.variance_type)
+ elif args.sample_method == 'DPMSolverSinglestep':
+ scheduler = DPMSolverSinglestepScheduler.from_pretrained(args.pretrained_model_path,
+ subfolder="scheduler",
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ beta_schedule=args.beta_schedule,
+ variance_type=args.variance_type)
+ elif args.sample_method == 'PNDM':
+ scheduler = PNDMScheduler.from_pretrained(args.pretrained_model_path,
+ subfolder="scheduler",
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ beta_schedule=args.beta_schedule,
+ variance_type=args.variance_type)
+ elif args.sample_method == 'HeunDiscrete':
+ scheduler = HeunDiscreteScheduler.from_pretrained(args.pretrained_model_path,
+ subfolder="scheduler",
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ beta_schedule=args.beta_schedule,
+ variance_type=args.variance_type)
+ elif args.sample_method == 'EulerAncestralDiscrete':
+ scheduler = EulerAncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path,
+ subfolder="scheduler",
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ beta_schedule=args.beta_schedule,
+ variance_type=args.variance_type)
+ elif args.sample_method == 'DEISMultistep':
+ scheduler = DEISMultistepScheduler.from_pretrained(args.pretrained_model_path,
+ subfolder="scheduler",
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ beta_schedule=args.beta_schedule,
+ variance_type=args.variance_type)
+ elif args.sample_method == 'KDPM2AncestralDiscrete':
+ scheduler = KDPM2AncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path,
+ subfolder="scheduler",
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ beta_schedule=args.beta_schedule,
+ variance_type=args.variance_type)
+
+
+ videogen_pipeline = LattePipeline(vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ transformer=transformer_model).to(device)
+ # videogen_pipeline.enable_xformers_memory_efficient_attention()
+
+ if not os.path.exists(args.save_img_path):
+ os.makedirs(args.save_img_path)
+
+ # video_grids = []
+ for num_prompt, prompt in enumerate(args.text_prompt):
+ print('Processing the ({}) prompt'.format(prompt))
+ videos = videogen_pipeline(prompt,
+ video_length=args.video_length,
+ height=args.image_size[0],
+ width=args.image_size[1],
+ num_inference_steps=args.num_sampling_steps,
+ guidance_scale=args.guidance_scale,
+ enable_temporal_attentions=args.enable_temporal_attentions,
+ num_images_per_prompt=1,
+ mask_feature=True,
+ enable_vae_temporal_decoder=args.enable_vae_temporal_decoder
+ ).video
+ if videos.shape[1] == 1:
+ try:
+ save_image(videos[0][0], args.save_img_path + prompt.replace(' ', '_') + '.png')
+ except:
+ save_image(videos[0][0], args.save_img_path + str(num_prompt)+ '.png')
+ print('Error when saving {}'.format(prompt))
+ else:
+ try:
+ imageio.mimwrite(args.save_img_path + prompt.replace(' ', '_') + '_%04d' % args.run_time + '.mp4', videos[0], fps=8, quality=9) # highest quality is 10, lowest is 0
+ except:
+ print('Error when saving {}'.format(prompt))
+ # save video grid
+ # video_grids.append(videos)
+
+ # video_grids = torch.cat(video_grids, dim=0)
+
+ # video_grids = save_video_grid(video_grids)
+
+ # # torchvision.io.write_video(args.save_img_path + '_%04d' % args.run_time + '-.mp4', video_grids, fps=6)
+ # imageio.mimwrite(args.save_img_path + '_%04d' % args.run_time + '-.mp4', video_grids, fps=8, quality=6)
+ # print('save path {}'.format(args.save_img_path))
+
+ # save_videos_grid(video, f"./{prompt}.gif")
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, default="./configs/wbv10m_train.yaml")
+ args = parser.parse_args()
+
+ main(OmegaConf.load(args.config))
+
diff --git a/sample/sky.sh b/sample/sky.sh
new file mode 100644
index 0000000000000000000000000000000000000000..752e076149cc91bf431a25c02d5332ba2164fc45
--- /dev/null
+++ b/sample/sky.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+export CUDA_VISIBLE_DEVICES=7
+
+python sample/sample.py \
+--config ./configs/sky/sky_sample.yaml \
+--ckpt ./share_ckpts/skytimelapse.pt \
+--save_video_path ./test
\ No newline at end of file
diff --git a/sample/sky_ddp.sh b/sample/sky_ddp.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6cead37f7dc1db90bcc01a257e423cc4c1ae3dc5
--- /dev/null
+++ b/sample/sky_ddp.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+export CUDA_VISIBLE_DEVICES=6,7
+
+torchrun --nnodes=1 --nproc_per_node=2 sample/sample_ddp.py \
+--config ./configs/sky/sky_sample.yaml \
+--ckpt ./share_ckpts/skytimelapse.pt \
+--save_video_path ./test
diff --git a/sample/t2i.sh b/sample/t2i.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1cb9fe369c1c79e9c37067cf90a162e2f1975b03
--- /dev/null
+++ b/sample/t2i.sh
@@ -0,0 +1,2 @@
+export CUDA_VISIBLE_DEVICES=7
+python sample/sample_t2x.py --config configs/t2x/t2i_sample.yaml
\ No newline at end of file
diff --git a/sample/t2v.sh b/sample/t2v.sh
new file mode 100644
index 0000000000000000000000000000000000000000..bbd1982bf67bc4d7ee01f531419e577c8de486ac
--- /dev/null
+++ b/sample/t2v.sh
@@ -0,0 +1,2 @@
+export CUDA_VISIBLE_DEVICES=0
+python sample/sample_t2x.py --config configs/t2x/t2v_sample.yaml
\ No newline at end of file
diff --git a/sample/taichi.sh b/sample/taichi.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1d31aea4878f9fa7c1f5e572b58eb4345e6b8b89
--- /dev/null
+++ b/sample/taichi.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+export CUDA_VISIBLE_DEVICES=7
+
+python sample/sample.py \
+--config ./configs/ucf101/taichi_sample.yaml \
+--ckpt ./share_ckpts/taichi-hd.pt \
+--save_video_path ./test
\ No newline at end of file
diff --git a/sample/taichi_ddp.sh b/sample/taichi_ddp.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7b93ccd258e2e632667cfa9c3baf548757e74575
--- /dev/null
+++ b/sample/taichi_ddp.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+export CUDA_VISIBLE_DEVICES=6,7
+
+torchrun --nnodes=1 --nproc_per_node=2 sample/sample_ddp.py \
+--config ./configs/taichi/taichi_sample.yaml \
+--ckpt ./share_ckpts/taichi-hd.pt \
+--save_video_path ./test \
diff --git a/sample/ucf101.sh b/sample/ucf101.sh
new file mode 100644
index 0000000000000000000000000000000000000000..df8e1aef7aed990b800f6ba5bb80e05215bc5c78
--- /dev/null
+++ b/sample/ucf101.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+export CUDA_VISIBLE_DEVICES=7
+python sample/sample.py \
+--config ./configs/ucf101/ucf101_sample.yaml \
+--ckpt ./share_ckpts/ucf101.pt \
+--save_video_path ./test
\ No newline at end of file
diff --git a/sample/ucf101_ddp.sh b/sample/ucf101_ddp.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8738ea6b6a8f800caafe9fb770dd84188c817c51
--- /dev/null
+++ b/sample/ucf101_ddp.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+export CUDA_VISIBLE_DEVICES=6,7
+torchrun --nnodes=1 --nproc_per_node=2 --master_port=29520 sample/sample_ddp.py \
+--config ./configs/ucf101/ucf101_sample.yaml \
+--ckpt ./share_ckpts/ucf101.pt \
+--save_video_path ./test
diff --git a/sample_videos/t2v-temp.mp4 b/sample_videos/t2v-temp.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..7c4b98a89a4cd0463a2e39874a159330bb24cf2b
Binary files /dev/null and b/sample_videos/t2v-temp.mp4 differ
diff --git a/slurm_scripts/ffs.slurm b/slurm_scripts/ffs.slurm
new file mode 100644
index 0000000000000000000000000000000000000000..512161c9c4d87aeca6685dbb6f604d2cd343ff20
--- /dev/null
+++ b/slurm_scripts/ffs.slurm
@@ -0,0 +1,16 @@
+#!/usr/bin/env bash
+#SBATCH --job-name=Latte-ffs #To give your job a name, replace "Latte-ffs" with an appropriate name
+#SBATCH --partition group-name
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=8
+#SBATCH --gres=gpu:8
+#SBATCH --cpus-per-task=16
+#SBATCH --time=500:00:00
+#SBATCH --output=slurm_log/%j.out
+#SBATCH --error=slurm_log/%j.err
+
+source ~/.bashrc
+
+conda activate latte
+
+srun python train.py --config ./configs/ffs/ffs_train.yaml
\ No newline at end of file
diff --git a/slurm_scripts/sky.slurm b/slurm_scripts/sky.slurm
new file mode 100644
index 0000000000000000000000000000000000000000..97b9cc5d77037c372d488ccc056afd443584b902
--- /dev/null
+++ b/slurm_scripts/sky.slurm
@@ -0,0 +1,16 @@
+#!/usr/bin/env bash
+#SBATCH --job-name=Latte-ffs
+#SBATCH --partition group-name
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=8
+#SBATCH --gres=gpu:8
+#SBATCH --cpus-per-task=16
+#SBATCH --time=500:00:00
+#SBATCH --output=slurm_log/%j.out
+#SBATCH --error=slurm_log/%j.err
+
+source ~/.bashrc
+
+conda activate latte
+
+srun python train.py --config ./configs/sky/sky_train.yaml
\ No newline at end of file
diff --git a/slurm_scripts/taichi.slurm b/slurm_scripts/taichi.slurm
new file mode 100644
index 0000000000000000000000000000000000000000..d525946ef092e5bf8a3f60676ecc55e879e2176b
--- /dev/null
+++ b/slurm_scripts/taichi.slurm
@@ -0,0 +1,16 @@
+#!/usr/bin/env bash
+#SBATCH --job-name=Latte-ffs
+#SBATCH --partition group-name
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=8
+#SBATCH --gres=gpu:8
+#SBATCH --cpus-per-task=16
+#SBATCH --time=500:00:00
+#SBATCH --output=slurm_log/%j.out
+#SBATCH --error=slurm_log/%j.err
+
+source ~/.bashrc
+
+conda activate latte
+
+srun python train.py --config ./configs/taichi/taichi_train.yaml
\ No newline at end of file
diff --git a/slurm_scripts/ucf101.slurm b/slurm_scripts/ucf101.slurm
new file mode 100644
index 0000000000000000000000000000000000000000..4eaeb45ed4229dd525f89c528417650deccb4327
--- /dev/null
+++ b/slurm_scripts/ucf101.slurm
@@ -0,0 +1,16 @@
+#!/usr/bin/env bash
+#SBATCH --job-name=Latte-ffs
+#SBATCH --partition group-name
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=8
+#SBATCH --gres=gpu:8
+#SBATCH --cpus-per-task=16
+#SBATCH --time=500:00:00
+#SBATCH --output=slurm_log/%j.out
+#SBATCH --error=slurm_log/%j.err
+
+source ~/.bashrc
+
+conda activate latte
+
+srun python train.py --config ./configs/ucf101/ucf101_train.yaml
\ No newline at end of file
diff --git a/tools/calc_metrics_for_dataset.py b/tools/calc_metrics_for_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..5622f91d776277907a9d9312ab076c9ce12c2fee
--- /dev/null
+++ b/tools/calc_metrics_for_dataset.py
@@ -0,0 +1,173 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Calculate quality metrics for previous training run or pretrained network pickle."""
+
+import sys; sys.path.extend(['.', 'tools'])
+import os
+import click
+import tempfile
+import torch
+from omegaconf import OmegaConf
+from tools import dnnlib
+
+from metrics import metric_main
+from metrics import metric_utils
+from tools.torch_utils import training_stats
+from tools.torch_utils import custom_ops
+
+#----------------------------------------------------------------------------
+
+def subprocess_fn(rank, args, temp_dir):
+ dnnlib.util.Logger(should_flush=True)
+
+ # Init torch.distributed.
+ if args.num_gpus > 1:
+ init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
+ if os.name == 'nt':
+ init_method = 'file:///' + init_file.replace('\\', '/')
+ torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
+ else:
+ init_method = f'file://{init_file}'
+ torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
+
+ # Init torch_utils.
+ sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
+ training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
+ if rank != 0 or not args.verbose:
+ custom_ops.verbosity = 'none'
+
+ # Print network summary.
+ device = torch.device('cuda', rank)
+ torch.backends.cudnn.benchmark = True
+ torch.backends.cuda.matmul.allow_tf32 = False
+ torch.backends.cudnn.allow_tf32 = False
+
+ # Calculate each metric.
+ for metric in args.metrics:
+ if rank == 0 and args.verbose:
+ print(f'Calculating {metric}...')
+ progress = metric_utils.ProgressMonitor(verbose=args.verbose)
+ result_dict = metric_main.calc_metric(
+ metric=metric,
+ dataset_kwargs=args.dataset_kwargs, # real
+ gen_dataset_kwargs=args.gen_dataset_kwargs, # fake
+ generator_as_dataset=args.generator_as_dataset,
+ num_gpus=args.num_gpus,
+ rank=rank,
+ device=device,
+ progress=progress,
+ cache=args.use_cache,
+ num_runs=args.num_runs,
+ )
+
+ if rank == 0:
+ metric_main.report_metric(result_dict, run_dir=args.run_dir)
+
+ if rank == 0 and args.verbose:
+ print()
+
+ # Done.
+ if rank == 0 and args.verbose:
+ print('Exiting...')
+
+#----------------------------------------------------------------------------
+
+class CommaSeparatedList(click.ParamType):
+ name = 'list'
+
+ def convert(self, value, param, ctx):
+ _ = param, ctx
+ if value is None or value.lower() == 'none' or value == '':
+ return []
+ return value.split(',')
+
+#----------------------------------------------------------------------------
+
+def calc_metrics_for_dataset(ctx, metrics, real_data_path, fake_data_path, mirror, resolution, gpus, verbose, use_cache: bool, num_runs: int):
+ dnnlib.util.Logger(should_flush=True)
+
+ # Validate arguments.
+ args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, verbose=verbose)
+ if not all(metric_main.is_valid_metric(metric) for metric in args.metrics):
+ ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
+ if not args.num_gpus >= 1:
+ ctx.fail('--gpus must be at least 1')
+
+ dummy_dataset_cfg = OmegaConf.create({'max_num_frames': 10000})
+
+ # Initialize dataset options for real data.
+ args.dataset_kwargs = dnnlib.EasyDict(
+ class_name='utils.dataset.VideoFramesFolderDataset',
+ path=real_data_path,
+ cfg=dummy_dataset_cfg,
+ xflip=mirror,
+ resolution=resolution,
+ use_labels=False,
+ )
+
+ # Initialize dataset options for fake data.
+ args.gen_dataset_kwargs = dnnlib.EasyDict(
+ class_name='utils.dataset.VideoFramesFolderDataset',
+ path=fake_data_path,
+ cfg=dummy_dataset_cfg,
+ xflip=False,
+ resolution=resolution,
+ use_labels=False,
+ )
+ args.generator_as_dataset = True
+
+ # Print dataset options.
+ if args.verbose:
+ print('Real data options:')
+ print(args.dataset_kwargs)
+
+ print('Fake data options:')
+ print(args.gen_dataset_kwargs)
+
+ print('*' * 50 + 'parting line' + '*' * 50)
+ print('Fake data options:')
+ print(args.gen_dataset_kwargs)
+
+ # Locate run dir.
+ args.run_dir = None
+ args.use_cache = use_cache
+ args.num_runs = num_runs
+
+ # Launch processes.
+ if args.verbose:
+ print('Launching processes...')
+ torch.multiprocessing.set_start_method('spawn')
+ with tempfile.TemporaryDirectory() as temp_dir:
+ if args.num_gpus == 1:
+ subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
+ else:
+ torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
+
+#----------------------------------------------------------------------------
+
+@click.command()
+@click.pass_context
+@click.option('--metrics', help='Comma-separated list or "none"', type=CommaSeparatedList(), default='fvd2048_16f,fid50k_full', show_default=True)
+@click.option('--real_data_path', help='Dataset to evaluate metrics against (directory or zip) [default: same as training data]', metavar='PATH')
+@click.option('--fake_data_path', help='Generated images (directory or zip)', metavar='PATH')
+@click.option('--mirror', help='Should we mirror the real data?', type=bool, metavar='BOOL')
+@click.option('--resolution', help='Resolution for the source dataset', type=int, metavar='INT')
+@click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True)
+@click.option('--verbose', help='Print optional information', type=bool, default=False, metavar='BOOL', show_default=True)
+@click.option('--use_cache', help='Use stats cache', type=bool, default=True, metavar='BOOL', show_default=True)
+@click.option('--num_runs', help='Number of runs', type=int, default=1, metavar='INT', show_default=True)
+def calc_metrics_cli_wrapper(ctx, *args, **kwargs):
+ calc_metrics_for_dataset(ctx, *args, **kwargs)
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ calc_metrics_cli_wrapper() # pylint: disable=no-value-for-parameter
+
+#----------------------------------------------------------------------------
diff --git a/tools/convert_videos_to_frames.py b/tools/convert_videos_to_frames.py
new file mode 100644
index 0000000000000000000000000000000000000000..d94b87b9e5dcb1ef7a30011ef066d907c43dbd63
--- /dev/null
+++ b/tools/convert_videos_to_frames.py
@@ -0,0 +1,109 @@
+"""
+Converts a dataset of mp4 videos into a dataset of video frames
+I.e. a directory of mp4 files becomes a directory of directories of frames
+This speeds up loading during training because we do not need
+"""
+import os
+from typing import List
+import argparse
+from pathlib import Path
+from multiprocessing import Pool
+from collections import Counter
+
+from PIL import Image
+import torchvision.transforms.functional as TVF
+from moviepy.editor import VideoFileClip
+from tqdm import tqdm
+
+
+def convert_videos_to_frames(source_dir: os.PathLike, target_dir: os.PathLike, num_workers: int, video_ext: str, **process_video_kwargs):
+ broken_clips_dir = f'{target_dir}_broken_clips'
+ os.makedirs(target_dir, exist_ok=True)
+ os.makedirs(broken_clips_dir, exist_ok=True)
+
+ clips_paths = [cp for cp in listdir_full_paths(source_dir) if cp.endswith(video_ext)]
+ clips_fps = []
+ tasks_kwargs = [dict(
+ clip_path=cp,
+ target_dir=target_dir,
+ broken_clips_dir=broken_clips_dir,
+ **process_video_kwargs,
+ ) for cp in clips_paths]
+ pool = Pool(processes=num_workers)
+
+ for fps in tqdm(pool.imap_unordered(task_proxy, tasks_kwargs), total=len(clips_paths)):
+ clips_fps.append(fps)
+
+ print(f'All possible fps: {Counter(clips_fps).most_common()}')
+
+
+def task_proxy(kwargs):
+ """I do not know, how to pass several arguments to a pool job..."""
+ return process_video(**kwargs)
+
+
+def process_video(
+ clip_path: os.PathLike, target_dir: os.PathLike, force_fps: int=None, target_size: int=None,
+ broken_clips_dir: os.PathLike=None, compute_fps_only: bool=False) -> int:
+
+ clip_name = os.path.basename(clip_path)
+ clip_name = clip_name[:clip_name.rfind('.')]
+
+ try:
+ clip = VideoFileClip(clip_path)
+ except KeyboardInterrupt:
+ raise
+ except Exception as e:
+ print(f'Coudnt process clip: {clip_path}')
+ if not broken_clips_dir is None:
+ Path(os.path.join(broken_clips_dir, clip_name)).touch()
+ return 0
+
+ if compute_fps_only:
+ return clip.fps
+
+ fps = clip.fps if force_fps is None else force_fps
+ clip_target_dir = os.path.join(target_dir, clip_name)
+ clip_target_dir = clip_target_dir.replace('#', '_')
+ os.makedirs(clip_target_dir, exist_ok=True)
+
+ frame_idx = 0
+ for frame in clip.iter_frames(fps=fps):
+ frame = Image.fromarray(frame)
+ h, w = frame.size[0], frame.size[1]
+ min_size = min(h, w)
+ if not target_size is None:
+ # frame = TVF.resize(frame, size=target_size, interpolation=Image.LANCZOS)
+ # frame = TVF.center_crop(frame, output_size=(target_size, target_size))
+ frame = TVF.center_crop(frame, output_size=(min_size, min_size))
+ frame = TVF.resize(frame, size=target_size, interpolation=Image.LANCZOS)
+ frame.save(os.path.join(clip_target_dir, f'{frame_idx:06d}.jpg'), q=95)
+ frame_idx += 1
+
+ return clip.fps
+
+
+def listdir_full_paths(d) -> List[os.PathLike]:
+ return sorted([os.path.join(d, x) for x in os.listdir(d)])
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='Convert a dataset of mp4 files into a dataset of individual frames')
+ parser.add_argument('-s', '--source_dir', type=str, help='Path to the source dataset')
+ parser.add_argument('-t', '--target_dir', type=str, help='Where to save the new dataset')
+ parser.add_argument('--video_ext', type=str, default='avi', help='Video extension')
+ parser.add_argument('--target_size', type=int, default=128, help='What size should we resize to?')
+ parser.add_argument('--force_fps', type=int, help='What fps should we run videos with?')
+ parser.add_argument('--num_workers', type=int, default=8, help='Number of processes to launch')
+ parser.add_argument('--compute_fps_only', action='store_true', help='Should we just compute fps?')
+ args = parser.parse_args()
+
+ convert_videos_to_frames(
+ source_dir=args.source_dir,
+ target_dir=args.target_dir,
+ target_size=args.target_size,
+ force_fps=args.force_fps,
+ num_workers=args.num_workers,
+ video_ext=args.video_ext,
+ compute_fps_only=args.compute_fps_only,
+ )
diff --git a/tools/dnnlib/__init__.py b/tools/dnnlib/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f08cf36f11f9b0fd94c1b7caeadf69b98375b04
--- /dev/null
+++ b/tools/dnnlib/__init__.py
@@ -0,0 +1,9 @@
+๏ปฟ# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+from .util import EasyDict, make_cache_dir_path
diff --git a/tools/dnnlib/util.py b/tools/dnnlib/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..0aa5f409b65385d7f4c459121a59826e64233152
--- /dev/null
+++ b/tools/dnnlib/util.py
@@ -0,0 +1,480 @@
+๏ปฟ# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Miscellaneous utility classes and functions."""
+
+import ctypes
+import fnmatch
+import importlib
+import inspect
+import numpy as np
+import os
+import shutil
+import sys
+import types
+import io
+import pickle
+import re
+import requests
+import html
+import hashlib
+import glob
+import tempfile
+import urllib
+import urllib.request
+import uuid
+
+from distutils.util import strtobool
+from typing import Any, List, Tuple, Union, Dict
+
+
+# Util classes
+# ------------------------------------------------------------------------------------------
+
+
+class EasyDict(dict):
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
+
+ def __getattr__(self, name: str) -> Any:
+ try:
+ return self[name]
+ except KeyError:
+ raise AttributeError(name)
+
+ def __setattr__(self, name: str, value: Any) -> None:
+ self[name] = value
+
+ def __delattr__(self, name: str) -> None:
+ del self[name]
+
+ def to_dict(self) -> Dict:
+ return {k: (v.to_dict() if isinstance(v, EasyDict) else v) for (k, v) in self.items()}
+
+
+class Logger(object):
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
+
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
+ self.file = None
+
+ if file_name is not None:
+ self.file = open(file_name, file_mode)
+
+ self.should_flush = should_flush
+ self.stdout = sys.stdout
+ self.stderr = sys.stderr
+
+ sys.stdout = self
+ sys.stderr = self
+
+ def __enter__(self) -> "Logger":
+ return self
+
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+ self.close()
+
+ def write(self, text: Union[str, bytes]) -> None:
+ """Write text to stdout (and a file) and optionally flush."""
+ if isinstance(text, bytes):
+ text = text.decode()
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
+ return
+
+ if self.file is not None:
+ self.file.write(text)
+
+ self.stdout.write(text)
+
+ if self.should_flush:
+ self.flush()
+
+ def flush(self) -> None:
+ """Flush written text to both stdout and a file, if open."""
+ if self.file is not None:
+ self.file.flush()
+
+ self.stdout.flush()
+
+ def close(self) -> None:
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
+ self.flush()
+
+ # if using multiple loggers, prevent closing in wrong order
+ if sys.stdout is self:
+ sys.stdout = self.stdout
+ if sys.stderr is self:
+ sys.stderr = self.stderr
+
+ if self.file is not None:
+ self.file.close()
+ self.file = None
+
+
+# Cache directories
+# ------------------------------------------------------------------------------------------
+
+_dnnlib_cache_dir = None
+
+def set_cache_dir(path: str) -> None:
+ global _dnnlib_cache_dir
+ _dnnlib_cache_dir = path
+
+def make_cache_dir_path(*paths: str) -> str:
+ if _dnnlib_cache_dir is not None:
+ return os.path.join(_dnnlib_cache_dir, *paths)
+ if 'DNNLIB_CACHE_DIR' in os.environ:
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
+ if 'HOME' in os.environ:
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
+ if 'USERPROFILE' in os.environ:
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
+
+# Small util functions
+# ------------------------------------------------------------------------------------------
+
+
+def format_time(seconds: Union[int, float]) -> str:
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
+ s = int(np.rint(seconds))
+
+ if s < 60:
+ return "{0}s".format(s)
+ elif s < 60 * 60:
+ return "{0}m {1:02}s".format(s // 60, s % 60)
+ elif s < 24 * 60 * 60:
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
+ else:
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
+
+
+def ask_yes_no(question: str) -> bool:
+ """Ask the user the question until the user inputs a valid answer."""
+ while True:
+ try:
+ print("{0} [y/n]".format(question))
+ return strtobool(input().lower())
+ except ValueError:
+ pass
+
+
+def tuple_product(t: Tuple) -> Any:
+ """Calculate the product of the tuple elements."""
+ result = 1
+
+ for v in t:
+ result *= v
+
+ return result
+
+
+_str_to_ctype = {
+ "uint8": ctypes.c_ubyte,
+ "uint16": ctypes.c_uint16,
+ "uint32": ctypes.c_uint32,
+ "uint64": ctypes.c_uint64,
+ "int8": ctypes.c_byte,
+ "int16": ctypes.c_int16,
+ "int32": ctypes.c_int32,
+ "int64": ctypes.c_int64,
+ "float32": ctypes.c_float,
+ "float64": ctypes.c_double
+}
+
+
+def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
+ type_str = None
+
+ if isinstance(type_obj, str):
+ type_str = type_obj
+ elif hasattr(type_obj, "__name__"):
+ type_str = type_obj.__name__
+ elif hasattr(type_obj, "name"):
+ type_str = type_obj.name
+ else:
+ raise RuntimeError("Cannot infer type name from input")
+
+ assert type_str in _str_to_ctype.keys()
+
+ my_dtype = np.dtype(type_str)
+ my_ctype = _str_to_ctype[type_str]
+
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
+
+ return my_dtype, my_ctype
+
+
+def is_pickleable(obj: Any) -> bool:
+ try:
+ with io.BytesIO() as stream:
+ pickle.dump(obj, stream)
+ return True
+ except:
+ return False
+
+
+# Functionality to import modules/objects by name, and call functions by name
+# ------------------------------------------------------------------------------------------
+
+def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
+ """Searches for the underlying module behind the name to some python object.
+ Returns the module and the object name (original name with module part removed)."""
+
+ # allow convenience shorthands, substitute them by full names
+ obj_name = re.sub("^np.", "numpy.", obj_name)
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
+
+ # list alternatives for (module_name, local_obj_name)
+ parts = obj_name.split(".")
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
+
+ # try each alternative in turn
+ for module_name, local_obj_name in name_pairs:
+ try:
+ module = importlib.import_module(module_name) # may raise ImportError
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
+ return module, local_obj_name
+ except:
+ pass
+
+ # maybe some of the modules themselves contain errors?
+ for module_name, _local_obj_name in name_pairs:
+ try:
+ importlib.import_module(module_name) # may raise ImportError
+ except ImportError:
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
+ raise
+
+ # maybe the requested attribute is missing?
+ for module_name, local_obj_name in name_pairs:
+ try:
+ module = importlib.import_module(module_name) # may raise ImportError
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
+ except ImportError:
+ pass
+
+ # we are out of luck, but we have no idea why
+ raise ImportError(obj_name)
+
+
+def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
+ """Traverses the object name and returns the last (rightmost) python object."""
+ if obj_name == '':
+ return module
+ obj = module
+ for part in obj_name.split("."):
+ obj = getattr(obj, part)
+ return obj
+
+
+def get_obj_by_name(name: str) -> Any:
+ """Finds the python object with the given name."""
+ module, obj_name = get_module_from_obj_name(name)
+ return get_obj_from_module(module, obj_name)
+
+
+def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
+ """Finds the python object with the given name and calls it as a function."""
+ assert func_name is not None
+ func_obj = get_obj_by_name(func_name)
+ assert callable(func_obj)
+ return func_obj(*args, **kwargs)
+
+
+def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
+ """Finds the python class with the given name and constructs it with the given arguments."""
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
+
+
+def get_module_dir_by_obj_name(obj_name: str) -> str:
+ """Get the directory path of the module containing the given object name."""
+ module, _ = get_module_from_obj_name(obj_name)
+ return os.path.dirname(inspect.getfile(module))
+
+
+def is_top_level_function(obj: Any) -> bool:
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
+
+
+def get_top_level_function_name(obj: Any) -> str:
+ """Return the fully-qualified name of a top-level function."""
+ assert is_top_level_function(obj)
+ module = obj.__module__
+ if module == '__main__':
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
+ return module + "." + obj.__name__
+
+
+# File system helpers
+# ------------------------------------------------------------------------------------------
+
+def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
+ """List all files recursively in a given directory while ignoring given file and directory names.
+ Returns list of tuples containing both absolute and relative paths."""
+ assert os.path.isdir(dir_path)
+ base_name = os.path.basename(os.path.normpath(dir_path))
+
+ if ignores is None:
+ ignores = []
+
+ result = []
+
+ for root, dirs, files in os.walk(dir_path, topdown=True):
+ for ignore_ in ignores:
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
+
+ # dirs need to be edited in-place
+ for d in dirs_to_remove:
+ dirs.remove(d)
+
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
+
+ absolute_paths = [os.path.join(root, f) for f in files]
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
+
+ if add_base_to_relative:
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
+
+ assert len(absolute_paths) == len(relative_paths)
+ result += zip(absolute_paths, relative_paths)
+
+ return result
+
+
+def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
+ """Takes in a list of tuples of (src, dst) paths and copies files.
+ Will create all necessary directories."""
+ for file in files:
+ target_dir_name = os.path.dirname(file[1])
+
+ # will create all intermediate-level directories
+ if not os.path.exists(target_dir_name):
+ os.makedirs(target_dir_name)
+
+ shutil.copyfile(file[0], file[1])
+
+
+# URL helpers
+# ------------------------------------------------------------------------------------------
+
+def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
+ """Determine whether the given object is a valid URL string."""
+ if not isinstance(obj, str) or not "://" in obj:
+ return False
+ if allow_file_urls and obj.startswith('file://'):
+ return True
+ try:
+ res = requests.compat.urlparse(obj)
+ if not res.scheme or not res.netloc or not "." in res.netloc:
+ return False
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
+ if not res.scheme or not res.netloc or not "." in res.netloc:
+ return False
+ except:
+ return False
+ return True
+
+
+def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
+ """Download the given URL and return a binary-mode file object to access the data."""
+ assert num_attempts >= 1
+ assert not (return_filename and (not cache))
+
+ # Doesn't look like an URL scheme so interpret it as a local filename.
+ if not re.match('^[a-z]+://', url):
+ return url if return_filename else open(url, "rb")
+
+ # Handle file URLs. This code handles unusual file:// patterns that
+ # arise on Windows:
+ #
+ # file:///c:/foo.txt
+ #
+ # which would translate to a local '/c:/foo.txt' filename that's
+ # invalid. Drop the forward slash for such pathnames.
+ #
+ # If you touch this code path, you should test it on both Linux and
+ # Windows.
+ #
+ # Some internet resources suggest using urllib.request.url2pathname() but
+ # but that converts forward slashes to backslashes and this causes
+ # its own set of problems.
+ if url.startswith('file://'):
+ filename = urllib.parse.urlparse(url).path
+ if re.match(r'^/[a-zA-Z]:', filename):
+ filename = filename[1:]
+ return filename if return_filename else open(filename, "rb")
+
+ assert is_url(url)
+
+ # Lookup from cache.
+ if cache_dir is None:
+ cache_dir = make_cache_dir_path('downloads')
+
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
+ if cache:
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
+ if len(cache_files) == 1:
+ filename = cache_files[0]
+ return filename if return_filename else open(filename, "rb")
+
+ # Download.
+ url_name = None
+ url_data = None
+ with requests.Session() as session:
+ if verbose:
+ print("Downloading %s ..." % url, end="", flush=True)
+ for attempts_left in reversed(range(num_attempts)):
+ try:
+ with session.get(url) as res:
+ res.raise_for_status()
+ if len(res.content) == 0:
+ raise IOError("No data received")
+
+ if len(res.content) < 8192:
+ content_str = res.content.decode("utf-8")
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
+ if len(links) == 1:
+ url = requests.compat.urljoin(url, links[0])
+ raise IOError("Google Drive virus checker nag")
+ if "Google Drive - Quota exceeded" in content_str:
+ raise IOError("Google Drive download quota exceeded -- please try again later")
+
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
+ url_name = match[1] if match else url
+ url_data = res.content
+ if verbose:
+ print(" done")
+ break
+ except KeyboardInterrupt:
+ raise
+ except:
+ if not attempts_left:
+ if verbose:
+ print(" failed")
+ raise
+ if verbose:
+ print(".", end="", flush=True)
+
+ # Save to cache.
+ if cache:
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
+ os.makedirs(cache_dir, exist_ok=True)
+ with open(temp_file, "wb") as f:
+ f.write(url_data)
+ os.replace(temp_file, cache_file) # atomic
+ if return_filename:
+ return cache_file
+
+ # Return data as file object.
+ assert not return_filename
+ return io.BytesIO(url_data)
diff --git a/tools/eval_metrics.sh b/tools/eval_metrics.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9e513e9520b25540a0956dbd86b1fcf68bf83f66
--- /dev/null
+++ b/tools/eval_metrics.sh
@@ -0,0 +1,7 @@
+export CUDA_VISIBLE_DEVICES=0
+python tools/calc_metrics_for_dataset.py \
+--real_data_path /path/to/real_data//images \
+--fake_data_path /path/to/fake_data/images \
+--mirror 1 --gpus 1 --resolution 256 \
+--metrics fvd2048_16f \
+--verbose 0 --use_cache 0
\ No newline at end of file
diff --git a/tools/metrics/__init__.py b/tools/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..55435dbaba8542b9080b8bdcbc8ed2015d445a4b
--- /dev/null
+++ b/tools/metrics/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
diff --git a/tools/metrics/frechet_inception_distance.py b/tools/metrics/frechet_inception_distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..a569cc60ef0fb555ffc3418202ff0eacaa83e7b0
--- /dev/null
+++ b/tools/metrics/frechet_inception_distance.py
@@ -0,0 +1,54 @@
+๏ปฟ# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Frechet Inception Distance (FID) from the paper
+"GANs trained by a two time-scale update rule converge to a local Nash
+equilibrium". Matches the original implementation by Heusel et al. at
+https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
+
+import numpy as np
+import scipy.linalg
+from . import metric_utils
+
+NUM_FRAMES_IN_BATCH = {128: 32, 256: 32, 512: 8, 1024: 2}
+
+#----------------------------------------------------------------------------
+
+def compute_fid(opts, max_real, num_gen):
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
+
+ batch_size = NUM_FRAMES_IN_BATCH[opts.dataset_kwargs.resolution]
+
+ mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real, use_image_dataset=True).get_mean_cov()
+
+ if opts.generator_as_dataset:
+ compute_gen_stats_fn = metric_utils.compute_feature_stats_for_dataset
+ gen_opts = metric_utils.rewrite_opts_for_gen_dataset(opts)
+ gen_kwargs = dict(use_image_dataset=True)
+ else:
+ compute_gen_stats_fn = metric_utils.compute_feature_stats_for_generator
+ gen_opts = opts
+ gen_kwargs = dict()
+
+ mu_gen, sigma_gen = compute_gen_stats_fn(
+ opts=gen_opts, detector_url=detector_url, detector_kwargs=detector_kwargs, batch_size=batch_size,
+ rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen, **gen_kwargs).get_mean_cov()
+
+ if opts.rank != 0:
+ return float('nan')
+
+ m = np.square(mu_gen - mu_real).sum()
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
+ fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
+ return float(fid)
+
+#----------------------------------------------------------------------------
diff --git a/tools/metrics/frechet_video_distance.py b/tools/metrics/frechet_video_distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..43170d81d0fcdc6c3959c8081a41da5b43da0e44
--- /dev/null
+++ b/tools/metrics/frechet_video_distance.py
@@ -0,0 +1,62 @@
+"""
+Frechet Video Distance (FVD). Matches the original tensorflow implementation from
+https://github.com/google-research/google-research/blob/master/frechet_video_distance/frechet_video_distance.py
+up to the upsampling operation. Note that this tf.hub I3D model is different from the one released in the I3D repo.
+"""
+
+import copy
+import numpy as np
+import scipy.linalg
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+NUM_FRAMES_IN_BATCH = {128: 128, 256: 128, 512: 64, 1024: 32}
+
+#----------------------------------------------------------------------------
+
+def compute_fvd(opts, max_real: int, num_gen: int, num_frames: int, realdata_subsample_factor: int=3, gendata_subsample_factor: int=1):
+ # Perfectly reproduced torchscript version of the I3D model, trained on Kinetics-400, used here:
+ # https://github.com/google-research/google-research/blob/master/frechet_video_distance/frechet_video_distance.py
+ # Note that the weights on tf.hub (used in the script above) differ from the original released weights
+ detector_url = 'https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1'
+ detector_kwargs = dict(rescale=True, resize=True, return_features=True) # Return raw features before the softmax layer.
+
+ # real data args
+ opts = copy.deepcopy(opts)
+ opts.dataset_kwargs.load_n_consecutive = num_frames
+ # opts.dataset_kwargs.load_n_consecutive = None
+ opts.dataset_kwargs.subsample_factor = realdata_subsample_factor
+ opts.dataset_kwargs.discard_short_videos = True
+ batch_size = NUM_FRAMES_IN_BATCH[opts.dataset_kwargs.resolution] // num_frames
+
+ mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, rel_lo=0, rel_hi=0,
+ capture_mean_cov=True, max_items=max_real, temporal_detector=True, batch_size=batch_size).get_mean_cov()
+
+ if opts.generator_as_dataset:
+ # fake data args
+ compute_gen_stats_fn = metric_utils.compute_feature_stats_for_dataset
+ gen_opts = metric_utils.rewrite_opts_for_gen_dataset(opts)
+ gen_opts.dataset_kwargs.load_n_consecutive = num_frames
+ gen_opts.dataset_kwargs.load_n_consecutive_random_offset = False
+ gen_opts.dataset_kwargs.subsample_factor = gendata_subsample_factor
+ gen_kwargs = dict()
+ else:
+ compute_gen_stats_fn = metric_utils.compute_feature_stats_for_generator
+ gen_opts = opts
+ gen_kwargs = dict(num_video_frames=num_frames, subsample_factor=gendata_subsample_factor)
+
+ mu_gen, sigma_gen = compute_gen_stats_fn(
+ opts=gen_opts, detector_url=detector_url, detector_kwargs=detector_kwargs, rel_lo=0, rel_hi=1, capture_mean_cov=True,
+ max_items=num_gen, temporal_detector=True, batch_size=batch_size, **gen_kwargs).get_mean_cov()
+
+ if opts.rank != 0:
+ return float('nan')
+
+ m = np.square(mu_gen - mu_real).sum()
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
+ fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
+ return float(fid)
+
+#----------------------------------------------------------------------------
diff --git a/tools/metrics/inception_score.py b/tools/metrics/inception_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..51a0009e98e14181d9fa62b06bf23e15c506c2ec
--- /dev/null
+++ b/tools/metrics/inception_score.py
@@ -0,0 +1,48 @@
+๏ปฟ# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Inception Score (IS) from the paper "Improved techniques for training
+GANs". Matches the original implementation by Salimans et al. at
+https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
+
+import numpy as np
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_is(opts, num_gen, num_splits):
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+ detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
+ detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
+
+ if opts.generator_as_dataset:
+ compute_gen_stats_fn = metric_utils.compute_feature_stats_for_dataset
+ gen_opts = metric_utils.rewrite_opts_for_gen_dataset(opts)
+ gen_kwargs = dict(use_image_dataset=True)
+ else:
+ compute_gen_stats_fn = metric_utils.compute_feature_stats_for_generator
+ gen_opts = opts
+ gen_kwargs = dict()
+
+ gen_probs = compute_gen_stats_fn(
+ opts=gen_opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ capture_all=True, max_items=num_gen, **gen_kwargs).get_all()
+
+ if opts.rank != 0:
+ return float('nan'), float('nan')
+
+ scores = []
+ for i in range(num_splits):
+ part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
+ kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
+ kl = np.mean(np.sum(kl, axis=1))
+ print(kl)
+ scores.append(np.exp(kl))
+ return float(np.mean(scores)), float(np.std(scores))
+
+#----------------------------------------------------------------------------
diff --git a/tools/metrics/kernel_inception_distance.py b/tools/metrics/kernel_inception_distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..772f9f941dc0dd1644ed35206bafe80511179807
--- /dev/null
+++ b/tools/metrics/kernel_inception_distance.py
@@ -0,0 +1,46 @@
+๏ปฟ# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Kernel Inception Distance (KID) from the paper "Demystifying MMD
+GANs". Matches the original implementation by Binkowski et al. at
+https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
+
+import numpy as np
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+ detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
+
+ real_features = metric_utils.compute_feature_stats_for_dataset(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real, use_image_dataset=True).get_all()
+
+ gen_features = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
+
+ if opts.rank != 0:
+ return float('nan')
+
+ n = real_features.shape[1]
+ m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
+ t = 0
+ for _subset_idx in range(num_subsets):
+ x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
+ y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
+ a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
+ b = (x @ y.T / n + 1) ** 3
+ t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
+ kid = t / num_subsets / m
+ return float(kid) * 1000.0
+
+#----------------------------------------------------------------------------
diff --git a/tools/metrics/metric_main.py b/tools/metrics/metric_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..54eeb16a0df44271f393df5804b73a60a70fb55f
--- /dev/null
+++ b/tools/metrics/metric_main.py
@@ -0,0 +1,155 @@
+๏ปฟ# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import os
+import time
+import json
+import torch
+import numpy as np
+from tools import dnnlib
+
+from . import metric_utils
+from . import frechet_inception_distance
+from . import kernel_inception_distance
+from . import inception_score
+from . import video_inception_score
+from . import frechet_video_distance
+
+#----------------------------------------------------------------------------
+
+_metric_dict = dict() # name => fn
+
+def register_metric(fn):
+ assert callable(fn)
+ _metric_dict[fn.__name__] = fn
+ return fn
+
+def is_valid_metric(metric):
+ return metric in _metric_dict
+
+def list_valid_metrics():
+ return list(_metric_dict.keys())
+
+def is_power_of_two(n: int) -> bool:
+ return (n & (n-1) == 0) and n != 0
+
+#----------------------------------------------------------------------------
+
+def calc_metric(metric, num_runs: int=1, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
+ assert is_valid_metric(metric)
+ opts = metric_utils.MetricOptions(**kwargs)
+
+ # Calculate.
+ start_time = time.time()
+ all_runs_results = [_metric_dict[metric](opts) for _ in range(num_runs)]
+ total_time = time.time() - start_time
+
+ # Broadcast results.
+ for results in all_runs_results:
+ for key, value in list(results.items()):
+ if opts.num_gpus > 1:
+ value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
+ torch.distributed.broadcast(tensor=value, src=0)
+ value = float(value.cpu())
+ results[key] = value
+
+ if num_runs > 1:
+ results = {f'{key}_run{i+1:02d}': value for i, results in enumerate(all_runs_results) for key, value in results.items()}
+ for key, value in all_runs_results[0].items():
+ all_runs_values = [r[key] for r in all_runs_results]
+ results[f'{key}_mean'] = np.mean(all_runs_values)
+ results[f'{key}_std'] = np.std(all_runs_values)
+ else:
+ results = all_runs_results[0]
+
+ # Decorate with metadata.
+ return dnnlib.EasyDict(
+ results = dnnlib.EasyDict(results),
+ metric = metric,
+ total_time = total_time,
+ total_time_str = dnnlib.util.format_time(total_time),
+ num_gpus = opts.num_gpus,
+ )
+
+#----------------------------------------------------------------------------
+
+def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
+ metric = result_dict['metric']
+ assert is_valid_metric(metric)
+ if run_dir is not None and snapshot_pkl is not None:
+ snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
+
+ jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
+ print(jsonl_line)
+ if run_dir is not None and os.path.isdir(run_dir):
+ with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
+ f.write(jsonl_line + '\n')
+
+#----------------------------------------------------------------------------
+# Primary metrics.
+
+@register_metric
+def fid50k_full(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
+ return dict(fid50k_full=fid)
+
+
+@register_metric
+def kid50k_full(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
+ return dict(kid50k_full=kid)
+
+@register_metric
+def is50k(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
+ return dict(is50k_mean=mean, is50k_std=std)
+
+@register_metric
+def fvd2048_16f(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ fvd = frechet_video_distance.compute_fvd(opts, max_real=2048, num_gen=2048, num_frames=16)
+ return dict(fvd2048_16f=fvd)
+
+@register_metric
+def fvd2048_128f(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ fvd = frechet_video_distance.compute_fvd(opts, max_real=2048, num_gen=2048, num_frames=128)
+ return dict(fvd2048_128f=fvd)
+
+@register_metric
+def fvd2048_128f_subsample8f(opts):
+ """Similar to `fvd2048_128f`, but we sample each 8-th frame"""
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ fvd = frechet_video_distance.compute_fvd(opts, max_real=2048, num_gen=2048, num_frames=16, subsample_factor=8)
+ return dict(fvd2048_128f_subsample8f=fvd)
+
+@register_metric
+def isv2048_ucf(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ mean, std = video_inception_score.compute_isv(opts, num_gen=2048, num_splits=10, backbone='c3d_ucf101')
+ return dict(isv2048_ucf_mean=mean, isv2048_ucf_std=std)
+
+#----------------------------------------------------------------------------
+# Legacy metrics.
+
+@register_metric
+def fid50k(opts):
+ opts.dataset_kwargs.update(max_size=None)
+ fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
+ return dict(fid50k=fid)
+
+@register_metric
+def kid50k(opts):
+ opts.dataset_kwargs.update(max_size=None)
+ kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
+ return dict(kid50k=kid)
+
+#----------------------------------------------------------------------------
diff --git a/tools/metrics/metric_utils.py b/tools/metrics/metric_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ade5ffb000bf717db058645e46b972ed296b3390
--- /dev/null
+++ b/tools/metrics/metric_utils.py
@@ -0,0 +1,335 @@
+๏ปฟ# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import os
+import time
+import hashlib
+import random
+import pickle
+import copy
+import uuid
+from urllib.parse import urlparse
+import numpy as np
+import torch
+from tools import dnnlib
+from tools.utils.dataset import video_to_image_dataset_kwargs
+
+#----------------------------------------------------------------------------
+
+class MetricOptions:
+ def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None,
+ progress=None, cache=True, gen_dataset_kwargs={}, generator_as_dataset=False):
+ assert 0 <= rank < num_gpus
+ self.G = G
+ self.G_kwargs = dnnlib.EasyDict(G_kwargs)
+ self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
+ self.num_gpus = num_gpus
+ self.rank = rank
+ self.device = device if device is not None else torch.device('cuda', rank)
+ self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
+ self.cache = cache
+ self.gen_dataset_kwargs = gen_dataset_kwargs
+ self.generator_as_dataset = generator_as_dataset
+
+#----------------------------------------------------------------------------
+
+_feature_detector_cache = dict()
+
+def get_feature_detector_name(url):
+ return os.path.splitext(url.split('/')[-1])[0]
+
+def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
+ assert 0 <= rank < num_gpus
+ key = (url, device)
+ if key not in _feature_detector_cache:
+ is_leader = (rank == 0)
+ if not is_leader and num_gpus > 1:
+ torch.distributed.barrier() # leader goes first
+ with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
+ if urlparse(url).path.endswith('.pkl'):
+ _feature_detector_cache[key] = pickle.load(f).to(device)
+ else:
+ _feature_detector_cache[key] = torch.jit.load(f).eval().to(device)
+ if is_leader and num_gpus > 1:
+ torch.distributed.barrier() # others follow
+ return _feature_detector_cache[key]
+
+#----------------------------------------------------------------------------
+
+class FeatureStats:
+ def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
+ self.capture_all = capture_all
+ self.capture_mean_cov = capture_mean_cov
+ self.max_items = max_items
+ self.num_items = 0
+ self.num_features = None
+ self.all_features = None
+ self.raw_mean = None
+ self.raw_cov = None
+
+ def set_num_features(self, num_features):
+ if self.num_features is not None:
+ assert num_features == self.num_features
+ else:
+ self.num_features = num_features
+ self.all_features = []
+ self.raw_mean = np.zeros([num_features], dtype=np.float64)
+ self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
+
+ def is_full(self):
+ return (self.max_items is not None) and (self.num_items >= self.max_items)
+
+ def append(self, x):
+ x = np.asarray(x, dtype=np.float32)
+ assert x.ndim == 2
+ if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
+ if self.num_items >= self.max_items:
+ return
+ x = x[:self.max_items - self.num_items]
+
+ self.set_num_features(x.shape[1])
+ self.num_items += x.shape[0]
+ if self.capture_all:
+ self.all_features.append(x)
+ if self.capture_mean_cov:
+ x64 = x.astype(np.float64)
+ self.raw_mean += x64.sum(axis=0)
+ self.raw_cov += x64.T @ x64
+
+ def append_torch(self, x, num_gpus=1, rank=0):
+ assert isinstance(x, torch.Tensor) and x.ndim == 2
+ assert 0 <= rank < num_gpus
+ if num_gpus > 1:
+ ys = []
+ for src in range(num_gpus):
+ y = x.clone()
+ torch.distributed.broadcast(y, src=src)
+ ys.append(y)
+ x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
+ self.append(x.cpu().numpy())
+
+ def get_all(self):
+ assert self.capture_all
+ return np.concatenate(self.all_features, axis=0)
+
+ def get_all_torch(self):
+ return torch.from_numpy(self.get_all())
+
+ def get_mean_cov(self):
+ assert self.capture_mean_cov
+ mean = self.raw_mean / self.num_items
+ cov = self.raw_cov / self.num_items
+ cov = cov - np.outer(mean, mean)
+ return mean, cov
+
+ def save(self, pkl_file):
+ with open(pkl_file, 'wb') as f:
+ pickle.dump(self.__dict__, f)
+
+ @staticmethod
+ def load(pkl_file):
+ with open(pkl_file, 'rb') as f:
+ s = dnnlib.EasyDict(pickle.load(f))
+ obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
+ obj.__dict__.update(s)
+ return obj
+
+#----------------------------------------------------------------------------
+
+class ProgressMonitor:
+ def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
+ self.tag = tag
+ self.num_items = num_items
+ self.verbose = verbose
+ self.flush_interval = flush_interval
+ self.progress_fn = progress_fn
+ self.pfn_lo = pfn_lo
+ self.pfn_hi = pfn_hi
+ self.pfn_total = pfn_total
+ self.start_time = time.time()
+ self.batch_time = self.start_time
+ self.batch_items = 0
+ if self.progress_fn is not None:
+ self.progress_fn(self.pfn_lo, self.pfn_total)
+
+ def update(self, cur_items: int):
+ assert (self.num_items is None) or (cur_items <= self.num_items), f"Wrong `items` values: cur_items={cur_items}, self.num_items={self.num_items}"
+ if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
+ return
+ cur_time = time.time()
+ total_time = cur_time - self.start_time
+ time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
+ if (self.verbose) and (self.tag is not None):
+ print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
+ self.batch_time = cur_time
+ self.batch_items = cur_items
+
+ if (self.progress_fn is not None) and (self.num_items is not None):
+ self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
+
+ def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
+ return ProgressMonitor(
+ tag = tag,
+ num_items = num_items,
+ flush_interval = flush_interval,
+ verbose = self.verbose,
+ progress_fn = self.progress_fn,
+ pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
+ pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
+ pfn_total = self.pfn_total,
+ )
+
+#----------------------------------------------------------------------------
+
+@torch.no_grad()
+def compute_feature_stats_for_dataset(
+ opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64,
+ data_loader_kwargs=None, max_items=None, temporal_detector=False, use_image_dataset=False,
+ feature_stats_cls=FeatureStats, **stats_kwargs):
+
+ dataset_kwargs = video_to_image_dataset_kwargs(opts.dataset_kwargs) if use_image_dataset else opts.dataset_kwargs
+ dataset = dnnlib.util.construct_class_by_name(**dataset_kwargs)
+
+ if data_loader_kwargs is None:
+ data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
+
+ # Try to lookup from cache.
+ cache_file = None
+ if opts.cache:
+ # Choose cache file name.
+ args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ stats_kwargs=stats_kwargs, feature_stats_cls=feature_stats_cls.__name__)
+ md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
+ cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
+ cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
+
+ # Check if the file exists (all processes must agree).
+ flag = os.path.isfile(cache_file) if opts.rank == 0 else False
+ if opts.num_gpus > 1:
+ flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
+ torch.distributed.broadcast(tensor=flag, src=0)
+ flag = (float(flag.cpu()) != 0)
+
+ # Load.
+ if flag:
+ return feature_stats_cls.load(cache_file)
+
+ # Initialize.
+ num_items = len(dataset)
+ data_length = len(dataset)
+ if max_items is not None:
+ num_items = min(num_items, max_items)
+ stats = feature_stats_cls(max_items=num_items, **stats_kwargs)
+ progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
+
+ # Main loop.
+ # item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)] # original stylegan-v code
+ item_subset = random.sample(range(data_length), num_items) # added by xin, randomly selected 2048 videos
+ for batch in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
+ images = batch['image']
+ if temporal_detector:
+ images = images.permute(0, 2, 1, 3, 4).contiguous() # [batch_size, c, t, h, w]
+
+ # images = images.float() / 255
+ # images = torch.nn.functional.interpolate(images, size=(images.shape[2], 128, 128), mode='trilinear', align_corners=False) # downsample
+ # images = torch.nn.functional.interpolate(images, size=(images.shape[2], 256, 256), mode='trilinear', align_corners=False) # upsample
+ # images = (images * 255).to(torch.uint8)
+ else:
+ images = images.view(-1, *images.shape[-3:]) # [-1, c, h, w]
+
+ if images.shape[1] == 1:
+ images = images.repeat([1, 3, *([1] * (images.ndim - 2))])
+ features = detector(images.to(opts.device), **detector_kwargs)
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
+ progress.update(stats.num_items)
+
+ # Save to cache.
+ if cache_file is not None and opts.rank == 0:
+ os.makedirs(os.path.dirname(cache_file), exist_ok=True)
+ temp_file = cache_file + '.' + uuid.uuid4().hex
+ stats.save(temp_file)
+ os.replace(temp_file, cache_file) # atomic
+ return stats
+
+#----------------------------------------------------------------------------
+
+@torch.no_grad()
+def compute_feature_stats_for_generator(
+ opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size: int=16,
+ batch_gen=None, jit=False, temporal_detector=False, num_video_frames: int=16,
+ feature_stats_cls=FeatureStats, subsample_factor: int=1, **stats_kwargs):
+
+ if batch_gen is None:
+ batch_gen = min(batch_size, 4)
+ assert batch_size % batch_gen == 0
+
+ # Setup generator and load labels.
+ G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
+
+ # Image generation func.
+ def run_generator(z, c, t):
+ img = G(z=z, c=c, t=t, **opts.G_kwargs)
+ bt, c, h, w = img.shape
+
+ if temporal_detector:
+ img = img.view(bt // num_video_frames, num_video_frames, c, h, w) # [batch_size, t, c, h, w]
+ img = img.permute(0, 2, 1, 3, 4).contiguous() # [batch_size, c, t, h, w]
+
+ # img = torch.nn.functional.interpolate(img, size=(img.shape[2], 128, 128), mode='trilinear', align_corners=False) # downsample
+ # img = torch.nn.functional.interpolate(img, size=(img.shape[2], 256, 256), mode='trilinear', align_corners=False) # upsample
+
+ img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
+ return img
+
+ # JIT.
+ if jit:
+ z = torch.zeros([batch_gen, G.z_dim], device=opts.device)
+ c = torch.zeros([batch_gen, G.c_dim], device=opts.device)
+ t = torch.zeros([batch_gen, G.cfg.sampling.num_frames_per_video], device=opts.device)
+ run_generator = torch.jit.trace(run_generator, [z, c, t], check_trace=False)
+
+ # Initialize.
+ stats = feature_stats_cls(**stats_kwargs)
+ assert stats.max_items is not None
+ progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
+
+ # Main loop.
+ while not stats.is_full():
+ images = []
+ for _i in range(batch_size // batch_gen):
+ z = torch.randn([batch_gen, G.z_dim], device=opts.device)
+ cond_sample_idx = [np.random.randint(len(dataset)) for _ in range(batch_gen)]
+ c = [dataset.get_label(i) for i in cond_sample_idx]
+ c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
+ t = [list(range(0, num_video_frames * subsample_factor, subsample_factor)) for _i in range(batch_gen)]
+ t = torch.from_numpy(np.stack(t)).pin_memory().to(opts.device)
+ images.append(run_generator(z, c, t))
+ images = torch.cat(images)
+ if images.shape[1] == 1:
+ images = images.repeat([1, 3, *([1] * (images.ndim - 2))])
+ features = detector(images, **detector_kwargs)
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
+ progress.update(stats.num_items)
+ return stats
+
+#----------------------------------------------------------------------------
+
+def rewrite_opts_for_gen_dataset(opts):
+ """
+ Updates dataset arguments in the opts to enable the second dataset stats computation
+ """
+ new_opts = copy.deepcopy(opts)
+ new_opts.dataset_kwargs = new_opts.gen_dataset_kwargs
+ new_opts.cache = False
+
+ return new_opts
+
+#----------------------------------------------------------------------------
\ No newline at end of file
diff --git a/tools/metrics/video_inception_score.py b/tools/metrics/video_inception_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb4309c67c3b772362f987fd283071d753109aab
--- /dev/null
+++ b/tools/metrics/video_inception_score.py
@@ -0,0 +1,54 @@
+"""Inception Score (IS) from the paper "Improved techniques for training
+GANs". Matches the original implementation by Salimans et al. at
+https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
+
+import numpy as np
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+NUM_FRAMES_IN_BATCH = {128: 128, 256: 128, 512: 64, 1024: 32}
+
+#----------------------------------------------------------------------------
+
+def compute_isv(opts, num_gen: int, num_splits: int, backbone: str):
+ if backbone == 'c3d_ucf101':
+ # Perfectly reproduced torchscript version of the original chainer checkpoint:
+ # https://github.com/pfnet-research/tgan2/blob/f892bc432da315d4f6b6ae9448f69d046ef6fe01/tgan2/models/c3d/c3d_ucf101.py
+ # It is a UCF-101-finetuned C3D model.
+ detector_url = 'https://www.dropbox.com/s/jxpu7avzdc9n97q/c3d_ucf101.pt?dl=1'
+ else:
+ raise NotImplementedError(f'Backbone {backbone} is not supported.')
+
+ num_frames = 16
+ batch_size = NUM_FRAMES_IN_BATCH[opts.dataset_kwargs.resolution] // num_frames
+
+ if opts.generator_as_dataset:
+ compute_gen_stats_fn = metric_utils.compute_feature_stats_for_dataset
+ gen_opts = metric_utils.rewrite_opts_for_gen_dataset(opts)
+ gen_opts.dataset_kwargs.load_n_consecutive = num_frames
+ gen_opts.dataset_kwargs.load_n_consecutive_random_offset = False
+ gen_opts.dataset_kwargs.subsample_factor = 1
+ gen_kwargs = dict()
+ else:
+ compute_gen_stats_fn = metric_utils.compute_feature_stats_for_generator
+ gen_opts = opts
+ gen_kwargs = dict(num_video_frames=num_frames, subsample_factor=1)
+
+ gen_probs = compute_gen_stats_fn(
+ opts=gen_opts, detector_url=detector_url, detector_kwargs={},
+ capture_all=True, max_items=num_gen, temporal_detector=True, **gen_kwargs).get_all() # [num_gen, num_classes]
+
+ if opts.rank != 0:
+ return float('nan'), float('nan')
+
+ scores = []
+ np.random.RandomState(42).shuffle(gen_probs)
+ for i in range(num_splits):
+ part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
+ kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
+ kl = np.mean(np.sum(kl, axis=1))
+ scores.append(np.exp(kl))
+ return float(np.mean(scores)), float(np.std(scores))
+
+#----------------------------------------------------------------------------
diff --git a/tools/torch_utils/__init__.py b/tools/torch_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ece0ea08fe2e939cc260a1dafc0ab5b391b773d9
--- /dev/null
+++ b/tools/torch_utils/__init__.py
@@ -0,0 +1,9 @@
+๏ปฟ# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+# empty
diff --git a/tools/torch_utils/custom_ops.py b/tools/torch_utils/custom_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cc4e43fc6f6ce79f2bd68a44ba87990b9b8564e
--- /dev/null
+++ b/tools/torch_utils/custom_ops.py
@@ -0,0 +1,126 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import os
+import glob
+import torch
+import torch.utils.cpp_extension
+import importlib
+import hashlib
+import shutil
+from pathlib import Path
+
+from torch.utils.file_baton import FileBaton
+
+#----------------------------------------------------------------------------
+# Global options.
+
+verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
+
+#----------------------------------------------------------------------------
+# Internal helper funcs.
+
+def _find_compiler_bindir():
+ patterns = [
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
+ ]
+ for pattern in patterns:
+ matches = sorted(glob.glob(pattern))
+ if len(matches):
+ return matches[-1]
+ return None
+
+#----------------------------------------------------------------------------
+# Main entry point for compiling and loading C++/CUDA plugins.
+
+_cached_plugins = dict()
+
+def get_plugin(module_name, sources, **build_kwargs):
+ assert verbosity in ['none', 'brief', 'full']
+
+ # Already cached?
+ if module_name in _cached_plugins:
+ return _cached_plugins[module_name]
+
+ # Print status.
+ if verbosity == 'full':
+ print(f'Setting up PyTorch plugin "{module_name}"...')
+ elif verbosity == 'brief':
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
+
+ try: # pylint: disable=too-many-nested-blocks
+ # Make sure we can find the necessary compiler binaries.
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
+ compiler_bindir = _find_compiler_bindir()
+ if compiler_bindir is None:
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
+ os.environ['PATH'] += ';' + compiler_bindir
+
+ # Compile and load.
+ verbose_build = (verbosity == 'full')
+
+ # Incremental build md5sum trickery. Copies all the input source files
+ # into a cached build directory under a combined md5 digest of the input
+ # source files. Copying is done only if the combined digest has changed.
+ # This keeps input file timestamps and filenames the same as in previous
+ # extension builds, allowing for fast incremental rebuilds.
+ #
+ # This optimization is done only in case all the source files reside in
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
+ # environment variable is set (we take this as a signal that the user
+ # actually cares about this.)
+ source_dirs_set = set(os.path.dirname(source) for source in sources)
+ if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
+ all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
+
+ # Compute a combined hash digest for all source files in the same
+ # custom op directory (usually .cu, .cpp, .py and .h files).
+ hash_md5 = hashlib.md5()
+ for src in all_source_files:
+ with open(src, 'rb') as f:
+ hash_md5.update(f.read())
+ build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
+ digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
+
+ if not os.path.isdir(digest_build_dir):
+ os.makedirs(digest_build_dir, exist_ok=True)
+ baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
+ if baton.try_acquire():
+ try:
+ for src in all_source_files:
+ shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
+ finally:
+ baton.release()
+ else:
+ # Someone else is copying source files under the digest dir,
+ # wait until done and continue.
+ baton.wait()
+ digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
+ torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
+ verbose=verbose_build, sources=digest_sources, **build_kwargs)
+ else:
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
+ module = importlib.import_module(module_name)
+
+ except:
+ if verbosity == 'brief':
+ print('Failed!')
+ raise
+
+ # Print status and add to cache.
+ if verbosity == 'full':
+ print(f'Done setting up PyTorch plugin "{module_name}".')
+ elif verbosity == 'brief':
+ print('Done.')
+ _cached_plugins[module_name] = module
+ return module
+
+#----------------------------------------------------------------------------
diff --git a/tools/torch_utils/misc.py b/tools/torch_utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad57fa748b1ac4bfab6f0d07d31950766d95f44e
--- /dev/null
+++ b/tools/torch_utils/misc.py
@@ -0,0 +1,274 @@
+๏ปฟ# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import re
+import contextlib
+import numpy as np
+import torch
+import warnings
+from tools import dnnlib
+
+#----------------------------------------------------------------------------
+# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
+# same constant is used multiple times.
+
+_constant_cache = dict()
+
+def constant(value, shape=None, dtype=None, device=None, memory_format=None):
+ value = np.asarray(value)
+ if shape is not None:
+ shape = tuple(shape)
+ if dtype is None:
+ dtype = torch.get_default_dtype()
+ if device is None:
+ device = torch.device('cpu')
+ if memory_format is None:
+ memory_format = torch.contiguous_format
+
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
+ tensor = _constant_cache.get(key, None)
+ if tensor is None:
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
+ if shape is not None:
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
+ tensor = tensor.contiguous(memory_format=memory_format)
+ _constant_cache[key] = tensor
+ return tensor
+
+#----------------------------------------------------------------------------
+# Replace NaN/Inf with specified numerical values.
+
+try:
+ nan_to_num = torch.nan_to_num # 1.8.0a0
+except AttributeError:
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
+ assert isinstance(input, torch.Tensor)
+ if posinf is None:
+ posinf = torch.finfo(input.dtype).max
+ if neginf is None:
+ neginf = torch.finfo(input.dtype).min
+ assert nan == 0
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
+
+#----------------------------------------------------------------------------
+# Symbolic assert.
+
+try:
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
+except AttributeError:
+ symbolic_assert = torch.Assert # 1.7.0
+
+#----------------------------------------------------------------------------
+# Context manager to suppress known warnings in torch.jit.trace().
+
+class suppress_tracer_warnings(warnings.catch_warnings):
+ def __enter__(self):
+ super().__enter__()
+ warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
+ return self
+
+#----------------------------------------------------------------------------
+# Assert that the shape of a tensor matches the given list of integers.
+# None indicates that the size of a dimension is allowed to vary.
+# Performs symbolic assertion when used in torch.jit.trace().
+
+def assert_shape(tensor, ref_shape):
+ err_suffix = f' for tensor of size {list(tensor.shape)}'
+ if tensor.ndim != len(ref_shape):
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}{err_suffix}')
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
+ if ref_size is None:
+ pass
+ elif isinstance(ref_size, torch.Tensor):
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}{err_suffix}')
+ elif isinstance(size, torch.Tensor):
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}{err_suffix}')
+ elif size != ref_size:
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}{err_suffix}')
+
+#----------------------------------------------------------------------------
+# Function decorator that calls torch.autograd.profiler.record_function().
+
+def profiled_function(fn):
+ def decorator(*args, **kwargs):
+ with torch.autograd.profiler.record_function(fn.__name__):
+ return fn(*args, **kwargs)
+ decorator.__name__ = fn.__name__
+ return decorator
+
+#----------------------------------------------------------------------------
+# Sampler for torch.utils.data.DataLoader that loops over the dataset
+# indefinitely, shuffling items as it goes.
+
+class InfiniteSampler(torch.utils.data.Sampler):
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
+ assert len(dataset) > 0
+ assert num_replicas > 0
+ assert 0 <= rank < num_replicas
+ assert 0 <= window_size <= 1
+ super().__init__(dataset)
+ self.dataset = dataset
+ self.rank = rank
+ self.num_replicas = num_replicas
+ self.shuffle = shuffle
+ self.seed = seed
+ self.window_size = window_size
+
+ def __iter__(self):
+ order = np.arange(len(self.dataset))
+ rnd = None
+ window = 0
+ if self.shuffle:
+ rnd = np.random.RandomState(self.seed)
+ rnd.shuffle(order)
+ window = int(np.rint(order.size * self.window_size))
+
+ idx = 0
+ while True:
+ i = idx % order.size
+ if idx % self.num_replicas == self.rank:
+ yield order[i]
+ if window >= 2:
+ j = (i - rnd.randint(window)) % order.size
+ order[i], order[j] = order[j], order[i]
+ idx += 1
+
+#----------------------------------------------------------------------------
+# Utilities for operating with torch.nn.Module parameters and buffers.
+
+def params_and_buffers(module):
+ assert isinstance(module, torch.nn.Module)
+ return list(module.parameters()) + list(module.buffers())
+
+def named_params_and_buffers(module):
+ assert isinstance(module, torch.nn.Module)
+ return list(module.named_parameters()) + list(module.named_buffers())
+
+def copy_params_and_buffers(src_module, dst_module, require_all=False):
+ assert isinstance(src_module, torch.nn.Module)
+ assert isinstance(dst_module, torch.nn.Module)
+ src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
+ for name, tensor in named_params_and_buffers(dst_module):
+ assert (name in src_tensors) or (not require_all)
+ if name in src_tensors:
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
+
+#----------------------------------------------------------------------------
+# Context manager for easily enabling/disabling DistributedDataParallel
+# synchronization.
+
+@contextlib.contextmanager
+def ddp_sync(module, sync):
+ assert isinstance(module, torch.nn.Module)
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
+ yield
+ else:
+ with module.no_sync():
+ yield
+
+#----------------------------------------------------------------------------
+# Check DistributedDataParallel consistency across processes.
+
+def check_ddp_consistency(module, ignore_regex=None):
+ assert isinstance(module, torch.nn.Module)
+ for name, tensor in named_params_and_buffers(module):
+ fullname = type(module).__name__ + '.' + name
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
+ continue
+ tensor = tensor.detach()
+ other = tensor.clone()
+ torch.distributed.broadcast(tensor=other, src=0)
+ assert (nan_to_num(tensor) == nan_to_num(other)).all(), f'{fullname} is not DDP consistent'
+
+#----------------------------------------------------------------------------
+# Print summary table of module hierarchy.
+
+def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
+ assert isinstance(module, torch.nn.Module)
+ assert not isinstance(module, torch.jit.ScriptModule)
+ assert isinstance(inputs, (tuple, list))
+
+ # Register hooks.
+ entries = []
+ nesting = [0]
+ def pre_hook(_mod, _inputs):
+ nesting[0] += 1
+ def post_hook(mod, module_inputs, outputs):
+ nesting[0] -= 1
+ if nesting[0] <= max_nesting:
+ module_inputs = list(module_inputs) if isinstance(module_inputs, (tuple, list)) else [module_inputs]
+ module_inputs = [t for t in module_inputs if isinstance(t, torch.Tensor)]
+ if isinstance(outputs, (tuple, list)):
+ outputs = list(outputs)
+ elif isinstance(outputs, dict):
+ outputs = list(outputs.values())
+ else:
+ outputs = [outputs]
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
+ entries.append(dnnlib.EasyDict(mod=mod, inputs=module_inputs, outputs=outputs))
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
+
+ # Run module.
+ outputs = module(*inputs)
+ for hook in hooks:
+ hook.remove()
+
+ # Identify unique outputs, parameters, and buffers.
+ tensors_seen = set()
+ for e in entries:
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
+
+ # Filter out redundant entries.
+ if skip_redundant:
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
+
+ # Construct table.
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Input Shape', 'Output shape', 'Datatype']]
+ rows += [['---'] * len(rows[0])]
+ param_total = 0
+ buffer_total = 0
+ submodule_names = {mod: name for name, mod in module.named_modules()}
+ for e in entries:
+ name = '' if e.mod is module else submodule_names[e.mod]
+ param_size = sum(t.numel() for t in e.unique_params)
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
+ input_shape_str = ' + '.join([str(list(t.shape)) for t in e.inputs])
+ output_shapes = [str(list(t.shape)) for t in e.outputs]
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
+ rows += [[
+ name + (':0' if len(e.outputs) >= 2 else ''),
+ str(param_size) if param_size else '-',
+ str(buffer_size) if buffer_size else '-',
+ input_shape_str if len(input_shape_str) > 0 else '-',
+ (output_shapes + ['-'])[0],
+ (output_dtypes + ['-'])[0],
+ ]]
+ for idx in range(1, len(e.outputs)):
+ rows += [[name + f':{idx}', '-', '-', '-', output_shapes[idx], output_dtypes[idx]]]
+ param_total += param_size
+ buffer_total += buffer_size
+ rows += [['---'] * len(rows[0])]
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-', '-']]
+ row_lengths = [len(r) for r in rows]
+ assert len(set(row_lengths)) == 1, f"Summary table contains rows of different lengths: {row_lengths}"
+
+ # Print table.
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
+ print()
+ for row in rows:
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
+ print()
+ return outputs
+
+#----------------------------------------------------------------------------
diff --git a/tools/torch_utils/ops/__init__.py b/tools/torch_utils/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ece0ea08fe2e939cc260a1dafc0ab5b391b773d9
--- /dev/null
+++ b/tools/torch_utils/ops/__init__.py
@@ -0,0 +1,9 @@
+๏ปฟ# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+# empty
diff --git a/tools/torch_utils/ops/bias_act.cpp b/tools/torch_utils/ops/bias_act.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..5d2425d8054991a8e8b6f7a940fd0ff7fa0bb330
--- /dev/null
+++ b/tools/torch_utils/ops/bias_act.cpp
@@ -0,0 +1,99 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include
+#include
+#include "bias_act.h"
+
+//------------------------------------------------------------------------
+
+static bool has_same_layout(torch::Tensor x, torch::Tensor y)
+{
+ if (x.dim() != y.dim())
+ return false;
+ for (int64_t i = 0; i < x.dim(); i++)
+ {
+ if (x.size(i) != y.size(i))
+ return false;
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
+ return false;
+ }
+ return true;
+}
+
+//------------------------------------------------------------------------
+
+static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
+{
+ // Validate arguments.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
+
+ // Validate layout.
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
+
+ // Create output tensor.
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+ torch::Tensor y = torch::empty_like(x);
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
+
+ // Initialize CUDA kernel parameters.
+ bias_act_kernel_params p;
+ p.x = x.data_ptr();
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
+ p.y = y.data_ptr();
+ p.grad = grad;
+ p.act = act;
+ p.alpha = alpha;
+ p.gain = gain;
+ p.clamp = clamp;
+ p.sizeX = (int)x.numel();
+ p.sizeB = (int)b.numel();
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
+
+ // Choose CUDA kernel.
+ void* kernel;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
+ {
+ kernel = choose_bias_act_kernel(p);
+ });
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
+
+ // Launch CUDA kernel.
+ p.loopX = 4;
+ int blockSize = 4 * 32;
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
+ void* args[] = {&p};
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
+ return y;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("bias_act", &bias_act);
+}
+
+//------------------------------------------------------------------------
diff --git a/tools/torch_utils/ops/bias_act.cu b/tools/torch_utils/ops/bias_act.cu
new file mode 100644
index 0000000000000000000000000000000000000000..dd8fc4756d7d94727f94af738665b68d9c518880
--- /dev/null
+++ b/tools/torch_utils/ops/bias_act.cu
@@ -0,0 +1,173 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include "bias_act.h"
+
+//------------------------------------------------------------------------
+// Helpers.
+
+template struct InternalType;
+template <> struct InternalType { typedef double scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+
+//------------------------------------------------------------------------
+// CUDA kernel.
+
+template
+__global__ void bias_act_kernel(bias_act_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+ int G = p.grad;
+ scalar_t alpha = (scalar_t)p.alpha;
+ scalar_t gain = (scalar_t)p.gain;
+ scalar_t clamp = (scalar_t)p.clamp;
+ scalar_t one = (scalar_t)1;
+ scalar_t two = (scalar_t)2;
+ scalar_t expRange = (scalar_t)80;
+ scalar_t halfExpRange = (scalar_t)40;
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
+
+ // Loop over elements.
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
+ {
+ // Load.
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
+ scalar_t y = 0;
+
+ // Apply bias.
+ ((G == 0) ? x : xref) += b;
+
+ // linear
+ if (A == 1)
+ {
+ if (G == 0) y = x;
+ if (G == 1) y = x;
+ }
+
+ // relu
+ if (A == 2)
+ {
+ if (G == 0) y = (x > 0) ? x : 0;
+ if (G == 1) y = (yy > 0) ? x : 0;
+ }
+
+ // lrelu
+ if (A == 3)
+ {
+ if (G == 0) y = (x > 0) ? x : x * alpha;
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
+ }
+
+ // tanh
+ if (A == 4)
+ {
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
+ if (G == 1) y = x * (one - yy * yy);
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
+ }
+
+ // sigmoid
+ if (A == 5)
+ {
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
+ if (G == 1) y = x * yy * (one - yy);
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
+ }
+
+ // elu
+ if (A == 6)
+ {
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
+ }
+
+ // selu
+ if (A == 7)
+ {
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
+ }
+
+ // softplus
+ if (A == 8)
+ {
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
+ if (G == 1) y = x * (one - exp(-yy));
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
+ }
+
+ // swish
+ if (A == 9)
+ {
+ if (G == 0)
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
+ else
+ {
+ scalar_t c = exp(xref);
+ scalar_t d = c + one;
+ if (G == 1)
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
+ else
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
+ }
+ }
+
+ // Apply gain.
+ y *= gain * dy;
+
+ // Clamp.
+ if (clamp >= 0)
+ {
+ if (G == 0)
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
+ else
+ y = (yref > -clamp & yref < clamp) ? y : 0;
+ }
+
+ // Store.
+ ((T*)p.y)[xi] = (T)y;
+ }
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
+{
+ if (p.act == 1) return (void*)bias_act_kernel;
+ if (p.act == 2) return (void*)bias_act_kernel;
+ if (p.act == 3) return (void*)bias_act_kernel;
+ if (p.act == 4) return (void*)bias_act_kernel;
+ if (p.act == 5) return (void*)bias_act_kernel;
+ if (p.act == 6) return (void*)bias_act_kernel;
+ if (p.act == 7) return (void*)bias_act_kernel;
+ if (p.act == 8) return (void*)bias_act_kernel;
+ if (p.act == 9) return (void*)bias_act_kernel;
+ return NULL;
+}
+
+//------------------------------------------------------------------------
+// Template specializations.
+
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/tools/torch_utils/ops/bias_act.h b/tools/torch_utils/ops/bias_act.h
new file mode 100644
index 0000000000000000000000000000000000000000..a32187e1fb7e3bae509d4eceaf900866866875a4
--- /dev/null
+++ b/tools/torch_utils/ops/bias_act.h
@@ -0,0 +1,38 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+//------------------------------------------------------------------------
+// CUDA kernel parameters.
+
+struct bias_act_kernel_params
+{
+ const void* x; // [sizeX]
+ const void* b; // [sizeB] or NULL
+ const void* xref; // [sizeX] or NULL
+ const void* yref; // [sizeX] or NULL
+ const void* dy; // [sizeX] or NULL
+ void* y; // [sizeX]
+
+ int grad;
+ int act;
+ float alpha;
+ float gain;
+ float clamp;
+
+ int sizeX;
+ int sizeB;
+ int stepB;
+ int loopX;
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/tools/torch_utils/ops/bias_act.py b/tools/torch_utils/ops/bias_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..f64c4a7c5b22cb5240c5175a398a748d0d0e6db1
--- /dev/null
+++ b/tools/torch_utils/ops/bias_act.py
@@ -0,0 +1,212 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Custom PyTorch ops for efficient bias and activation."""
+
+import os
+import warnings
+import numpy as np
+import torch
+from tools import dnnlib
+import traceback
+
+from .. import custom_ops
+from .. import misc
+
+#----------------------------------------------------------------------------
+
+activation_funcs = {
+ 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
+ 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
+ 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
+ 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
+ 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
+ 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
+ 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
+ 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
+ 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
+}
+
+#----------------------------------------------------------------------------
+
+_inited = False
+_plugin = None
+_null_tensor = torch.empty([0])
+
+def _init():
+ global _inited, _plugin
+ if not _inited:
+ _inited = True
+ sources = ['bias_act.cpp', 'bias_act.cu']
+ sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
+ try:
+ _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
+ except:
+ warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
+ return _plugin is not None
+
+#----------------------------------------------------------------------------
+
+def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
+ r"""Fused bias and activation function.
+
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
+ the fused op is considerably more efficient than performing the same calculation
+ using standard PyTorch ops. It supports first and second order gradients,
+ but not third order gradients.
+
+ Args:
+ x: Input activation tensor. Can be of any shape.
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
+ as `x`. The shape must be known, and it must match the dimension of `x`
+ corresponding to `dim`.
+ dim: The dimension in `x` corresponding to the elements of `b`.
+ The value of `dim` is ignored if `b` is not specified.
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
+ See `activation_funcs` for a full list. `None` is not allowed.
+ alpha: Shape parameter for the activation function, or `None` to use the default.
+ gain: Scaling factor for the output tensor, or `None` to use default.
+ See `activation_funcs` for the default scaling of each activation function.
+ If unsure, consider specifying 1.
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
+ the clamping (default).
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
+
+ Returns:
+ Tensor of the same shape and datatype as `x`.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert impl in ['ref', 'cuda']
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Add bias.
+ if b is not None:
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
+ assert 0 <= dim < x.ndim
+ assert b.shape[0] == x.shape[dim]
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
+
+ # Evaluate activation function.
+ alpha = float(alpha)
+ x = spec.func(x, alpha=alpha)
+
+ # Scale by gain.
+ gain = float(gain)
+ if gain != 1:
+ x = x * gain
+
+ # Clamp.
+ if clamp >= 0:
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
+ return x
+
+#----------------------------------------------------------------------------
+
+_bias_act_cuda_cache = dict()
+
+def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
+ """Fast CUDA implementation of `bias_act()` using custom ops.
+ """
+ # Parse arguments.
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Lookup from cache.
+ key = (dim, act, alpha, gain, clamp)
+ if key in _bias_act_cuda_cache:
+ return _bias_act_cuda_cache[key]
+
+ # Forward op.
+ class BiasActCuda(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
+ x = x.contiguous(memory_format=ctx.memory_format)
+ b = b.contiguous() if b is not None else _null_tensor
+ y = x
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
+ ctx.save_for_backward(
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ y if 'y' in spec.ref else _null_tensor)
+ return y
+
+ @staticmethod
+ def backward(ctx, dy): # pylint: disable=arguments-differ
+ dy = dy.contiguous(memory_format=ctx.memory_format)
+ x, b, y = ctx.saved_tensors
+ dx = None
+ db = None
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ dx = dy
+ if act != 'linear' or gain != 1 or clamp >= 0:
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
+
+ if ctx.needs_input_grad[1]:
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
+
+ return dx, db
+
+ # Backward op.
+ class BiasActCudaGrad(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
+ ctx.save_for_backward(
+ dy if spec.has_2nd_grad else _null_tensor,
+ x, b, y)
+ return dx
+
+ @staticmethod
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
+ dy, x, b, y = ctx.saved_tensors
+ d_dy = None
+ d_x = None
+ d_b = None
+ d_y = None
+
+ if ctx.needs_input_grad[0]:
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
+
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
+
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
+
+ return d_dy, d_x, d_b, d_y
+
+ # Add to cache.
+ _bias_act_cuda_cache[key] = BiasActCuda
+ return BiasActCuda
+
+#----------------------------------------------------------------------------
diff --git a/tools/torch_utils/ops/conv2d_gradfix.py b/tools/torch_utils/ops/conv2d_gradfix.py
new file mode 100644
index 0000000000000000000000000000000000000000..44ed80fbc9d4bd0d749f8ead245826c3bc6f5e4c
--- /dev/null
+++ b/tools/torch_utils/ops/conv2d_gradfix.py
@@ -0,0 +1,170 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Custom replacement for `torch.nn.functional.conv2d` that supports
+arbitrarily high order gradients with zero performance penalty."""
+
+import warnings
+import contextlib
+import torch
+
+# pylint: disable=redefined-builtin
+# pylint: disable=arguments-differ
+# pylint: disable=protected-access
+
+#----------------------------------------------------------------------------
+
+enabled = False # Enable the custom op by setting this to true.
+weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
+
+@contextlib.contextmanager
+def no_weight_gradients():
+ global weight_gradients_disabled
+ old = weight_gradients_disabled
+ weight_gradients_disabled = True
+ yield
+ weight_gradients_disabled = old
+
+#----------------------------------------------------------------------------
+
+def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ if _should_use_custom_op(input):
+ return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+
+def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
+ if _should_use_custom_op(input):
+ return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
+
+#----------------------------------------------------------------------------
+
+def _should_use_custom_op(input):
+ assert isinstance(input, torch.Tensor)
+ if (not enabled) or (not torch.backends.cudnn.enabled):
+ return False
+ if input.device.type != 'cuda':
+ return False
+ if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9', '1.10']):
+ return True
+ warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
+ return False
+
+def _tuple_of_ints(xs, ndim):
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
+ assert len(xs) == ndim
+ assert all(isinstance(x, int) for x in xs)
+ return xs
+
+#----------------------------------------------------------------------------
+
+_conv2d_gradfix_cache = dict()
+
+def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
+ # Parse arguments.
+ ndim = 2
+ weight_shape = tuple(weight_shape)
+ stride = _tuple_of_ints(stride, ndim)
+ padding = _tuple_of_ints(padding, ndim)
+ output_padding = _tuple_of_ints(output_padding, ndim)
+ dilation = _tuple_of_ints(dilation, ndim)
+
+ # Lookup from cache.
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
+ if key in _conv2d_gradfix_cache:
+ return _conv2d_gradfix_cache[key]
+
+ # Validate arguments.
+ assert groups >= 1
+ assert len(weight_shape) == ndim + 2
+ assert all(stride[i] >= 1 for i in range(ndim))
+ assert all(padding[i] >= 0 for i in range(ndim))
+ assert all(dilation[i] >= 0 for i in range(ndim))
+ if not transpose:
+ assert all(output_padding[i] == 0 for i in range(ndim))
+ else: # transpose
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
+
+ # Helpers.
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
+ def calc_output_padding(input_shape, output_shape):
+ if transpose:
+ return [0, 0]
+ return [
+ input_shape[i + 2]
+ - (output_shape[i + 2] - 1) * stride[i]
+ - (1 - 2 * padding[i])
+ - dilation[i] * (weight_shape[i + 2] - 1)
+ for i in range(ndim)
+ ]
+
+ # Forward & backward.
+ class Conv2d(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, weight, bias):
+ assert weight.shape == weight_shape
+ if not transpose:
+ output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
+ else: # transpose
+ output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
+ ctx.save_for_backward(input, weight)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, weight = ctx.saved_tensors
+ grad_input = None
+ grad_weight = None
+ grad_bias = None
+
+ if ctx.needs_input_grad[0]:
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
+ grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
+ assert grad_input.shape == input.shape
+
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
+ assert grad_weight.shape == weight_shape
+
+ if ctx.needs_input_grad[2]:
+ grad_bias = grad_output.sum([0, 2, 3])
+
+ return grad_input, grad_weight, grad_bias
+
+ # Gradient with respect to the weights.
+ class Conv2dGradWeight(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, grad_output, input):
+ op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
+ flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
+ grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
+ assert grad_weight.shape == weight_shape
+ ctx.save_for_backward(grad_output, input)
+ return grad_weight
+
+ @staticmethod
+ def backward(ctx, grad2_grad_weight):
+ grad_output, input = ctx.saved_tensors
+ grad2_grad_output = None
+ grad2_input = None
+
+ if ctx.needs_input_grad[0]:
+ grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
+ assert grad2_grad_output.shape == grad_output.shape
+
+ if ctx.needs_input_grad[1]:
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
+ grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
+ assert grad2_input.shape == input.shape
+
+ return grad2_grad_output, grad2_input
+
+ _conv2d_gradfix_cache[key] = Conv2d
+ return Conv2d
+
+#----------------------------------------------------------------------------
diff --git a/tools/torch_utils/ops/conv2d_resample.py b/tools/torch_utils/ops/conv2d_resample.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd4750744c83354bab78704d4ef51ad1070fcc4a
--- /dev/null
+++ b/tools/torch_utils/ops/conv2d_resample.py
@@ -0,0 +1,156 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""2D convolution with optional up/downsampling."""
+
+import torch
+
+from .. import misc
+from . import conv2d_gradfix
+from . import upfirdn2d
+from .upfirdn2d import _parse_padding
+from .upfirdn2d import _get_filter_size
+
+#----------------------------------------------------------------------------
+
+def _get_weight_shape(w):
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
+ shape = [int(sz) for sz in w.shape]
+ misc.assert_shape(w, shape)
+ return shape
+
+#----------------------------------------------------------------------------
+
+def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
+ """
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
+
+ # Flip weight if requested.
+ if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
+ w = w.flip([2, 3])
+
+ # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
+ # 1x1 kernel + memory_format=channels_last + less than 64 channels.
+ if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
+ if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
+ if out_channels <= 4 and groups == 1:
+ in_shape = x.shape
+ x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
+ x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
+ else:
+ x = x.to(memory_format=torch.contiguous_format)
+ w = w.to(memory_format=torch.contiguous_format)
+ x = conv2d_gradfix.conv2d(x, w, groups=groups)
+ return x.to(memory_format=torch.channels_last)
+
+ # Otherwise => execute using conv2d_gradfix.
+ op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
+ return op(x, w, stride=stride, padding=padding, groups=groups)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
+ r"""2D convolution with optional up/downsampling.
+
+ Padding is performed only once at the beginning, not between the operations.
+
+ Args:
+ x: Input tensor of shape
+ `[batch_size, in_channels, in_height, in_width]`.
+ w: Weight tensor of shape
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
+ calling upfirdn2d.setup_filter(). None = identity (default).
+ up: Integer upsampling factor (default: 1).
+ down: Integer downsampling factor (default: 1).
+ padding: Padding with respect to the upsampled image. Can be a single number
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ groups: Split input channels into N groups (default: 1).
+ flip_weight: False = convolution, True = correlation (default: True).
+ flip_filter: False = convolution, True = correlation (default: False).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ # Validate arguments.
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
+ assert isinstance(up, int) and (up >= 1)
+ assert isinstance(down, int) and (down >= 1)
+ assert isinstance(groups, int) and (groups >= 1)
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
+ fw, fh = _get_filter_size(f)
+ px0, px1, py0, py1 = _parse_padding(padding)
+
+ # Adjust padding to account for up/downsampling.
+ if up > 1:
+ px0 += (fw + up - 1) // 2
+ px1 += (fw - up) // 2
+ py0 += (fh + up - 1) // 2
+ py1 += (fh - up) // 2
+ if down > 1:
+ px0 += (fw - down + 1) // 2
+ px1 += (fw - down) // 2
+ py0 += (fh - down + 1) // 2
+ py1 += (fh - down) // 2
+
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ return x
+
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
+ return x
+
+ # Fast path: downsampling only => use strided convolution.
+ if down > 1 and up == 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
+ return x
+
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
+ if up > 1:
+ if groups == 1:
+ w = w.transpose(0, 1)
+ else:
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
+ w = w.transpose(1, 2)
+ w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
+ px0 -= kw - 1
+ px1 -= kw - up
+ py0 -= kh - 1
+ py1 -= kh - up
+ pxt = max(min(-px0, -px1), 0)
+ pyt = max(min(-py0, -py1), 0)
+ x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
+ if down > 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
+ return x
+
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
+ if up == 1 and down == 1:
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
+ return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
+
+ # Fallback: Generic reference implementation.
+ x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ if down > 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
+ return x
+
+#----------------------------------------------------------------------------
diff --git a/tools/torch_utils/ops/fma.py b/tools/torch_utils/ops/fma.py
new file mode 100644
index 0000000000000000000000000000000000000000..2eeac58a626c49231e04122b93e321ada954c5d3
--- /dev/null
+++ b/tools/torch_utils/ops/fma.py
@@ -0,0 +1,60 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
+
+import torch
+
+#----------------------------------------------------------------------------
+
+def fma(a, b, c): # => a * b + c
+ return _FusedMultiplyAdd.apply(a, b, c)
+
+#----------------------------------------------------------------------------
+
+class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
+ @staticmethod
+ def forward(ctx, a, b, c): # pylint: disable=arguments-differ
+ out = torch.addcmul(c, a, b)
+ ctx.save_for_backward(a, b)
+ ctx.c_shape = c.shape
+ return out
+
+ @staticmethod
+ def backward(ctx, dout): # pylint: disable=arguments-differ
+ a, b = ctx.saved_tensors
+ c_shape = ctx.c_shape
+ da = None
+ db = None
+ dc = None
+
+ if ctx.needs_input_grad[0]:
+ da = _unbroadcast(dout * b, a.shape)
+
+ if ctx.needs_input_grad[1]:
+ db = _unbroadcast(dout * a, b.shape)
+
+ if ctx.needs_input_grad[2]:
+ dc = _unbroadcast(dout, c_shape)
+
+ return da, db, dc
+
+#----------------------------------------------------------------------------
+
+def _unbroadcast(x, shape):
+ extra_dims = x.ndim - len(shape)
+ assert extra_dims >= 0
+ dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
+ if len(dim):
+ x = x.sum(dim=dim, keepdim=True)
+ if extra_dims:
+ x = x.reshape(-1, *x.shape[extra_dims+1:])
+ assert x.shape == shape
+ return x
+
+#----------------------------------------------------------------------------
diff --git a/tools/torch_utils/ops/grid_sample_gradfix.py b/tools/torch_utils/ops/grid_sample_gradfix.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca6b3413ea72a734703c34382c023b84523601fd
--- /dev/null
+++ b/tools/torch_utils/ops/grid_sample_gradfix.py
@@ -0,0 +1,83 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Custom replacement for `torch.nn.functional.grid_sample` that
+supports arbitrarily high order gradients between the input and output.
+Only works on 2D images and assumes
+`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
+
+import warnings
+import torch
+
+# pylint: disable=redefined-builtin
+# pylint: disable=arguments-differ
+# pylint: disable=protected-access
+
+#----------------------------------------------------------------------------
+
+enabled = False # Enable the custom op by setting this to true.
+
+#----------------------------------------------------------------------------
+
+def grid_sample(input, grid):
+ if _should_use_custom_op():
+ return _GridSample2dForward.apply(input, grid)
+ return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
+
+#----------------------------------------------------------------------------
+
+def _should_use_custom_op():
+ if not enabled:
+ return False
+ if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
+ return True
+ warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
+ return False
+
+#----------------------------------------------------------------------------
+
+class _GridSample2dForward(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, grid):
+ assert input.ndim == 4
+ assert grid.ndim == 4
+ output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
+ ctx.save_for_backward(input, grid)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, grid = ctx.saved_tensors
+ grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
+ return grad_input, grad_grid
+
+#----------------------------------------------------------------------------
+
+class _GridSample2dBackward(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, grad_output, input, grid):
+ op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
+ ctx.save_for_backward(grid)
+ return grad_input, grad_grid
+
+ @staticmethod
+ def backward(ctx, grad2_grad_input, grad2_grad_grid):
+ _ = grad2_grad_grid # unused
+ grid, = ctx.saved_tensors
+ grad2_grad_output = None
+ grad2_input = None
+ grad2_grid = None
+
+ if ctx.needs_input_grad[0]:
+ grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
+
+ assert not ctx.needs_input_grad[2]
+ return grad2_grad_output, grad2_input, grad2_grid
+
+#----------------------------------------------------------------------------
diff --git a/tools/torch_utils/ops/upfirdn2d.cpp b/tools/torch_utils/ops/upfirdn2d.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..2d7177fc60040751d20e9a8da0301fa3ab64968a
--- /dev/null
+++ b/tools/torch_utils/ops/upfirdn2d.cpp
@@ -0,0 +1,103 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include
+#include
+#include "upfirdn2d.h"
+
+//------------------------------------------------------------------------
+
+static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
+{
+ // Validate arguments.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
+ TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
+ TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+ TORCH_CHECK(f.dim() == 2, "f must be rank 2");
+ TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
+ TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
+ TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
+
+ // Create output tensor.
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+ int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
+ int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
+ TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
+ TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
+
+ // Initialize CUDA kernel parameters.
+ upfirdn2d_kernel_params p;
+ p.x = x.data_ptr();
+ p.f = f.data_ptr();
+ p.y = y.data_ptr();
+ p.up = make_int2(upx, upy);
+ p.down = make_int2(downx, downy);
+ p.pad0 = make_int2(padx0, pady0);
+ p.flip = (flip) ? 1 : 0;
+ p.gain = gain;
+ p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+ p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
+ p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
+ p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
+ p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
+ p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
+ p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
+ p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
+
+ // Choose CUDA kernel.
+ upfirdn2d_kernel_spec spec;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
+ {
+ spec = choose_upfirdn2d_kernel(p);
+ });
+
+ // Set looping options.
+ p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
+ p.loopMinor = spec.loopMinor;
+ p.loopX = spec.loopX;
+ p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
+ p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
+
+ // Compute grid size.
+ dim3 blockSize, gridSize;
+ if (spec.tileOutW < 0) // large
+ {
+ blockSize = dim3(4, 32, 1);
+ gridSize = dim3(
+ ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
+ (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
+ p.launchMajor);
+ }
+ else // small
+ {
+ blockSize = dim3(256, 1, 1);
+ gridSize = dim3(
+ ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
+ (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
+ p.launchMajor);
+ }
+
+ // Launch CUDA kernel.
+ void* args[] = {&p};
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
+ return y;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("upfirdn2d", &upfirdn2d);
+}
+
+//------------------------------------------------------------------------
diff --git a/tools/torch_utils/ops/upfirdn2d.cu b/tools/torch_utils/ops/upfirdn2d.cu
new file mode 100644
index 0000000000000000000000000000000000000000..ebdd9879f4bb16fc57a23cbc81f9de8ef54e4916
--- /dev/null
+++ b/tools/torch_utils/ops/upfirdn2d.cu
@@ -0,0 +1,350 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include "upfirdn2d.h"
+
+//------------------------------------------------------------------------
+// Helpers.
+
+template struct InternalType;
+template <> struct InternalType { typedef double scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+
+static __device__ __forceinline__ int floor_div(int a, int b)
+{
+ int t = 1 - a / b;
+ return (a + t * b) / b - t;
+}
+
+//------------------------------------------------------------------------
+// Generic CUDA implementation for large filters.
+
+template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+
+ // Calculate thread index.
+ int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
+ int outY = minorBase / p.launchMinor;
+ minorBase -= outY * p.launchMinor;
+ int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
+ int majorBase = blockIdx.z * p.loopMajor;
+ if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
+ return;
+
+ // Setup Y receptive field.
+ int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
+ int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
+ int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
+ int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
+ if (p.flip)
+ filterY = p.filterSize.y - 1 - filterY;
+
+ // Loop over major, minor, and X.
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
+ for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
+ {
+ int nc = major * p.sizeMinor + minor;
+ int n = nc / p.inSize.z;
+ int c = nc - n * p.inSize.z;
+ for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
+ {
+ // Setup X receptive field.
+ int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
+ int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
+ int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
+ int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
+ if (p.flip)
+ filterX = p.filterSize.x - 1 - filterX;
+
+ // Initialize pointers.
+ const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
+ const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
+ int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
+ int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
+
+ // Inner loop.
+ scalar_t v = 0;
+ for (int y = 0; y < h; y++)
+ {
+ for (int x = 0; x < w; x++)
+ {
+ v += (scalar_t)(*xp) * (scalar_t)(*fp);
+ xp += p.inStride.x;
+ fp += filterStepX;
+ }
+ xp += p.inStride.y - w * p.inStride.x;
+ fp += filterStepY - w * filterStepX;
+ }
+
+ // Store result.
+ v *= p.gain;
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
+ }
+ }
+}
+
+//------------------------------------------------------------------------
+// Specialized CUDA implementation for small filters.
+
+template
+static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+ const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
+ const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
+ __shared__ volatile scalar_t sf[filterH][filterW];
+ __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
+
+ // Calculate tile index.
+ int minorBase = blockIdx.x;
+ int tileOutY = minorBase / p.launchMinor;
+ minorBase -= tileOutY * p.launchMinor;
+ minorBase *= loopMinor;
+ tileOutY *= tileOutH;
+ int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
+ int majorBase = blockIdx.z * p.loopMajor;
+ if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
+ return;
+
+ // Load filter (flipped).
+ for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
+ {
+ int fy = tapIdx / filterW;
+ int fx = tapIdx - fy * filterW;
+ scalar_t v = 0;
+ if (fx < p.filterSize.x & fy < p.filterSize.y)
+ {
+ int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
+ int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
+ v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
+ }
+ sf[fy][fx] = v;
+ }
+
+ // Loop over major and X.
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
+ {
+ int baseNC = major * p.sizeMinor + minorBase;
+ int n = baseNC / p.inSize.z;
+ int baseC = baseNC - n * p.inSize.z;
+ for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
+ {
+ // Load input pixels.
+ int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
+ int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
+ int tileInX = floor_div(tileMidX, upx);
+ int tileInY = floor_div(tileMidY, upy);
+ __syncthreads();
+ for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
+ {
+ int relC = inIdx;
+ int relInX = relC / loopMinor;
+ int relInY = relInX / tileInW;
+ relC -= relInX * loopMinor;
+ relInX -= relInY * tileInW;
+ int c = baseC + relC;
+ int inX = tileInX + relInX;
+ int inY = tileInY + relInY;
+ scalar_t v = 0;
+ if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
+ v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
+ sx[relInY][relInX][relC] = v;
+ }
+
+ // Loop over output pixels.
+ __syncthreads();
+ for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
+ {
+ int relC = outIdx;
+ int relOutX = relC / loopMinor;
+ int relOutY = relOutX / tileOutW;
+ relC -= relOutX * loopMinor;
+ relOutX -= relOutY * tileOutW;
+ int c = baseC + relC;
+ int outX = tileOutX + relOutX;
+ int outY = tileOutY + relOutY;
+
+ // Setup receptive field.
+ int midX = tileMidX + relOutX * downx;
+ int midY = tileMidY + relOutY * downy;
+ int inX = floor_div(midX, upx);
+ int inY = floor_div(midY, upy);
+ int relInX = inX - tileInX;
+ int relInY = inY - tileInY;
+ int filterX = (inX + 1) * upx - midX - 1; // flipped
+ int filterY = (inY + 1) * upy - midY - 1; // flipped
+
+ // Inner loop.
+ if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
+ {
+ scalar_t v = 0;
+ #pragma unroll
+ for (int y = 0; y < filterH / upy; y++)
+ #pragma unroll
+ for (int x = 0; x < filterW / upx; x++)
+ v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
+ v *= p.gain;
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
+ }
+ }
+ }
+ }
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
+{
+ int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
+
+ upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous
+ if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last
+
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
+ {
+ if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ }
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
+ {
+ if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ }
+ if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
+ {
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ }
+ if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
+ {
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ }
+ if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
+ {
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ }
+ if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
+ {
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ }
+ if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
+ {
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ }
+ if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
+ {
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ }
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
+ {
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
+ }
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
+ {
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
+ }
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
+ {
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
+ }
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
+ {
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
+ }
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
+ {
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
+ }
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
+ {
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
+ }
+ return spec;
+}
+
+//------------------------------------------------------------------------
+// Template specializations.
+
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p);
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p);
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/tools/torch_utils/ops/upfirdn2d.h b/tools/torch_utils/ops/upfirdn2d.h
new file mode 100644
index 0000000000000000000000000000000000000000..c9e2032bcac9d2abde7a75eea4d812da348afadd
--- /dev/null
+++ b/tools/torch_utils/ops/upfirdn2d.h
@@ -0,0 +1,59 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+
+//------------------------------------------------------------------------
+// CUDA kernel parameters.
+
+struct upfirdn2d_kernel_params
+{
+ const void* x;
+ const float* f;
+ void* y;
+
+ int2 up;
+ int2 down;
+ int2 pad0;
+ int flip;
+ float gain;
+
+ int4 inSize; // [width, height, channel, batch]
+ int4 inStride;
+ int2 filterSize; // [width, height]
+ int2 filterStride;
+ int4 outSize; // [width, height, channel, batch]
+ int4 outStride;
+ int sizeMinor;
+ int sizeMajor;
+
+ int loopMinor;
+ int loopMajor;
+ int loopX;
+ int launchMinor;
+ int launchMajor;
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel specialization.
+
+struct upfirdn2d_kernel_spec
+{
+ void* kernel;
+ int tileOutW;
+ int tileOutH;
+ int loopMinor;
+ int loopX;
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/tools/torch_utils/ops/upfirdn2d.py b/tools/torch_utils/ops/upfirdn2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..ceeac2b9834e33b7c601c28bf27f32aa91c69256
--- /dev/null
+++ b/tools/torch_utils/ops/upfirdn2d.py
@@ -0,0 +1,384 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Custom PyTorch ops for efficient resampling of 2D images."""
+
+import os
+import warnings
+import numpy as np
+import torch
+import traceback
+
+from .. import custom_ops
+from .. import misc
+from . import conv2d_gradfix
+
+#----------------------------------------------------------------------------
+
+_inited = False
+_plugin = None
+
+def _init():
+ global _inited, _plugin
+ if not _inited:
+ sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
+ sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
+ try:
+ _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
+ except:
+ warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
+ return _plugin is not None
+
+def _parse_scaling(scaling):
+ if isinstance(scaling, int):
+ scaling = [scaling, scaling]
+ assert isinstance(scaling, (list, tuple))
+ assert all(isinstance(x, int) for x in scaling)
+ sx, sy = scaling
+ assert sx >= 1 and sy >= 1
+ return sx, sy
+
+def _parse_padding(padding):
+ if isinstance(padding, int):
+ padding = [padding, padding]
+ assert isinstance(padding, (list, tuple))
+ assert all(isinstance(x, int) for x in padding)
+ if len(padding) == 2:
+ padx, pady = padding
+ padding = [padx, padx, pady, pady]
+ padx0, padx1, pady0, pady1 = padding
+ return padx0, padx1, pady0, pady1
+
+def _get_filter_size(f):
+ if f is None:
+ return 1, 1
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
+ fw = f.shape[-1]
+ fh = f.shape[0]
+ with misc.suppress_tracer_warnings():
+ fw = int(fw)
+ fh = int(fh)
+ misc.assert_shape(f, [fh, fw][:f.ndim])
+ assert fw >= 1 and fh >= 1
+ return fw, fh
+
+#----------------------------------------------------------------------------
+
+def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
+ r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
+
+ Args:
+ f: Torch tensor, numpy array, or python list of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable),
+ `[]` (impulse), or
+ `None` (identity).
+ device: Result device (default: cpu).
+ normalize: Normalize the filter so that it retains the magnitude
+ for constant input signal (DC)? (default: True).
+ flip_filter: Flip the filter? (default: False).
+ gain: Overall scaling factor for signal magnitude (default: 1).
+ separable: Return a separable filter? (default: select automatically).
+
+ Returns:
+ Float32 tensor of the shape
+ `[filter_height, filter_width]` (non-separable) or
+ `[filter_taps]` (separable).
+ """
+ # Validate.
+ if f is None:
+ f = 1
+ f = torch.as_tensor(f, dtype=torch.float32)
+ assert f.ndim in [0, 1, 2]
+ assert f.numel() > 0
+ if f.ndim == 0:
+ f = f[np.newaxis]
+
+ # Separable?
+ if separable is None:
+ separable = (f.ndim == 1 and f.numel() >= 8)
+ if f.ndim == 1 and not separable:
+ f = f.ger(f)
+ assert f.ndim == (1 if separable else 2)
+
+ # Apply normalize, flip, gain, and device.
+ if normalize:
+ f /= f.sum()
+ if flip_filter:
+ f = f.flip(list(range(f.ndim)))
+ f = f * (gain ** (f.ndim / 2))
+ f = f.to(device=device)
+ return f
+
+#----------------------------------------------------------------------------
+
+def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
+ r"""Pad, upsample, filter, and downsample a batch of 2D images.
+
+ Performs the following sequence of operations for each channel:
+
+ 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
+
+ 2. Pad the image with the specified number of zeros on each side (`padding`).
+ Negative padding corresponds to cropping the image.
+
+ 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
+ so that the footprint of all output pixels lies within the input image.
+
+ 4. Downsample the image by keeping every Nth pixel (`down`).
+
+ This sequence of operations bears close resemblance to scipy.signal.upfirdn().
+ The fused op is considerably more efficient than performing the same calculation
+ using standard PyTorch ops. It supports gradients of arbitrary order.
+
+ Args:
+ x: Float32/float64/float16 input tensor of the shape
+ `[batch_size, num_channels, in_height, in_width]`.
+ f: Float32 FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ up: Integer upsampling factor. Can be a single int or a list/tuple
+ `[x, y]` (default: 1).
+ down: Integer downsampling factor. Can be a single int or a list/tuple
+ `[x, y]` (default: 1).
+ padding: Padding with respect to the upsampled image. Can be a single number
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ flip_filter: False = convolution, True = correlation (default: False).
+ gain: Overall scaling factor for signal magnitude (default: 1).
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert impl in ['ref', 'cuda']
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
+ return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
+ return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
+ """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
+ """
+ # Validate arguments.
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
+ if f is None:
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
+ assert f.dtype == torch.float32 and not f.requires_grad
+ batch_size, num_channels, in_height, in_width = x.shape
+ upx, upy = _parse_scaling(up)
+ downx, downy = _parse_scaling(down)
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
+
+ # Upsample by inserting zeros.
+ x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
+ x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
+ x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
+
+ # Pad or crop.
+ x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
+ x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
+
+ # Setup filter.
+ f = f * (gain ** (f.ndim / 2))
+ f = f.to(x.dtype)
+ if not flip_filter:
+ f = f.flip(list(range(f.ndim)))
+
+ # Convolve with the filter.
+ f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
+ if f.ndim == 4:
+ x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
+ else:
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
+
+ # Downsample by throwing away pixels.
+ x = x[:, :, ::downy, ::downx]
+ return x
+
+#----------------------------------------------------------------------------
+
+_upfirdn2d_cuda_cache = dict()
+
+def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
+ """Fast CUDA implementation of `upfirdn2d()` using custom ops.
+ """
+ # Parse arguments.
+ upx, upy = _parse_scaling(up)
+ downx, downy = _parse_scaling(down)
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
+
+ # Lookup from cache.
+ key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
+ if key in _upfirdn2d_cuda_cache:
+ return _upfirdn2d_cuda_cache[key]
+
+ # Forward op.
+ class Upfirdn2dCuda(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, f): # pylint: disable=arguments-differ
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
+ if f is None:
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
+ y = x
+ if f.ndim == 2:
+ y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
+ else:
+ y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
+ y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
+ ctx.save_for_backward(f)
+ ctx.x_shape = x.shape
+ return y
+
+ @staticmethod
+ def backward(ctx, dy): # pylint: disable=arguments-differ
+ f, = ctx.saved_tensors
+ _, _, ih, iw = ctx.x_shape
+ _, _, oh, ow = dy.shape
+ fw, fh = _get_filter_size(f)
+ p = [
+ fw - padx0 - 1,
+ iw * upx - ow * downx + padx0 - upx + 1,
+ fh - pady0 - 1,
+ ih * upy - oh * downy + pady0 - upy + 1,
+ ]
+ dx = None
+ df = None
+
+ if ctx.needs_input_grad[0]:
+ dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
+
+ assert not ctx.needs_input_grad[1]
+ return dx, df
+
+ # Add to cache.
+ _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
+ return Upfirdn2dCuda
+
+#----------------------------------------------------------------------------
+
+def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
+ r"""Filter a batch of 2D images using the given 2D FIR filter.
+
+ By default, the result is padded so that its shape matches the input.
+ User-specified padding is applied on top of that, with negative values
+ indicating cropping. Pixels outside the image are assumed to be zero.
+
+ Args:
+ x: Float32/float64/float16 input tensor of the shape
+ `[batch_size, num_channels, in_height, in_width]`.
+ f: Float32 FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ padding: Padding with respect to the output. Can be a single number or a
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ flip_filter: False = convolution, True = correlation (default: False).
+ gain: Overall scaling factor for signal magnitude (default: 1).
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
+ fw, fh = _get_filter_size(f)
+ p = [
+ padx0 + fw // 2,
+ padx1 + (fw - 1) // 2,
+ pady0 + fh // 2,
+ pady1 + (fh - 1) // 2,
+ ]
+ return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
+
+#----------------------------------------------------------------------------
+
+def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
+ r"""Upsample a batch of 2D images using the given 2D FIR filter.
+
+ By default, the result is padded so that its shape is a multiple of the input.
+ User-specified padding is applied on top of that, with negative values
+ indicating cropping. Pixels outside the image are assumed to be zero.
+
+ Args:
+ x: Float32/float64/float16 input tensor of the shape
+ `[batch_size, num_channels, in_height, in_width]`.
+ f: Float32 FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ up: Integer upsampling factor. Can be a single int or a list/tuple
+ `[x, y]` (default: 1).
+ padding: Padding with respect to the output. Can be a single number or a
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ flip_filter: False = convolution, True = correlation (default: False).
+ gain: Overall scaling factor for signal magnitude (default: 1).
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ upx, upy = _parse_scaling(up)
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
+ fw, fh = _get_filter_size(f)
+ p = [
+ padx0 + (fw + upx - 1) // 2,
+ padx1 + (fw - upx) // 2,
+ pady0 + (fh + upy - 1) // 2,
+ pady1 + (fh - upy) // 2,
+ ]
+ return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
+
+#----------------------------------------------------------------------------
+
+def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
+ r"""Downsample a batch of 2D images using the given 2D FIR filter.
+
+ By default, the result is padded so that its shape is a fraction of the input.
+ User-specified padding is applied on top of that, with negative values
+ indicating cropping. Pixels outside the image are assumed to be zero.
+
+ Args:
+ x: Float32/float64/float16 input tensor of the shape
+ `[batch_size, num_channels, in_height, in_width]`.
+ f: Float32 FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ down: Integer downsampling factor. Can be a single int or a list/tuple
+ `[x, y]` (default: 1).
+ padding: Padding with respect to the input. Can be a single number or a
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ flip_filter: False = convolution, True = correlation (default: False).
+ gain: Overall scaling factor for signal magnitude (default: 1).
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ downx, downy = _parse_scaling(down)
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
+ fw, fh = _get_filter_size(f)
+ p = [
+ padx0 + (fw - downx + 1) // 2,
+ padx1 + (fw - downx) // 2,
+ pady0 + (fh - downy + 1) // 2,
+ pady1 + (fh - downy) // 2,
+ ]
+ return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
+
+#----------------------------------------------------------------------------
diff --git a/tools/torch_utils/persistence.py b/tools/torch_utils/persistence.py
new file mode 100644
index 0000000000000000000000000000000000000000..724bababdf9c11df905b3ab6c08f9409bedee1f1
--- /dev/null
+++ b/tools/torch_utils/persistence.py
@@ -0,0 +1,251 @@
+๏ปฟ# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Facilities for pickling Python code alongside other data.
+
+The pickled code is automatically imported into a separate Python module
+during unpickling. This way, any previously exported pickles will remain
+usable even if the original code is no longer available, or if the current
+version of the code is not consistent with what was originally pickled."""
+
+import sys
+import pickle
+import io
+import inspect
+import copy
+import uuid
+import types
+from tools import dnnlib
+
+#----------------------------------------------------------------------------
+
+_version = 6 # internal version number
+_decorators = set() # {decorator_class, ...}
+_import_hooks = [] # [hook_function, ...]
+_module_to_src_dict = dict() # {module: src, ...}
+_src_to_module_dict = dict() # {src: module, ...}
+
+#----------------------------------------------------------------------------
+
+def persistent_class(orig_class):
+ r"""Class decorator that extends a given class to save its source code
+ when pickled.
+
+ Example:
+
+ from src.torch_utils import persistence
+
+ @persistence.persistent_class
+ class MyNetwork(torch.nn.Module):
+ def __init__(self, num_inputs, num_outputs):
+ super().__init__()
+ self.fc = MyLayer(num_inputs, num_outputs)
+ ...
+
+ @persistence.persistent_class
+ class MyLayer(torch.nn.Module):
+ ...
+
+ When pickled, any instance of `MyNetwork` and `MyLayer` will save its
+ source code alongside other internal state (e.g., parameters, buffers,
+ and submodules). This way, any previously exported pickle will remain
+ usable even if the class definitions have been modified or are no
+ longer available.
+
+ The decorator saves the source code of the entire Python module
+ containing the decorated class. It does *not* save the source code of
+ any imported modules. Thus, the imported modules must be available
+ during unpickling, also including `torch_utils.persistence` itself.
+
+ It is ok to call functions defined in the same module from the
+ decorated class. However, if the decorated class depends on other
+ classes defined in the same module, they must be decorated as well.
+ This is illustrated in the above example in the case of `MyLayer`.
+
+ It is also possible to employ the decorator just-in-time before
+ calling the constructor. For example:
+
+ cls = MyLayer
+ if want_to_make_it_persistent:
+ cls = persistence.persistent_class(cls)
+ layer = cls(num_inputs, num_outputs)
+
+ As an additional feature, the decorator also keeps track of the
+ arguments that were used to construct each instance of the decorated
+ class. The arguments can be queried via `obj.init_args` and
+ `obj.init_kwargs`, and they are automatically pickled alongside other
+ object state. A typical use case is to first unpickle a previous
+ instance of a persistent class, and then upgrade it to use the latest
+ version of the source code:
+
+ with open('old_pickle.pkl', 'rb') as f:
+ old_net = pickle.load(f)
+ new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
+ misc.copy_params_and_buffers(old_net, new_net, require_all=True)
+ """
+ assert isinstance(orig_class, type)
+ if is_persistent(orig_class):
+ return orig_class
+
+ assert orig_class.__module__ in sys.modules
+ orig_module = sys.modules[orig_class.__module__]
+ orig_module_src = _module_to_src(orig_module)
+
+ class Decorator(orig_class):
+ _orig_module_src = orig_module_src
+ _orig_class_name = orig_class.__name__
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._init_args = copy.deepcopy(args)
+ self._init_kwargs = copy.deepcopy(kwargs)
+ assert orig_class.__name__ in orig_module.__dict__
+ _check_pickleable(self.__reduce__())
+
+ @property
+ def init_args(self):
+ return copy.deepcopy(self._init_args)
+
+ @property
+ def init_kwargs(self):
+ return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
+
+ def __reduce__(self):
+ fields = list(super().__reduce__())
+ fields += [None] * max(3 - len(fields), 0)
+ if fields[0] is not _reconstruct_persistent_obj:
+ meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
+ fields[0] = _reconstruct_persistent_obj # reconstruct func
+ fields[1] = (meta,) # reconstruct args
+ fields[2] = None # state dict
+ return tuple(fields)
+
+ Decorator.__name__ = orig_class.__name__
+ _decorators.add(Decorator)
+ return Decorator
+
+#----------------------------------------------------------------------------
+
+def is_persistent(obj):
+ r"""Test whether the given object or class is persistent, i.e.,
+ whether it will save its source code when pickled.
+ """
+ try:
+ if obj in _decorators:
+ return True
+ except TypeError:
+ pass
+ return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
+
+#----------------------------------------------------------------------------
+
+def import_hook(hook):
+ r"""Register an import hook that is called whenever a persistent object
+ is being unpickled. A typical use case is to patch the pickled source
+ code to avoid errors and inconsistencies when the API of some imported
+ module has changed.
+
+ The hook should have the following signature:
+
+ hook(meta) -> modified meta
+
+ `meta` is an instance of `dnnlib.EasyDict` with the following fields:
+
+ type: Type of the persistent object, e.g. `'class'`.
+ version: Internal version number of `torch_utils.persistence`.
+ module_src Original source code of the Python module.
+ class_name: Class name in the original Python module.
+ state: Internal state of the object.
+
+ Example:
+
+ @persistence.import_hook
+ def wreck_my_network(meta):
+ if meta.class_name == 'MyNetwork':
+ print('MyNetwork is being imported. I will wreck it!')
+ meta.module_src = meta.module_src.replace("True", "False")
+ return meta
+ """
+ assert callable(hook)
+ _import_hooks.append(hook)
+
+#----------------------------------------------------------------------------
+
+def _reconstruct_persistent_obj(meta):
+ r"""Hook that is called internally by the `pickle` module to unpickle
+ a persistent object.
+ """
+ meta = dnnlib.EasyDict(meta)
+ meta.state = dnnlib.EasyDict(meta.state)
+ for hook in _import_hooks:
+ meta = hook(meta)
+ assert meta is not None
+
+ assert meta.version == _version
+ module = _src_to_module(meta.module_src)
+
+ assert meta.type == 'class'
+ orig_class = module.__dict__[meta.class_name]
+ decorator_class = persistent_class(orig_class)
+ obj = decorator_class.__new__(decorator_class)
+
+ setstate = getattr(obj, '__setstate__', None)
+ if callable(setstate):
+ setstate(meta.state) # pylint: disable=not-callable
+ else:
+ obj.__dict__.update(meta.state)
+ return obj
+
+#----------------------------------------------------------------------------
+
+def _module_to_src(module):
+ r"""Query the source code of a given Python module.
+ """
+ src = _module_to_src_dict.get(module, None)
+ if src is None:
+ src = inspect.getsource(module)
+ _module_to_src_dict[module] = src
+ _src_to_module_dict[src] = module
+ return src
+
+def _src_to_module(src):
+ r"""Get or create a Python module for the given source code.
+ """
+ module = _src_to_module_dict.get(src, None)
+ if module is None:
+ module_name = "_imported_module_" + uuid.uuid4().hex
+ module = types.ModuleType(module_name)
+ sys.modules[module_name] = module
+ _module_to_src_dict[module] = src
+ _src_to_module_dict[src] = module
+ exec(src, module.__dict__) # pylint: disable=exec-used
+ return module
+
+#----------------------------------------------------------------------------
+
+def _check_pickleable(obj):
+ r"""Check that the given object is pickleable, raising an exception if
+ it is not. This function is expected to be considerably more efficient
+ than actually pickling the object.
+ """
+ def recurse(obj):
+ if isinstance(obj, (list, tuple, set)):
+ return [recurse(x) for x in obj]
+ if isinstance(obj, dict):
+ return [[recurse(x), recurse(y)] for x, y in obj.items()]
+ if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
+ return None # Python primitive types are pickleable.
+ if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
+ return None # NumPy arrays and PyTorch tensors are pickleable.
+ if is_persistent(obj):
+ return None # Persistent objects are pickleable, by virtue of the constructor check.
+ return obj
+ with io.BytesIO() as f:
+ pickle.dump(recurse(obj), f)
+
+#----------------------------------------------------------------------------
diff --git a/tools/torch_utils/training_stats.py b/tools/torch_utils/training_stats.py
new file mode 100644
index 0000000000000000000000000000000000000000..57432c7935b8e43d2e7fa02b63c515681ee11245
--- /dev/null
+++ b/tools/torch_utils/training_stats.py
@@ -0,0 +1,268 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Facilities for reporting and collecting training statistics across
+multiple processes and devices. The interface is designed to minimize
+synchronization overhead as well as the amount of boilerplate in user
+code."""
+
+import re
+import numpy as np
+import torch
+from tools import dnnlib
+
+from . import misc
+
+#----------------------------------------------------------------------------
+
+_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
+_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
+_counter_dtype = torch.float64 # Data type to use for the internal counters.
+_rank = 0 # Rank of the current process.
+_sync_device = None # Device to use for multiprocess communication. None = single-process.
+_sync_called = False # Has _sync() been called yet?
+_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
+_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
+
+#----------------------------------------------------------------------------
+
+def init_multiprocessing(rank, sync_device):
+ r"""Initializes `torch_utils.training_stats` for collecting statistics
+ across multiple processes.
+
+ This function must be called after
+ `torch.distributed.init_process_group()` and before `Collector.update()`.
+ The call is not necessary if multi-process collection is not needed.
+
+ Args:
+ rank: Rank of the current process.
+ sync_device: PyTorch device to use for inter-process
+ communication, or None to disable multi-process
+ collection. Typically `torch.device('cuda', rank)`.
+ """
+ global _rank, _sync_device
+ assert not _sync_called
+ _rank = rank
+ _sync_device = sync_device
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def report(name, value):
+ r"""Broadcasts the given set of scalars to all interested instances of
+ `Collector`, across device and process boundaries.
+
+ This function is expected to be extremely cheap and can be safely
+ called from anywhere in the training loop, loss function, or inside a
+ `torch.nn.Module`.
+
+ Warning: The current implementation expects the set of unique names to
+ be consistent across processes. Please make sure that `report()` is
+ called at least once for each unique name by each process, and in the
+ same order. If a given process has no scalars to broadcast, it can do
+ `report(name, [])` (empty list).
+
+ Args:
+ name: Arbitrary string specifying the name of the statistic.
+ Averages are accumulated separately for each unique name.
+ value: Arbitrary set of scalars. Can be a list, tuple,
+ NumPy array, PyTorch tensor, or Python scalar.
+
+ Returns:
+ The same `value` that was passed in.
+ """
+ if name not in _counters:
+ _counters[name] = dict()
+
+ elems = torch.as_tensor(value)
+ if elems.numel() == 0:
+ return value
+
+ elems = elems.detach().flatten().to(_reduce_dtype)
+ moments = torch.stack([
+ torch.ones_like(elems).sum(),
+ elems.sum(),
+ elems.square().sum(),
+ ])
+ assert moments.ndim == 1 and moments.shape[0] == _num_moments
+ moments = moments.to(_counter_dtype)
+
+ device = moments.device
+ if device not in _counters[name]:
+ _counters[name][device] = torch.zeros_like(moments)
+ _counters[name][device].add_(moments)
+ return value
+
+#----------------------------------------------------------------------------
+
+def report0(name, value):
+ r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
+ but ignores any scalars provided by the other processes.
+ See `report()` for further details.
+ """
+ report(name, value if _rank == 0 else [])
+ return value
+
+#----------------------------------------------------------------------------
+
+class Collector:
+ r"""Collects the scalars broadcasted by `report()` and `report0()` and
+ computes their long-term averages (mean and standard deviation) over
+ user-defined periods of time.
+
+ The averages are first collected into internal counters that are not
+ directly visible to the user. They are then copied to the user-visible
+ state as a result of calling `update()` and can then be queried using
+ `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
+ internal counters for the next round, so that the user-visible state
+ effectively reflects averages collected between the last two calls to
+ `update()`.
+
+ Args:
+ regex: Regular expression defining which statistics to
+ collect. The default is to collect everything.
+ keep_previous: Whether to retain the previous averages if no
+ scalars were collected on a given round
+ (default: True).
+ """
+ def __init__(self, regex='.*', keep_previous=True):
+ self._regex = re.compile(regex)
+ self._keep_previous = keep_previous
+ self._cumulative = dict()
+ self._moments = dict()
+ self.update()
+ self._moments.clear()
+
+ def names(self):
+ r"""Returns the names of all statistics broadcasted so far that
+ match the regular expression specified at construction time.
+ """
+ return [name for name in _counters if self._regex.fullmatch(name)]
+
+ def update(self):
+ r"""Copies current values of the internal counters to the
+ user-visible state and resets them for the next round.
+
+ If `keep_previous=True` was specified at construction time, the
+ operation is skipped for statistics that have received no scalars
+ since the last update, retaining their previous averages.
+
+ This method performs a number of GPU-to-CPU transfers and one
+ `torch.distributed.all_reduce()`. It is intended to be called
+ periodically in the main training loop, typically once every
+ N training steps.
+ """
+ if not self._keep_previous:
+ self._moments.clear()
+ for name, cumulative in _sync(self.names()):
+ if name not in self._cumulative:
+ self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
+ delta = cumulative - self._cumulative[name]
+ self._cumulative[name].copy_(cumulative)
+ if float(delta[0]) != 0:
+ self._moments[name] = delta
+
+ def _get_delta(self, name):
+ r"""Returns the raw moments that were accumulated for the given
+ statistic between the last two calls to `update()`, or zero if
+ no scalars were collected.
+ """
+ assert self._regex.fullmatch(name)
+ if name not in self._moments:
+ self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
+ return self._moments[name]
+
+ def num(self, name):
+ r"""Returns the number of scalars that were accumulated for the given
+ statistic between the last two calls to `update()`, or zero if
+ no scalars were collected.
+ """
+ delta = self._get_delta(name)
+ return int(delta[0])
+
+ def mean(self, name):
+ r"""Returns the mean of the scalars that were accumulated for the
+ given statistic between the last two calls to `update()`, or NaN if
+ no scalars were collected.
+ """
+ delta = self._get_delta(name)
+ if int(delta[0]) == 0:
+ return float('nan')
+ return float(delta[1] / delta[0])
+
+ def std(self, name):
+ r"""Returns the standard deviation of the scalars that were
+ accumulated for the given statistic between the last two calls to
+ `update()`, or NaN if no scalars were collected.
+ """
+ delta = self._get_delta(name)
+ if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
+ return float('nan')
+ if int(delta[0]) == 1:
+ return float(0)
+ mean = float(delta[1] / delta[0])
+ raw_var = float(delta[2] / delta[0])
+ return np.sqrt(max(raw_var - np.square(mean), 0))
+
+ def as_dict(self):
+ r"""Returns the averages accumulated between the last two calls to
+ `update()` as an `dnnlib.EasyDict`. The contents are as follows:
+
+ dnnlib.EasyDict(
+ NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
+ ...
+ )
+ """
+ stats = dnnlib.EasyDict()
+ for name in self.names():
+ stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
+ return stats
+
+ def __getitem__(self, name):
+ r"""Convenience getter.
+ `collector[name]` is a synonym for `collector.mean(name)`.
+ """
+ return self.mean(name)
+
+#----------------------------------------------------------------------------
+
+def _sync(names):
+ r"""Synchronize the global cumulative counters across devices and
+ processes. Called internally by `Collector.update()`.
+ """
+ if len(names) == 0:
+ return []
+ global _sync_called
+ _sync_called = True
+
+ # Collect deltas within current rank.
+ deltas = []
+ device = _sync_device if _sync_device is not None else torch.device('cpu')
+ for name in names:
+ delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
+ for counter in _counters[name].values():
+ delta.add_(counter.to(device))
+ counter.copy_(torch.zeros_like(counter))
+ deltas.append(delta)
+ deltas = torch.stack(deltas)
+
+ # Sum deltas across ranks.
+ if _sync_device is not None:
+ torch.distributed.all_reduce(deltas)
+
+ # Update cumulative values.
+ deltas = deltas.cpu()
+ for idx, name in enumerate(names):
+ if name not in _cumulative:
+ _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
+ _cumulative[name].add_(deltas[idx])
+
+ # Return name-value pairs.
+ return [(name, _cumulative[name]) for name in names]
+
+#----------------------------------------------------------------------------
diff --git a/tools/utils/__init__.py b/tools/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..55435dbaba8542b9080b8bdcbc8ed2015d445a4b
--- /dev/null
+++ b/tools/utils/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
diff --git a/tools/utils/dataset.py b/tools/utils/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..c719caaa97a42ba5b1ddef031d6532e8a9e5eb6e
--- /dev/null
+++ b/tools/utils/dataset.py
@@ -0,0 +1,497 @@
+๏ปฟ# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import os
+import copy
+from typing import List, Dict
+import zipfile
+import json
+import random
+from typing import Tuple
+
+import numpy as np
+import PIL.Image
+import torch
+from tools import dnnlib
+from omegaconf import DictConfig, OmegaConf
+
+from tools.utils.layers import sample_frames
+
+try:
+ import pyspng
+except ImportError:
+ pyspng = None
+
+#----------------------------------------------------------------------------
+
+NUMPY_INTEGER_TYPES = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64]
+NUMPY_FLOAT_TYPES = [np.float16, np.float32, np.float64, np.single, np.double]
+
+#----------------------------------------------------------------------------
+
+class Dataset(torch.utils.data.Dataset):
+ def __init__(self,
+ name, # Name of the dataset.
+ raw_shape, # Shape of the raw image data (NCHW).
+ max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
+ use_labels = False, # Enable conditioning labels? False = label dimension is zero.
+ xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
+ random_seed = 0, # Random seed to use when applying max_size.
+ ):
+ self._name = name
+ self._raw_shape = list(raw_shape)
+ self._use_labels = use_labels
+ self._raw_labels = None
+ self._label_shape = None
+
+ # Apply max_size.
+ self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
+ if (max_size is not None) and (self._raw_idx.size > max_size):
+ np.random.RandomState(random_seed).shuffle(self._raw_idx)
+ self._raw_idx = np.sort(self._raw_idx[:max_size])
+
+ # Apply xflip.
+ self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
+ if xflip:
+ self._raw_idx = np.tile(self._raw_idx, 2)
+ self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
+
+ @staticmethod
+ def _file_ext(fname):
+ return os.path.splitext(fname)[1].lower()
+
+ def _get_raw_labels(self):
+ if self._raw_labels is None:
+ self._raw_labels = self._load_raw_labels() if self._use_labels else None
+ if self._raw_labels is None:
+ self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
+ assert isinstance(self._raw_labels, np.ndarray)
+ assert self._raw_labels.shape[0] == self._raw_shape[0]
+ assert self._raw_labels.dtype in [np.float32, np.int64]
+ if self._raw_labels.dtype == np.int64:
+ assert np.all(self._raw_labels >= 0)
+ return self._raw_labels
+
+ def close(self): # to be overridden by subclass
+ pass
+
+ def _load_raw_image(self, raw_idx): # to be overridden by subclass
+ raise NotImplementedError
+
+ def _load_raw_labels(self): # to be overridden by subclass
+ raise NotImplementedError
+
+ def __getstate__(self):
+ return dict(self.__dict__, _raw_labels=None)
+
+ def __del__(self):
+ try:
+ self.close()
+ except:
+ pass
+
+ def __len__(self):
+ return self._raw_idx.size
+
+ def __getitem__(self, idx):
+ image = self._load_raw_image(self._raw_idx[idx])
+ assert isinstance(image, np.ndarray)
+ assert list(image.shape) == self.image_shape
+ assert image.dtype == np.uint8
+ if self._xflip[idx]:
+ assert image.ndim == 3 # CHW
+ image = image[:, :, ::-1]
+
+ return {
+ 'image': image.copy(),
+ 'label': self.get_label(idx),
+ }
+
+ def get_label(self, idx):
+ label = self._get_raw_labels()[self._raw_idx[idx]]
+ if label.dtype == np.int64:
+ onehot = np.zeros(self.label_shape, dtype=np.float32)
+ onehot[label] = 1
+ label = onehot
+ return label.copy()
+
+ def get_details(self, idx):
+ d = dnnlib.EasyDict()
+ d.raw_idx = int(self._raw_idx[idx])
+ d.xflip = (int(self._xflip[idx]) != 0)
+ d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
+ return d
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def image_shape(self):
+ return list(self._raw_shape[1:])
+
+ @property
+ def num_channels(self):
+ assert len(self.image_shape) == 3 # CHW
+ return self.image_shape[0]
+
+ @property
+ def resolution(self):
+ assert len(self.image_shape) == 3 # CHW
+ assert self.image_shape[1] == self.image_shape[2]
+ return self.image_shape[1]
+
+ @property
+ def label_shape(self):
+ if self._label_shape is None:
+ raw_labels = self._get_raw_labels()
+ if raw_labels.dtype == np.int64:
+ self._label_shape = [int(np.max(raw_labels)) + 1]
+ else:
+ self._label_shape = raw_labels.shape[1:]
+ return list(self._label_shape)
+
+ @property
+ def label_dim(self):
+ assert len(self.label_shape) == 1, f"Labels must be 1-dimensional: {self.label_shape} to use `.label_dim`"
+ return self.label_shape[0]
+
+ @property
+ def has_labels(self):
+ return any(x != 0 for x in self.label_shape)
+
+ @property
+ def has_onehot_labels(self):
+ return self._get_raw_labels().dtype == np.int64
+
+#----------------------------------------------------------------------------
+
+class ImageFolderDataset(Dataset):
+ def __init__(self,
+ path, # Path to directory or zip.
+ resolution = None, # Ensure specific resolution, None = highest available.
+ **super_kwargs, # Additional arguments for the Dataset base class.
+ ):
+ self._path = path
+ self._zipfile = None
+
+ if os.path.isdir(self._path):
+ self._type = 'dir'
+ self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
+ elif self._file_ext(self._path) == '.zip':
+ self._type = 'zip'
+ self._all_fnames = set(self._get_zipfile().namelist())
+ else:
+ raise IOError('Path must point to a directory or zip')
+
+ PIL.Image.init()
+ self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
+ if len(self._image_fnames) == 0:
+ raise IOError('No image files found in the specified path')
+
+ name = os.path.splitext(os.path.basename(self._path))[0]
+ raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
+ if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
+ raise IOError(f'Image files do not match the specified resolution. Resolution is {resolution}, shape is {raw_shape}')
+ super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
+
+ def _get_zipfile(self):
+ assert self._type == 'zip'
+ if self._zipfile is None:
+ self._zipfile = zipfile.ZipFile(self._path)
+ return self._zipfile
+
+ def _open_file(self, fname):
+ if self._type == 'dir':
+ return open(os.path.join(self._path, fname), 'rb')
+ if self._type == 'zip':
+ return self._get_zipfile().open(fname, 'r')
+ return None
+
+ def close(self):
+ try:
+ if self._zipfile is not None:
+ self._zipfile.close()
+ finally:
+ self._zipfile = None
+
+ def __getstate__(self):
+ return dict(super().__getstate__(), _zipfile=None)
+
+ def _load_raw_image(self, raw_idx):
+ fname = self._image_fnames[raw_idx]
+
+ with self._open_file(fname) as f:
+ use_pyspng = pyspng is not None and self._file_ext(fname) == '.png'
+ image = load_image_from_buffer(f, use_pyspng=use_pyspng)
+
+ return image
+
+ def _load_raw_labels(self):
+ fname = 'dataset.json'
+ labels_files = [f for f in self._all_fnames if f.endswith(fname)]
+ if len(labels_files) == 0:
+ return None
+ assert len(labels_files) == 1, f"There can be only a single {fname} file"
+ with self._open_file(labels_files[0]) as f:
+ labels = json.load(f)['labels']
+ if labels is None:
+ return None
+ labels = dict(labels)
+ labels = [labels[remove_root(fname, self._name).replace('\\', '/')] for fname in self._image_fnames]
+ labels = np.array(labels)
+
+ if labels.dtype in NUMPY_INTEGER_TYPES:
+ labels = labels.astype(np.int64)
+ elif labels.dtype in NUMPY_FLOAT_TYPES:
+ labels = labels.astype(np.float32)
+ else:
+ raise NotImplementedError(f"Unsupported label dtype: {labels.dtype}")
+
+ return labels
+
+#----------------------------------------------------------------------------
+
+class VideoFramesFolderDataset(Dataset):
+ def __init__(self,
+ path, # Path to directory or zip.
+ cfg: DictConfig, # Config
+ resolution=None, # Unused arg for backward compatibility
+ load_n_consecutive: int=None, # Should we load first N frames for each video?
+ load_n_consecutive_random_offset: bool=True, # Should we use a random offset when loading consecutive frames?
+ subsample_factor: int=1, # Sampling factor, i.e. decreasing the temporal resolution
+ discard_short_videos: bool=False, # Should we discard videos that are shorter than `load_n_consecutive`?
+ **super_kwargs, # Additional arguments for the Dataset base class.
+ ):
+ self.sampling_dict = OmegaConf.to_container(OmegaConf.create({**cfg.sampling})) if 'sampling' in cfg else None
+ self.max_num_frames = cfg.max_num_frames
+ self._path = path
+ self._zipfile = None
+ self.load_n_consecutive = load_n_consecutive
+ self.load_n_consecutive_random_offset = load_n_consecutive_random_offset
+ self.subsample_factor = subsample_factor
+ print(subsample_factor)
+ self.discard_short_videos = discard_short_videos
+
+ if self.subsample_factor > 1 and self.load_n_consecutive is None:
+ raise NotImplementedError("Can do subsampling only when loading consecutive frames.")
+
+ listdir_full_paths = lambda d: sorted([os.path.join(d, x) for x in os.listdir(d)])
+ name = os.path.splitext(os.path.basename(self._path))[0]
+
+ if os.path.isdir(self._path):
+ self._type = 'dir'
+ # We assume that the depth is 2
+ self._all_objects = {o for d in listdir_full_paths(self._path) for o in (([d] + listdir_full_paths(d)) if os.path.isdir(d) else [d])}
+ self._all_objects = {os.path.relpath(o, start=os.path.dirname(self._path)) for o in {self._path}.union(self._all_objects)}
+ elif self._file_ext(self._path) == '.zip':
+ self._type = 'zip'
+ self._all_objects = set(self._get_zipfile().namelist())
+ else:
+ raise IOError('Path must be either a directory or point to a zip archive')
+
+ PIL.Image.init()
+ self._video_dir2frames = {}
+ objects = sorted([d for d in self._all_objects])
+ root_path_depth = len(os.path.normpath(objects[0]).split(os.path.sep))
+ curr_d = objects[1] # Root path is the first element
+
+ for o in objects[1:]:
+ curr_obj_depth = len(os.path.normpath(o).split(os.path.sep))
+
+ if self._file_ext(o) in PIL.Image.EXTENSION:
+ assert o.startswith(curr_d), f"Object {o} is out of sync. It should lie inside {curr_d}"
+ assert curr_obj_depth == root_path_depth + 2, "Frame images should be inside directories"
+ if not curr_d in self._video_dir2frames:
+ self._video_dir2frames[curr_d] = []
+ self._video_dir2frames[curr_d].append(o)
+ elif self._file_ext(o) == 'json':
+ assert curr_obj_depth == root_path_depth + 1, "Classes info file should be inside the root dir"
+ pass
+ else:
+ # We encountered a new directory
+ assert curr_obj_depth == root_path_depth + 1, f"Video directories should be inside the root dir. {o} is not."
+ if curr_d in self._video_dir2frames:
+ sorted_files = sorted(self._video_dir2frames[curr_d])
+ self._video_dir2frames[curr_d] = sorted_files
+ curr_d = o
+
+ if self.discard_short_videos:
+ self._video_dir2frames = {d: fs for d, fs in self._video_dir2frames.items() if len(fs) >= self.load_n_consecutive * self.subsample_factor}
+
+ self._video_idx2frames = [frames for frames in self._video_dir2frames.values()]
+
+ if len(self._video_idx2frames) == 0:
+ raise IOError('No videos found in the specified archive')
+
+ raw_shape = [len(self._video_idx2frames)] + list(self._load_raw_frames(0, [0])[0][0].shape)
+
+ super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
+
+ def _get_zipfile(self):
+ assert self._type == 'zip'
+ if self._zipfile is None:
+ self._zipfile = zipfile.ZipFile(self._path)
+ return self._zipfile
+
+ def _open_file(self, fname):
+ if self._type == 'dir':
+ return open(os.path.join(os.path.dirname(self._path), fname), 'rb')
+ if self._type == 'zip':
+ return self._get_zipfile().open(fname, 'r')
+ return None
+
+ def close(self):
+ try:
+ if self._zipfile is not None:
+ self._zipfile.close()
+ finally:
+ self._zipfile = None
+
+ def __getstate__(self):
+ return dict(super().__getstate__(), _zipfile=None)
+
+ def _load_raw_labels(self):
+ """
+ We leave the `dataset.json` file in the same format as in the original SG2-ADA repo:
+ it's `labels` field is a hashmap of filename-label pairs.
+ """
+ fname = 'dataset.json'
+ labels_files = [f for f in self._all_objects if f.endswith(fname)]
+ if len(labels_files) == 0:
+ return None
+ assert len(labels_files) == 1, f"There can be only a single {fname} file"
+ with self._open_file(labels_files[0]) as f:
+ labels = json.load(f)['labels']
+ if labels is None:
+ return None
+
+ labels = dict(labels)
+ # The `dataset.json` file defines a label for each image and
+ # For the video dataset, this is both inconvenient and redundant.
+ # So let's redefine this
+ video_labels = {}
+ for filename, label in labels.items():
+ dirname = os.path.dirname(filename)
+ if dirname in video_labels:
+ assert video_labels[dirname] == label
+ else:
+ video_labels[dirname] = label
+ labels = video_labels
+ labels = [labels[os.path.normpath(dname).split(os.path.sep)[-1]] for dname in self._video_dir2frames]
+ labels = np.array(labels)
+
+ if labels.dtype in NUMPY_INTEGER_TYPES:
+ labels = labels.astype(np.int64)
+ elif labels.dtype in NUMPY_FLOAT_TYPES:
+ labels = labels.astype(np.float32)
+ else:
+ raise NotImplementedError(f"Unsupported label dtype: {labels.dtype}")
+
+ return labels
+
+ def __getitem__(self, idx: int) -> Dict:
+ if self.load_n_consecutive:
+ num_frames_available = len(self._video_idx2frames[self._raw_idx[idx]])
+ assert num_frames_available - self.load_n_consecutive * self.subsample_factor >= 0, f"We have only {num_frames_available} frames available, cannot load {self.load_n_consecutive} frames."
+
+ if self.load_n_consecutive_random_offset:
+ random_offset = random.randint(0, num_frames_available - self.load_n_consecutive * self.subsample_factor + self.subsample_factor - 1)
+ else:
+ random_offset = 0
+ frames_idx = np.arange(0, self.load_n_consecutive * self.subsample_factor, self.subsample_factor) + random_offset
+ else:
+ frames_idx = None
+
+ frames, times = self._load_raw_frames(self._raw_idx[idx], frames_idx=frames_idx)
+
+ assert isinstance(frames, np.ndarray)
+ assert list(frames[0].shape) == self.image_shape
+ assert frames.dtype == np.uint8
+ assert len(frames) == len(times)
+
+ if self._xflip[idx]:
+ assert frames.ndim == 4 # TCHW
+ frames = frames[:, :, :, ::-1]
+
+ return {
+ 'image': frames.copy(),
+ 'label': self.get_label(idx),
+ 'times': times,
+ 'video_len': self.get_video_len(idx),
+ }
+
+ def get_video_len(self, idx: int) -> int:
+ return min(self.max_num_frames, len(self._video_idx2frames[self._raw_idx[idx]]))
+
+ def _load_raw_frames(self, raw_idx: int, frames_idx: List[int]=None) -> Tuple[np.ndarray, np.ndarray]:
+ frame_paths = self._video_idx2frames[raw_idx]
+ total_len = len(frame_paths)
+ offset = 0
+ images = []
+
+ if frames_idx is None:
+ assert not self.sampling_dict is None, f"The dataset was created without `cfg.sampling` config and cannot sample frames on its own."
+ if total_len > self.max_num_frames:
+ offset = random.randint(0, total_len - self.max_num_frames)
+ frames_idx = sample_frames(self.sampling_dict, total_video_len=min(total_len, self.max_num_frames)) + offset
+ else:
+ frames_idx = np.array(frames_idx)
+
+ for frame_idx in frames_idx:
+ with self._open_file(frame_paths[frame_idx]) as f:
+ images.append(load_image_from_buffer(f))
+
+ return np.array(images), frames_idx - offset
+
+ def compute_max_num_frames(self) -> int:
+ return max(len(frames) for frames in self._video_idx2frames)
+
+#----------------------------------------------------------------------------
+
+def load_image_from_buffer(f, use_pyspng: bool=False) -> np.ndarray:
+ if use_pyspng:
+ image = pyspng.load(f.read())
+ else:
+ image = np.array(PIL.Image.open(f))
+ if image.ndim == 2:
+ image = image[:, :, np.newaxis] # HW => HWC
+ image = image.transpose(2, 0, 1) # HWC => CHW
+
+ return image
+
+#----------------------------------------------------------------------------
+
+def video_to_image_dataset_kwargs(video_dataset_kwargs: dnnlib.EasyDict) -> dnnlib.EasyDict:
+ """Converts video dataset kwargs to image dataset kwargs"""
+ return dnnlib.EasyDict(
+ class_name='training.dataset.ImageFolderDataset',
+ path=video_dataset_kwargs.path,
+ use_labels=video_dataset_kwargs.use_labels,
+ xflip=video_dataset_kwargs.xflip,
+ resolution=video_dataset_kwargs.resolution,
+ random_seed=video_dataset_kwargs.get('random_seed'),
+ # Explicitly ignoring the max size, since we are now interested
+ # in the number of images instead of the number of videos
+ # max_size=video_dataset_kwargs.max_size,
+ )
+
+#----------------------------------------------------------------------------
+
+def remove_root(fname: os.PathLike, root_name: os.PathLike):
+ """`root_name` should NOT start with '/'"""
+ if fname == root_name or fname == ('/' + root_name):
+ return ''
+ elif fname.startswith(root_name + '/'):
+ return fname[len(root_name) + 1:]
+ elif fname.startswith('/' + root_name + '/'):
+ return fname[len(root_name) + 2:]
+ else:
+ return fname
+
+#----------------------------------------------------------------------------
diff --git a/tools/utils/layers.py b/tools/utils/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..caf82f6f2302148a640e2a26c3ed0d99d7bfb091
--- /dev/null
+++ b/tools/utils/layers.py
@@ -0,0 +1,448 @@
+import random
+from typing import Dict, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from omegaconf import DictConfig
+
+from tools.torch_utils import persistence
+from tools.torch_utils.ops import bias_act, upfirdn2d, conv2d_resample
+from tools.torch_utils import misc
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def normalize_2nd_moment(x, dim=1, eps=1e-8):
+ return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class MappingNetwork(torch.nn.Module):
+ def __init__(self,
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
+ w_dim, # Intermediate latent (W) dimensionality.
+ num_ws, # Number of intermediate latents to output, None = do not broadcast.
+ num_layers = 8, # Number of mapping layers.
+ embed_features = None, # Label embedding dimensionality, None = same as w_dim.
+ layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim.
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
+ w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track.
+ cfg = {}, # Additional config
+ ):
+ super().__init__()
+
+ self.cfg = cfg
+ self.z_dim = z_dim
+ self.c_dim = c_dim
+ self.w_dim = w_dim
+ self.num_ws = num_ws
+ self.num_layers = num_layers
+ self.w_avg_beta = w_avg_beta
+
+ if embed_features is None:
+ embed_features = w_dim
+ if c_dim == 0:
+ embed_features = 0
+ if layer_features is None:
+ layer_features = w_dim
+
+ features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
+
+ if c_dim > 0:
+ self.embed = FullyConnectedLayer(c_dim, embed_features)
+
+ for idx in range(num_layers):
+ in_features = features_list[idx]
+ out_features = features_list[idx + 1]
+ layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
+ setattr(self, f'fc{idx}', layer)
+
+ if num_ws is not None and w_avg_beta is not None:
+ self.register_buffer('w_avg', torch.zeros([w_dim]))
+
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
+ # Embed, normalize, and concat inputs.
+ x = None
+ with torch.autograd.profiler.record_function('input'):
+ if self.z_dim > 0:
+ misc.assert_shape(z, [None, self.z_dim])
+ x = normalize_2nd_moment(z.to(torch.float32))
+
+ if self.c_dim > 0:
+ misc.assert_shape(c, [None, self.c_dim])
+ y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
+ x = torch.cat([x, y], dim=1) if x is not None else y
+
+ # Main layers.
+ for idx in range(self.num_layers):
+ layer = getattr(self, f'fc{idx}')
+ x = layer(x)
+
+ # Update moving average of W.
+ if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
+ with torch.autograd.profiler.record_function('update_w_avg'):
+ self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
+
+ # Broadcast.
+ if self.num_ws is not None:
+ with torch.autograd.profiler.record_function('broadcast'):
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
+
+ # Apply truncation.
+ if truncation_psi != 1:
+ with torch.autograd.profiler.record_function('truncate'):
+ assert self.w_avg_beta is not None
+ if self.num_ws is None or truncation_cutoff is None:
+ x = self.w_avg.lerp(x, truncation_psi)
+ else:
+ x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
+ return x
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class FullyConnectedLayer(torch.nn.Module):
+ def __init__(self,
+ in_features, # Number of input features.
+ out_features, # Number of output features.
+ bias = True, # Apply additive bias before the activation function?
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
+ lr_multiplier = 1, # Learning rate multiplier.
+ bias_init = 0, # Initial value for the additive bias.
+ ):
+ super().__init__()
+ self.activation = activation
+ self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
+ self.bias = torch.nn.Parameter(torch.full([out_features], float(bias_init))) if bias else None
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
+ self.bias_gain = lr_multiplier
+
+ def forward(self, x):
+ w = self.weight.to(x.dtype) * self.weight_gain
+ b = self.bias
+ if b is not None:
+ b = b.to(x.dtype)
+ if self.bias_gain != 1:
+ b = b * self.bias_gain
+
+ if self.activation == 'linear' and b is not None:
+ x = torch.addmm(b.unsqueeze(0), x, w.t())
+ else:
+ x = x.matmul(w.t())
+ x = bias_act.bias_act(x, b, act=self.activation)
+ return x
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class Conv2dLayer(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels.
+ out_channels, # Number of output channels.
+ kernel_size, # Width and height of the convolution kernel.
+ bias = True, # Apply additive bias before the activation function?
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
+ up = 1, # Integer upsampling factor.
+ down = 1, # Integer downsampling factor.
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
+ conv_clamp = None, # Clamp the output to +-X, None = disable clamping.
+ channels_last = False, # Expect the input to have memory_format=channels_last?
+ trainable = True, # Update the weights of this layer during training?
+ instance_norm = False, # Should we apply instance normalization to y?
+ lr_multiplier = 1.0, # Learning rate multiplier.
+ ):
+ super().__init__()
+ self.activation = activation
+ self.up = up
+ self.down = down
+ self.conv_clamp = conv_clamp
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+ self.padding = kernel_size // 2
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
+ self.act_gain = bias_act.activation_funcs[activation].def_gain
+ self.instance_norm = instance_norm
+ self.lr_multiplier = lr_multiplier
+
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
+ weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)
+ bias = torch.zeros([out_channels]) if bias else None
+ if trainable:
+ self.weight = torch.nn.Parameter(weight)
+ self.bias = torch.nn.Parameter(bias) if bias is not None else None
+ else:
+ self.register_buffer('weight', weight)
+ if bias is not None:
+ self.register_buffer('bias', bias)
+ else:
+ self.bias = None
+
+ def forward(self, x, gain=1):
+ w = self.weight * (self.weight_gain * self.lr_multiplier)
+ b = (self.bias.to(x.dtype) * self.lr_multiplier) if self.bias is not None else None
+ flip_weight = (self.up == 1) # slightly faster
+ x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight)
+
+ act_gain = self.act_gain * gain
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
+ x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp)
+
+ if self.instance_norm:
+ x = (x - x.mean(dim=(2,3), keepdim=True)) / (x.std(dim=(2,3), keepdim=True) + 1e-8) # [batch_size, c, h, w]
+
+ return x
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class GenInput(nn.Module):
+ def __init__(self, cfg: DictConfig, channel_dim: int, motion_v_dim: int=None):
+ super().__init__()
+
+ self.cfg = cfg
+
+ if self.cfg.input.type == 'const':
+ self.input = torch.nn.Parameter(torch.randn([channel_dim, 4, 4]))
+ self.total_dim = channel_dim
+ elif self.cfg.input.type == 'temporal':
+ self.input = TemporalInput(self.cfg, channel_dim, motion_v_dim=motion_v_dim)
+ self.total_dim = self.input.get_dim()
+ else:
+ raise NotImplementedError(f'Unkown input type: {self.cfg.input.type}')
+
+ def forward(self, batch_size: int, motion_v: Optional[torch.Tensor]=None, dtype=None, memory_format=None) -> torch.Tensor:
+ if self.cfg.input.type == 'const':
+ x = self.input.to(dtype=dtype, memory_format=memory_format)
+ x = x.unsqueeze(0).repeat([batch_size, 1, 1, 1])
+ elif self.cfg.input.type == 'temporal':
+ x = self.input(motion_v=motion_v) # [batch_size, d, h, w]
+ else:
+ raise NotImplementedError(f'Unkown input type: {self.cfg.input.type}')
+
+ return x
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class TemporalInput(nn.Module):
+ def __init__(self, cfg: DictConfig, channel_dim: int, motion_v_dim: int):
+ super().__init__()
+
+ self.cfg = cfg
+ self.motion_v_dim = motion_v_dim
+ self.const = nn.Parameter(torch.randn(1, channel_dim, 4, 4))
+
+ def get_dim(self):
+ return self.motion_v_dim + self.const.shape[1]
+
+ def forward(self, motion_v: torch.Tensor) -> torch.Tensor:
+ """
+ motion_v: [batch_size, motion_v_dim]
+ """
+ out = torch.cat([
+ self.const.repeat(len(motion_v), 1, 1, 1),
+ motion_v.unsqueeze(2).unsqueeze(3).repeat(1, 1, *self.const.shape[2:]),
+ ], dim=1) # [batch_size, channel_dim + num_fourier_feats * 2]
+
+ return out
+
+#----------------------------------------------------------------------------
+
+class TemporalDifferenceEncoder(nn.Module):
+ def __init__(self, cfg: DictConfig):
+ super().__init__()
+
+ self.cfg = cfg
+
+ if self.cfg.sampling.num_frames_per_video > 1:
+ self.d = 256
+ self.const_embed = nn.Embedding(self.cfg.sampling.max_num_frames, self.d)
+ self.time_encoder = FixedTimeEncoder(
+ self.cfg.sampling.max_num_frames,
+ skip_small_t_freqs=self.cfg.get('skip_small_t_freqs', 0))
+
+ def get_dim(self) -> int:
+ if self.cfg.sampling.num_frames_per_video == 1:
+ return 1
+ else:
+ if self.cfg.sampling.type == 'uniform':
+ return self.d + self.time_encoder.get_dim()
+ else:
+ return (self.d + self.time_encoder.get_dim()) * (self.cfg.sampling.num_frames_per_video - 1)
+
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
+ misc.assert_shape(t, [None, self.cfg.sampling.num_frames_per_video])
+
+ batch_size = t.shape[0]
+
+ if self.cfg.sampling.num_frames_per_video == 1:
+ out = torch.zeros(len(t), 1, device=t.device)
+ else:
+ if self.cfg.sampling.type == 'uniform':
+ num_diffs_to_use = 1
+ t_diffs = t[:, 1] - t[:, 0] # [batch_size]
+ else:
+ num_diffs_to_use = self.cfg.sampling.num_frames_per_video - 1
+ t_diffs = (t[:, 1:] - t[:, :-1]).view(-1) # [batch_size * (num_frames - 1)]
+ # Note: float => round => long is necessary when it's originally long
+ const_embs = self.const_embed(t_diffs.float().round().long()) # [batch_size * num_diffs_to_use, d]
+ fourier_embs = self.time_encoder(t_diffs.unsqueeze(1)) # [batch_size * num_diffs_to_use, num_fourier_feats]
+ out = torch.cat([const_embs, fourier_embs], dim=1) # [batch_size * num_diffs_to_use, d + num_fourier_feats]
+ out = out.view(batch_size, num_diffs_to_use, -1).view(batch_size, -1) # [batch_size, num_diffs_to_use * (d + num_fourier_feats)]
+
+ return out
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class FixedTimeEncoder(nn.Module):
+ def __init__(self,
+ max_num_frames: int, # Maximum T size
+ skip_small_t_freqs: int=0, # How many high frequencies we should skip
+ ):
+ super().__init__()
+
+ assert max_num_frames >= 1, f"Wrong max_num_frames: {max_num_frames}"
+ fourier_coefs = construct_log_spaced_freqs(max_num_frames, skip_small_t_freqs=skip_small_t_freqs)
+ self.register_buffer('fourier_coefs', fourier_coefs) # [1, num_fourier_feats]
+
+ def get_dim(self) -> int:
+ return self.fourier_coefs.shape[1] * 2
+
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
+ assert t.ndim == 2, f"Wrong shape: {t.shape}"
+
+ t = t.view(-1).float() # [batch_size * num_frames]
+ fourier_raw_embs = self.fourier_coefs * t.unsqueeze(1) # [bf, num_fourier_feats]
+
+ fourier_embs = torch.cat([
+ fourier_raw_embs.sin(),
+ fourier_raw_embs.cos(),
+ ], dim=1) # [bf, num_fourier_feats * 2]
+
+ return fourier_embs
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class EqLRConv1d(nn.Module):
+ def __init__(self,
+ in_features: int,
+ out_features: int,
+ kernel_size: int,
+ padding: int=0,
+ stride: int=1,
+ activation: str='linear',
+ lr_multiplier: float=1.0,
+ bias=True,
+ bias_init=0.0,
+ ):
+ super().__init__()
+
+ self.activation = activation
+ self.weight = torch.nn.Parameter(torch.randn([out_features, in_features, kernel_size]) / lr_multiplier)
+ self.bias = torch.nn.Parameter(torch.full([out_features], float(bias_init))) if bias else None
+ self.weight_gain = lr_multiplier / np.sqrt(in_features * kernel_size)
+ self.bias_gain = lr_multiplier
+ self.padding = padding
+ self.stride = stride
+
+ assert self.activation in ['lrelu', 'linear']
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ assert x.ndim == 3, f"Wrong shape: {x.shape}"
+
+ w = self.weight.to(x.dtype) * self.weight_gain # [out_features, in_features, kernel_size]
+ b = self.bias # [out_features]
+ if b is not None:
+ b = b.to(x.dtype)
+ if self.bias_gain != 1:
+ b = b * self.bias_gain
+
+ y = F.conv1d(input=x, weight=w, bias=b, stride=self.stride, padding=self.padding) # [batch_size, out_features, out_len]
+ if self.activation == 'linear':
+ pass
+ elif self.activation == 'lrelu':
+ y = F.leaky_relu(y, negative_slope=0.2) # [batch_size, out_features, out_len]
+ else:
+ raise NotImplementedError
+ return y
+
+#----------------------------------------------------------------------------
+
+def sample_frames(cfg: Dict, total_video_len: int, **kwargs) -> np.ndarray:
+ if cfg['type'] == 'random':
+ return random_frame_sampling(cfg, total_video_len, **kwargs)
+ elif cfg['type'] == 'uniform':
+ return uniform_frame_sampling(cfg, total_video_len, **kwargs)
+ else:
+ raise NotImplementedError
+
+#----------------------------------------------------------------------------
+
+def random_frame_sampling(cfg: Dict, total_video_len: int, use_fractional_t: bool=False) -> np.ndarray:
+ min_time_diff = cfg["num_frames_per_video"] - 1
+ max_time_diff = min(total_video_len - 1, cfg.get('max_dist', float('inf')))
+
+ if type(cfg.get('total_dists')) in (list, tuple):
+ time_diff_range = [d for d in cfg['total_dists'] if min_time_diff <= d <= max_time_diff]
+ else:
+ time_diff_range = range(min_time_diff, max_time_diff)
+
+ time_diff: int = random.choice(time_diff_range)
+ if use_fractional_t:
+ offset = random.random() * (total_video_len - time_diff - 1)
+ else:
+ offset = random.randint(0, total_video_len - time_diff - 1)
+ frames_idx = [offset]
+
+ if cfg["num_frames_per_video"] > 1:
+ frames_idx.append(offset + time_diff)
+
+ if cfg["num_frames_per_video"] > 2:
+ frames_idx.extend([(offset + t) for t in random.sample(range(1, time_diff), k=cfg["num_frames_per_video"] - 2)])
+
+ frames_idx = sorted(frames_idx)
+
+ return np.array(frames_idx)
+
+#----------------------------------------------------------------------------
+
+def uniform_frame_sampling(cfg: Dict, total_video_len: int, use_fractional_t: bool=False) -> np.ndarray:
+ # Step 1: Select the distance between frames
+ if type(cfg.get('dists_between_frames')) in (list, tuple):
+ valid_dists = [d for d in cfg['dists_between_frames'] if d <= ['max_dist_between_frames']]
+ valid_dists = [d for d in valid_dists if (d * cfg['num_frames_per_video'] - d + 1) <= total_video_len]
+ d = random.choice(valid_dists)
+ else:
+ max_dist = min(cfg.get('max_dist', float('inf')), total_video_len // cfg['num_frames_per_video'])
+ d = random.randint(1, max_dist)
+
+ d_total = d * cfg['num_frames_per_video'] - d + 1
+
+ # Step 2: Sample.
+ if use_fractional_t:
+ offset = random.random() * (total_video_len - d_total)
+ else:
+ offset = random.randint(0, total_video_len - d_total)
+
+ frames_idx = offset + np.arange(cfg['num_frames_per_video']) * d
+
+ return frames_idx
+
+#----------------------------------------------------------------------------
+
+def construct_log_spaced_freqs(max_num_frames: int, skip_small_t_freqs: int=0) -> Tuple[int, torch.Tensor]:
+ time_resolution = 2 ** np.ceil(np.log2(max_num_frames))
+ num_fourier_feats = np.ceil(np.log2(time_resolution)).astype(int)
+ powers = torch.tensor([2]).repeat(num_fourier_feats).pow(torch.arange(num_fourier_feats)) # [num_fourier_feats]
+ powers = powers[:len(powers) - skip_small_t_freqs] # [num_fourier_feats]
+ fourier_coefs = powers.unsqueeze(0).float() * np.pi # [1, num_fourier_feats]
+
+ return fourier_coefs / time_resolution
+
+#----------------------------------------------------------------------------
diff --git a/train.py b/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..64c4fe1c68bc1b32412ffb87d3117b11cd8f67e1
--- /dev/null
+++ b/train.py
@@ -0,0 +1,281 @@
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+A minimal training script for Latte using PyTorch DDP.
+"""
+
+
+import torch
+# Maybe use fp16 percision training need to set to False
+torch.backends.cuda.matmul.allow_tf32 = True
+torch.backends.cudnn.allow_tf32 = True
+
+import io
+import os
+import math
+import argparse
+
+import torch.distributed as dist
+from glob import glob
+from time import time
+from copy import deepcopy
+from einops import rearrange
+from models import get_models
+from datasets import get_dataset
+from models.clip import TextEmbedder
+from diffusion import create_diffusion
+from omegaconf import OmegaConf
+from torch.utils.data import DataLoader
+from diffusers.models import AutoencoderKL
+from diffusers.optimization import get_scheduler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.data.distributed import DistributedSampler
+from utils import (clip_grad_norm_, create_logger, update_ema,
+ requires_grad, cleanup, create_tensorboard,
+ write_tensorboard, setup_distributed,
+ get_experiment_dir, text_preprocessing)
+import numpy as np
+from transformers import T5EncoderModel, T5Tokenizer
+
+#################################################################################
+# Training Loop #
+#################################################################################
+
+def main(args):
+
+ assert torch.cuda.is_available(), "Training currently requires at least one GPU."
+
+ # Setup DDP:
+ setup_distributed()
+ # dist.init_process_group("nccl")
+ # assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size."
+ # rank = dist.get_rank()
+ # device = rank % torch.cuda.device_count()
+ # local_rank = rank
+
+ rank = int(os.environ["RANK"])
+ local_rank = int(os.environ["LOCAL_RANK"])
+ device = torch.device("cuda", local_rank)
+
+ seed = args.global_seed + rank
+ torch.manual_seed(seed)
+ torch.cuda.set_device(device)
+ print(f"Starting rank={rank}, local rank={local_rank}, seed={seed}, world_size={dist.get_world_size()}.")
+
+ # Setup an experiment folder:
+ if rank == 0:
+ os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
+ experiment_index = len(glob(f"{args.results_dir}/*"))
+ model_string_name = args.model.replace("/", "-") # e.g., Latte-XL/2 --> Latte-XL-2 (for naming folders)
+ num_frame_string = 'F' + str(args.num_frames) + 'S' + str(args.frame_interval)
+ experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}-{num_frame_string}-{args.dataset}" # Create an experiment folder
+ experiment_dir = get_experiment_dir(experiment_dir, args)
+ checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
+ os.makedirs(checkpoint_dir, exist_ok=True)
+ logger = create_logger(experiment_dir)
+ tb_writer = create_tensorboard(experiment_dir)
+ OmegaConf.save(args, os.path.join(experiment_dir, 'config.yaml'))
+ logger.info(f"Experiment directory created at {experiment_dir}")
+ else:
+ logger = create_logger(None)
+ tb_writer = None
+
+ # Create model:
+ assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
+ sample_size = args.image_size // 8
+ args.latent_size = sample_size
+ model = get_models(args)
+ # Note that parameter initialization is done within the Latte constructor
+ ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
+ requires_grad(ema, False)
+ diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule
+ # vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema").to(device)
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device)
+
+ # # use pretrained model?
+ if args.pretrained:
+ checkpoint = torch.load(args.pretrained, map_location=lambda storage, loc: storage)
+ if "ema" in checkpoint: # supports checkpoints from train.py
+ logger.info('Using ema ckpt!')
+ checkpoint = checkpoint["ema"]
+
+ model_dict = model.state_dict()
+ # 1. filter out unnecessary keys
+ pretrained_dict = {}
+ for k, v in checkpoint.items():
+ if k in model_dict:
+ pretrained_dict[k] = v
+ else:
+ logger.info('Ignoring: {}'.format(k))
+ logger.info('Successfully Load {}% original pretrained model weights '.format(len(pretrained_dict) / len(checkpoint.items()) * 100))
+ # 2. overwrite entries in the existing state dict
+ model_dict.update(pretrained_dict)
+ model.load_state_dict(model_dict)
+ logger.info('Successfully load model at {}!'.format(args.pretrained))
+
+ if args.use_compile:
+ model = torch.compile(model)
+
+ # set distributed training
+ model = DDP(model.to(device), device_ids=[local_rank])
+
+ logger.info(f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
+ opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0)
+
+ # Freeze vae and text_encoder
+ vae.requires_grad_(False)
+
+ # Setup data:
+ dataset = get_dataset(args)
+
+ sampler = DistributedSampler(
+ dataset,
+ num_replicas=dist.get_world_size(),
+ rank=rank,
+ shuffle=True,
+ seed=args.global_seed
+ )
+ loader = DataLoader(
+ dataset,
+ batch_size=int(args.local_batch_size),
+ shuffle=False,
+ sampler=sampler,
+ num_workers=args.num_workers,
+ pin_memory=True,
+ drop_last=True
+ )
+ logger.info(f"Dataset contains {len(dataset):,} videos ({args.data_path})")
+
+ # Scheduler
+ lr_scheduler = get_scheduler(
+ name="constant",
+ optimizer=opt,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ # Prepare models for training:
+ update_ema(ema, model.module, decay=0) # Ensure EMA is initialized with synced weights
+ model.train() # important! This enables embedding dropout for classifier-free guidance
+ ema.eval() # EMA model should always be in eval mode
+
+ # Variables for monitoring/logging purposes:
+ train_steps = 0
+ log_steps = 0
+ running_loss = 0
+ first_epoch = 0
+ start_time = time()
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(loader))
+ # Afterwards we recalculate our number of training epochs
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ # TODO, need to checkout
+ # Get the most recent checkpoint
+ dirs = os.listdir(os.path.join(experiment_dir, 'checkpoints'))
+ dirs = [d for d in dirs if d.endswith("pt")]
+ dirs = sorted(dirs, key=lambda x: int(x.split(".")[0]))
+ path = dirs[-1]
+ logger.info(f"Resuming from checkpoint {path}")
+ model.load_state(os.path.join(dirs, path))
+ train_steps = int(path.split(".")[0])
+
+ first_epoch = train_steps // num_update_steps_per_epoch
+ resume_step = train_steps % num_update_steps_per_epoch
+
+ if args.pretrained:
+ train_steps = int(args.pretrained.split("/")[-1].split('.')[0])
+
+ for epoch in range(first_epoch, num_train_epochs):
+ sampler.set_epoch(epoch)
+ for step, video_data in enumerate(loader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ continue
+
+ x = video_data['video'].to(device, non_blocking=True)
+ video_name = video_data['video_name']
+ with torch.no_grad():
+ # Map input images to latent space + normalize latents:
+ b, _, _, _, _ = x.shape
+ x = rearrange(x, 'b f c h w -> (b f) c h w').contiguous()
+ x = vae.encode(x).latent_dist.sample().mul_(0.18215)
+ x = rearrange(x, '(b f) c h w -> b f c h w', b=b).contiguous()
+
+ if args.extras == 78: # text-to-video
+ raise 'T2V training are Not supported at this moment!'
+ elif args.extras == 2:
+ model_kwargs = dict(y=video_name)
+ else:
+ model_kwargs = dict(y=None)
+
+ t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device)
+ loss_dict = diffusion.training_losses(model, x, t, model_kwargs)
+ loss = loss_dict["loss"].mean()
+ loss.backward()
+
+ if train_steps < args.start_clip_iter: # if train_steps >= start_clip_iter, will clip gradient
+ gradient_norm = clip_grad_norm_(model.module.parameters(), args.clip_max_norm, clip_grad=False)
+ else:
+ gradient_norm = clip_grad_norm_(model.module.parameters(), args.clip_max_norm, clip_grad=True)
+
+ opt.step()
+ lr_scheduler.step()
+ opt.zero_grad()
+ update_ema(ema, model.module)
+
+ # Log loss values:
+ running_loss += loss.item()
+ log_steps += 1
+ train_steps += 1
+ if train_steps % args.log_every == 0:
+ # Measure training speed:
+ torch.cuda.synchronize()
+ end_time = time()
+ steps_per_sec = log_steps / (end_time - start_time)
+ # Reduce loss history over all processes:
+ avg_loss = torch.tensor(running_loss / log_steps, device=device)
+ dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
+ avg_loss = avg_loss.item() / dist.get_world_size()
+ # logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
+ logger.info(f"(step={train_steps:07d}/epoch={epoch:04d}) Train Loss: {avg_loss:.4f}, Gradient Norm: {gradient_norm:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
+ write_tensorboard(tb_writer, 'Train Loss', avg_loss, train_steps)
+ write_tensorboard(tb_writer, 'Gradient Norm', gradient_norm, train_steps)
+ # Reset monitoring variables:
+ running_loss = 0
+ log_steps = 0
+ start_time = time()
+
+ # Save Latte checkpoint:
+ if train_steps % args.ckpt_every == 0 and train_steps > 0:
+ if rank == 0:
+ checkpoint = {
+ "model": model.module.state_dict(),
+ "ema": ema.state_dict(),
+ # "opt": opt.state_dict(),
+ # "args": args
+ }
+ checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
+ torch.save(checkpoint, checkpoint_path)
+ logger.info(f"Saved checkpoint to {checkpoint_path}")
+ dist.barrier()
+
+ model.eval() # important! This disables randomized embedding dropout
+ # do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
+
+ logger.info("Done!")
+ cleanup()
+
+
+if __name__ == "__main__":
+ # Default args here will train Latte with the hyperparameters we used in our paper (except training iters).
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, default="./configs/train.yaml")
+ args = parser.parse_args()
+ main(OmegaConf.load(args.config))
diff --git a/train_scripts/ffs_train.sh b/train_scripts/ffs_train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9648d6b6625f7e19578fbc808f3139d5f13408b5
--- /dev/null
+++ b/train_scripts/ffs_train.sh
@@ -0,0 +1,3 @@
+export CUDA_VISIBLE_DEVICES=5
+# torchrun --nnodes=1 --nproc_per_node=2 --master_port=29509 train.py --config ./configs/ffs/ffs_train.yaml
+python train.py --config ./configs/ffs/ffs_train.yaml
\ No newline at end of file
diff --git a/train_scripts/sky_train.sh b/train_scripts/sky_train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..54de0bb463ba79e6edadea08c4770a6e2682f909
--- /dev/null
+++ b/train_scripts/sky_train.sh
@@ -0,0 +1,2 @@
+export CUDA_VISIBLE_DEVICES=4,5
+torchrun --nnodes=1 --nproc_per_node=2 --master_port=29509 train.py --config ./configs/sky/sky_train.yaml
\ No newline at end of file
diff --git a/train_scripts/taichi_train.sh b/train_scripts/taichi_train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..bce53214b4d49353ef70c854b1e6de44d90542df
--- /dev/null
+++ b/train_scripts/taichi_train.sh
@@ -0,0 +1,2 @@
+export CUDA_VISIBLE_DEVICES=4,5
+torchrun --nnodes=1 --nproc_per_node=2 --master_port=29509 train.py --config ./configs/taichi/taichi_train.yaml
\ No newline at end of file
diff --git a/train_scripts/ucf101_train.sh b/train_scripts/ucf101_train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..26822a67d966ee9b4f65749d363351e360ed2320
--- /dev/null
+++ b/train_scripts/ucf101_train.sh
@@ -0,0 +1,2 @@
+export CUDA_VISIBLE_DEVICES=4,5
+torchrun --nnodes=1 --nproc_per_node=2 --master_port=29509 train.py --config ./configs/ucf101/ucf101_train.yaml
\ No newline at end of file
diff --git a/train_with_img.py b/train_with_img.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8561b2e7581eef2962ac570b725209cbbfc48eb
--- /dev/null
+++ b/train_with_img.py
@@ -0,0 +1,302 @@
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+A minimal training script for Latte using PyTorch DDP.
+"""
+
+
+import torch
+# Maybe use fp16 percision training need to set to False
+torch.backends.cuda.matmul.allow_tf32 = True
+torch.backends.cudnn.allow_tf32 = True
+
+import os
+import math
+import argparse
+
+import torch.distributed as dist
+from glob import glob
+from time import time
+from copy import deepcopy
+from einops import rearrange
+from models import get_models
+from datasets import get_dataset
+from models.clip import TextEmbedder
+from diffusion import create_diffusion
+from omegaconf import OmegaConf
+from torch.utils.data import DataLoader
+from diffusers.models import AutoencoderKL
+from diffusers.optimization import get_scheduler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.data.distributed import DistributedSampler
+from utils import (clip_grad_norm_, create_logger, update_ema,
+ requires_grad, cleanup, create_tensorboard,
+ write_tensorboard, setup_distributed, get_experiment_dir)
+
+
+#################################################################################
+# Training Loop #
+#################################################################################
+
+def main(args):
+
+ assert torch.cuda.is_available(), "Training currently requires at least one GPU."
+
+ # Setup DDP:
+ setup_distributed()
+
+ rank = int(os.environ["RANK"])
+ local_rank = int(os.environ["LOCAL_RANK"])
+ device = torch.device("cuda", local_rank)
+
+ seed = args.global_seed + rank
+ torch.manual_seed(seed)
+ torch.cuda.set_device(device)
+ print(f"Starting rank={rank}, local rank={local_rank}, seed={seed}, world_size={dist.get_world_size()}.")
+
+ # Setup an experiment folder:
+ if rank == 0:
+ os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
+ experiment_index = len(glob(f"{args.results_dir}/*"))
+ model_string_name = args.model.replace("/", "-") # e.g., Latte-XL/2 --> Latte-XL-2 (for naming folders)
+ num_frame_string = 'F' + str(args.num_frames) + 'S' + str(args.frame_interval)
+ experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}-{num_frame_string}-{args.dataset}" # Create an experiment folder
+ experiment_dir = get_experiment_dir(experiment_dir, args)
+ checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
+ os.makedirs(checkpoint_dir, exist_ok=True)
+ logger = create_logger(experiment_dir)
+ tb_writer = create_tensorboard(experiment_dir)
+ OmegaConf.save(args, os.path.join(experiment_dir, 'config.yaml'))
+ logger.info(f"Experiment directory created at {experiment_dir}")
+ else:
+ logger = create_logger(None)
+ tb_writer = None
+
+ # Create model:
+ assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
+ sample_size = args.image_size // 8
+ args.latent_size = sample_size
+ model = get_models(args)
+ # Note that parameter initialization is done within the Latte constructor
+ ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
+ requires_grad(ema, False)
+ diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule
+ # vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema").to(device)
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="sd-vae-ft-mse").to(device)
+
+ # # use pretrained model?
+ if args.pretrained:
+ checkpoint = torch.load(args.pretrained, map_location=lambda storage, loc: storage)
+ if "ema" in checkpoint: # supports checkpoints from train.py
+ logger.info('Using ema ckpt!')
+ checkpoint = checkpoint["ema"]
+
+ model_dict = model.state_dict()
+ # 1. filter out unnecessary keys
+ pretrained_dict = {}
+ for k, v in checkpoint.items():
+ if k in model_dict:
+ pretrained_dict[k] = v
+ else:
+ logger.info('Ignoring: {}'.format(k))
+ logger.info('Successfully Load {}% original pretrained model weights '.format(len(pretrained_dict) / len(checkpoint.items()) * 100))
+ # 2. overwrite entries in the existing state dict
+ model_dict.update(pretrained_dict)
+ model.load_state_dict(model_dict)
+ logger.info('Successfully load model at {}!'.format(args.pretrained))
+
+ if args.use_compile:
+ model = torch.compile(model)
+
+ if args.enable_xformers_memory_efficient_attention:
+ logger.info("Using Xformers!")
+ model.enable_xformers_memory_efficient_attention()
+
+ if args.gradient_checkpointing:
+ logger.info("Using gradient checkpointing!")
+ model.enable_gradient_checkpointing()
+
+ if args.fixed_spatial:
+ trainable_modules = (
+ "attn_temp",
+ )
+ model.requires_grad_(False)
+ for name, module in model.named_modules():
+ if name.endswith(tuple(trainable_modules)):
+ for params in module.parameters():
+ logger.info("WARNING: Only train {} parametes!".format(name))
+ params.requires_grad = True
+ logger.info("WARNING: Only train {} parametes!".format(trainable_modules))
+
+ # set distributed training
+ model = DDP(model.to(device), device_ids=[local_rank])
+
+ logger.info(f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
+ opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0)
+
+ # Freeze vae and text_encoder
+ vae.requires_grad_(False)
+
+ # Setup data:
+ dataset = get_dataset(args)
+
+ sampler = DistributedSampler(
+ dataset,
+ num_replicas=dist.get_world_size(),
+ rank=rank,
+ shuffle=True,
+ seed=args.global_seed
+ )
+ loader = DataLoader(
+ dataset,
+ batch_size=int(args.local_batch_size),
+ shuffle=False,
+ sampler=sampler,
+ num_workers=args.num_workers,
+ pin_memory=True,
+ drop_last=True
+ )
+ logger.info(f"Dataset contains {len(dataset):,} videos ({args.webvideo_data_path})")
+
+ # Scheduler
+ lr_scheduler = get_scheduler(
+ name="constant",
+ optimizer=opt,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ # Prepare models for training:
+ update_ema(ema, model.module, decay=0) # Ensure EMA is initialized with synced weights
+ model.train() # important! This enables embedding dropout for classifier-free guidance
+ ema.eval() # EMA model should always be in eval mode
+
+ # Variables for monitoring/logging purposes:
+ train_steps = 0
+ log_steps = 0
+ running_loss = 0
+ first_epoch = 0
+ start_time = time()
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(loader))
+ # Afterwards we recalculate our number of training epochs
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ # TODO, need to checkout
+ # Get the most recent checkpoint
+ dirs = os.listdir(os.path.join(experiment_dir, 'checkpoints'))
+ dirs = [d for d in dirs if d.endswith("pt")]
+ dirs = sorted(dirs, key=lambda x: int(x.split(".")[0]))
+ path = dirs[-1]
+ logger.info(f"Resuming from checkpoint {path}")
+ model.load_state(os.path.join(dirs, path))
+ train_steps = int(path.split(".")[0])
+
+ first_epoch = train_steps // num_update_steps_per_epoch
+ resume_step = train_steps % num_update_steps_per_epoch
+
+ for epoch in range(first_epoch, num_train_epochs):
+ sampler.set_epoch(epoch)
+ for step, video_data in enumerate(loader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ continue
+
+ x = video_data['video'].to(device, non_blocking=True)
+ video_name = video_data['video_name']
+ if args.dataset == "ucf101_img":
+ image_name = video_data['image_name']
+ image_names = []
+ for caption in image_name:
+ single_caption = [int(item) for item in caption.split('=====')]
+ image_names.append(torch.as_tensor(single_caption))
+ # x = x.to(device)
+ # y = y.to(device) # y is text prompt; no need put in gpu
+ with torch.no_grad():
+ # Map input images to latent space + normalize latents:
+ b, _, _, _, _ = x.shape
+ x = rearrange(x, 'b f c h w -> (b f) c h w').contiguous()
+ x = vae.encode(x).latent_dist.sample().mul_(0.18215)
+ x = rearrange(x, '(b f) c h w -> b f c h w', b=b).contiguous()
+
+ if args.extras == 78: # text-to-video
+ raise 'T2V training are Not supported at this moment!'
+ elif args.extras == 2:
+ if args.dataset == "ucf101_img":
+ model_kwargs = dict(y=video_name, y_image=image_names, use_image_num=args.use_image_num) # tav unet
+ else:
+ model_kwargs = dict(y=video_name) # tav unet
+ else:
+ model_kwargs = dict(y=None, use_image_num=args.use_image_num)
+
+ t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device)
+ loss_dict = diffusion.training_losses(model, x, t, model_kwargs)
+ loss = loss_dict["loss"].mean()
+ loss.backward()
+
+ if train_steps < args.start_clip_iter: # if train_steps >= start_clip_iter, will clip gradient
+ gradient_norm = clip_grad_norm_(model.module.parameters(), args.clip_max_norm, clip_grad=False)
+ else:
+ gradient_norm = clip_grad_norm_(model.module.parameters(), args.clip_max_norm, clip_grad=True)
+
+ opt.step()
+ lr_scheduler.step()
+ opt.zero_grad()
+ update_ema(ema, model.module)
+
+ # Log loss values:
+ running_loss += loss.item()
+ log_steps += 1
+ train_steps += 1
+ if train_steps % args.log_every == 0:
+ # Measure training speed:
+ torch.cuda.synchronize()
+ end_time = time()
+ steps_per_sec = log_steps / (end_time - start_time)
+ # Reduce loss history over all processes:
+ avg_loss = torch.tensor(running_loss / log_steps, device=device)
+ dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
+ avg_loss = avg_loss.item() / dist.get_world_size()
+ # logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
+ logger.info(f"(step={train_steps:07d}/epoch={epoch:04d}) Train Loss: {avg_loss:.4f}, Gradient Norm: {gradient_norm:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
+ write_tensorboard(tb_writer, 'Train Loss', avg_loss, train_steps)
+ write_tensorboard(tb_writer, 'Gradient Norm', gradient_norm, train_steps)
+ # Reset monitoring variables:
+ running_loss = 0
+ log_steps = 0
+ start_time = time()
+
+ # Save Latte checkpoint:
+ if train_steps % args.ckpt_every == 0 and train_steps > 0:
+ if rank == 0:
+ checkpoint = {
+ # "model": model.module.state_dict(),
+ "ema": ema.state_dict(),
+ # "opt": opt.state_dict(),
+ # "args": args
+ }
+
+ checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
+ torch.save(checkpoint, checkpoint_path)
+ logger.info(f"Saved checkpoint to {checkpoint_path}")
+ dist.barrier()
+
+ model.eval() # important! This disables randomized embedding dropout
+ # do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
+
+ logger.info("Done!")
+ cleanup()
+
+
+if __name__ == "__main__":
+ # Default args here will train Latte-XL/2 with the hyperparameters we used in our paper (except training iters).
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, default="./configs/sky/sky_train.yaml")
+ args = parser.parse_args()
+ main(OmegaConf.load(args.config))
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..42c3356969c1db87e9de196ee7481c4bac9cbe5e
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,457 @@
+import os
+import math
+import torch
+import logging
+import random
+import subprocess
+import numpy as np
+import torch.distributed as dist
+
+from torch import inf
+from PIL import Image
+from typing import Union, Iterable
+from collections import OrderedDict
+from torch.utils.tensorboard import SummaryWriter
+
+from diffusers.utils import is_bs4_available, is_ftfy_available
+
+import html
+import re
+import urllib.parse as ul
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
+
+
+#################################################################################
+# Training Clip Gradients #
+#################################################################################
+
+def get_grad_norm(
+ parameters: _tensor_or_tensors, norm_type: float = 2.0) -> torch.Tensor:
+ r"""
+ Copy from torch.nn.utils.clip_grad_norm_
+
+ Clips gradient norm of an iterable of parameters.
+
+ The norm is computed over all gradients together, as if they were
+ concatenated into a single vector. Gradients are modified in-place.
+
+ Args:
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
+ single Tensor that will have gradients normalized
+ max_norm (float or int): max norm of the gradients
+ norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
+ infinity norm.
+ error_if_nonfinite (bool): if True, an error is thrown if the total
+ norm of the gradients from :attr:`parameters` is ``nan``,
+ ``inf``, or ``-inf``. Default: False (will switch to True in the future)
+
+ Returns:
+ Total norm of the parameter gradients (viewed as a single vector).
+ """
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ grads = [p.grad for p in parameters if p.grad is not None]
+ norm_type = float(norm_type)
+ if len(grads) == 0:
+ return torch.tensor(0.)
+ device = grads[0].device
+ if norm_type == inf:
+ norms = [g.detach().abs().max().to(device) for g in grads]
+ total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
+ else:
+ total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
+ return total_norm
+
+def clip_grad_norm_(
+ parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
+ error_if_nonfinite: bool = False, clip_grad = True) -> torch.Tensor:
+ r"""
+ Copy from torch.nn.utils.clip_grad_norm_
+
+ Clips gradient norm of an iterable of parameters.
+
+ The norm is computed over all gradients together, as if they were
+ concatenated into a single vector. Gradients are modified in-place.
+
+ Args:
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
+ single Tensor that will have gradients normalized
+ max_norm (float or int): max norm of the gradients
+ norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
+ infinity norm.
+ error_if_nonfinite (bool): if True, an error is thrown if the total
+ norm of the gradients from :attr:`parameters` is ``nan``,
+ ``inf``, or ``-inf``. Default: False (will switch to True in the future)
+
+ Returns:
+ Total norm of the parameter gradients (viewed as a single vector).
+ """
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ grads = [p.grad for p in parameters if p.grad is not None]
+ max_norm = float(max_norm)
+ norm_type = float(norm_type)
+ if len(grads) == 0:
+ return torch.tensor(0.)
+ device = grads[0].device
+ if norm_type == inf:
+ norms = [g.detach().abs().max().to(device) for g in grads]
+ total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
+ else:
+ total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
+
+ if clip_grad:
+ if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
+ raise RuntimeError(
+ f'The total norm of order {norm_type} for gradients from '
+ '`parameters` is non-finite, so it cannot be clipped. To disable '
+ 'this error and scale the gradients by the non-finite norm anyway, '
+ 'set `error_if_nonfinite=False`')
+ clip_coef = max_norm / (total_norm + 1e-6)
+ # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
+ # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
+ # when the gradients do not reside in CPU memory.
+ clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
+ for g in grads:
+ g.detach().mul_(clip_coef_clamped.to(g.device))
+ # gradient_cliped = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
+ return total_norm
+
+def get_experiment_dir(root_dir, args):
+ # if args.pretrained is not None and 'Latte-XL-2-256x256.pt' not in args.pretrained:
+ # root_dir += '-WOPRE'
+ if args.use_compile:
+ root_dir += '-Compile' # speedup by torch compile
+ if args.fixed_spatial:
+ root_dir += '-FixedSpa'
+ if args.enable_xformers_memory_efficient_attention:
+ root_dir += '-Xfor'
+ if args.gradient_checkpointing:
+ root_dir += '-Gc'
+ if args.mixed_precision:
+ root_dir += '-Amp'
+ if args.image_size == 512:
+ root_dir += '-512'
+ return root_dir
+
+#################################################################################
+# Training Logger #
+#################################################################################
+
+def create_logger(logging_dir):
+ """
+ Create a logger that writes to a log file and stdout.
+ """
+ if dist.get_rank() == 0: # real logger
+ logging.basicConfig(
+ level=logging.INFO,
+ # format='[\033[34m%(asctime)s\033[0m] %(message)s',
+ format='[%(asctime)s] %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
+ )
+ logger = logging.getLogger(__name__)
+
+ else: # dummy logger (does nothing)
+ logger = logging.getLogger(__name__)
+ logger.addHandler(logging.NullHandler())
+ return logger
+
+
+def create_tensorboard(tensorboard_dir):
+ """
+ Create a tensorboard that saves losses.
+ """
+ if dist.get_rank() == 0: # real tensorboard
+ # tensorboard
+ writer = SummaryWriter(tensorboard_dir)
+
+ return writer
+
+def write_tensorboard(writer, *args):
+ '''
+ write the loss information to a tensorboard file.
+ Only for pytorch DDP mode.
+ '''
+ if dist.get_rank() == 0: # real tensorboard
+ writer.add_scalar(args[0], args[1], args[2])
+
+#################################################################################
+# EMA Update/ DDP Training Utils #
+#################################################################################
+
+@torch.no_grad()
+def update_ema(ema_model, model, decay=0.9999):
+ """
+ Step the EMA model towards the current model.
+ """
+ ema_params = OrderedDict(ema_model.named_parameters())
+ model_params = OrderedDict(model.named_parameters())
+
+ for name, param in model_params.items():
+ # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
+ ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
+
+def requires_grad(model, flag=True):
+ """
+ Set requires_grad flag for all parameters in a model.
+ """
+ for p in model.parameters():
+ p.requires_grad = flag
+
+def cleanup():
+ """
+ End DDP training.
+ """
+ dist.destroy_process_group()
+
+
+def setup_distributed(backend="nccl", port=None):
+ """Initialize distributed training environment.
+ support both slurm and torch.distributed.launch
+ see torch.distributed.init_process_group() for more details
+ """
+ num_gpus = torch.cuda.device_count()
+
+ if "SLURM_JOB_ID" in os.environ:
+ rank = int(os.environ["SLURM_PROCID"])
+ world_size = int(os.environ["SLURM_NTASKS"])
+ node_list = os.environ["SLURM_NODELIST"]
+ addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
+ # specify master port
+ if port is not None:
+ os.environ["MASTER_PORT"] = str(port)
+ elif "MASTER_PORT" not in os.environ:
+ # os.environ["MASTER_PORT"] = "29566"
+ os.environ["MASTER_PORT"] = str(29567 + num_gpus)
+ if "MASTER_ADDR" not in os.environ:
+ os.environ["MASTER_ADDR"] = addr
+ os.environ["WORLD_SIZE"] = str(world_size)
+ os.environ["LOCAL_RANK"] = str(rank % num_gpus)
+ os.environ["RANK"] = str(rank)
+ else:
+ rank = int(os.environ["RANK"])
+ world_size = int(os.environ["WORLD_SIZE"])
+
+ # torch.cuda.set_device(rank % num_gpus)
+
+ dist.init_process_group(
+ backend=backend,
+ world_size=world_size,
+ rank=rank,
+ )
+
+#################################################################################
+# Testing Utils #
+#################################################################################
+
+def save_video_grid(video, nrow=None):
+ b, t, h, w, c = video.shape
+
+ if nrow is None:
+ nrow = math.ceil(math.sqrt(b))
+ ncol = math.ceil(b / nrow)
+ padding = 1
+ video_grid = torch.zeros((t, (padding + h) * nrow + padding,
+ (padding + w) * ncol + padding, c), dtype=torch.uint8)
+
+ for i in range(b):
+ r = i // ncol
+ c = i % ncol
+ start_r = (padding + h) * r
+ start_c = (padding + w) * c
+ video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i]
+
+ return video_grid
+
+def find_model(model_name):
+ """
+ Finds a pre-trained Latte model, downloading it if necessary. Alternatively, loads a model from a local path.
+ """
+ assert os.path.isfile(model_name), f'Could not find Latte checkpoint at {model_name}'
+ checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
+
+ if "ema" in checkpoint: # supports checkpoints from train.py
+ print('Using Ema!')
+ checkpoint = checkpoint["ema"]
+ else:
+ print('Using model!')
+ checkpoint = checkpoint['model']
+ return checkpoint
+
+#################################################################################
+# MMCV Utils #
+#################################################################################
+
+
+def collect_env():
+ # Copyright (c) OpenMMLab. All rights reserved.
+ from mmcv.utils import collect_env as collect_base_env
+ from mmcv.utils import get_git_hash
+ """Collect the information of the running environments."""
+
+ env_info = collect_base_env()
+ env_info['MMClassification'] = get_git_hash()[:7]
+
+ for name, val in env_info.items():
+ print(f'{name}: {val}')
+
+ print(torch.cuda.get_arch_list())
+ print(torch.version.cuda)
+
+
+#################################################################################
+# Pixart-alpha Utils #
+#################################################################################
+
+bad_punct_regex = re.compile(
+ r"[" + "#ยฎโขยฉโข&@ยทยบยฝยพยฟยกยง~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
+)
+
+def text_preprocessing(text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = clean_caption(text)
+ text = clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+def clean_caption(caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0โ31EF CJK Strokes
+ # 31F0โ31FF Katakana Phonetic Extensions
+ # 3200โ32FF Enclosed CJK Letters and Months
+ # 3300โ33FF CJK Compatibility
+ # 3400โ4DBF CJK Unified Ideographs Extension A
+ # 4DC0โ4DFF Yijing Hexagram Symbols
+ # 4E00โ9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # ะฒัะต ะฒะธะดั ัะธัะต / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # ะบะฐะฒััะบะธ ะบ ะพะดะฝะพะผั ััะฐะฝะดะฐััั
+ caption = re.sub(r"[`ยดยซยปโโยจ]", '"', caption)
+ caption = re.sub(r"[โโ]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip adresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xั
ร]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+
+
+
+
+
diff --git a/visuals/architecture.svg b/visuals/architecture.svg
new file mode 100644
index 0000000000000000000000000000000000000000..7ad8be5fdec9d4194ee59e5e46c5537f8db00184
--- /dev/null
+++ b/visuals/architecture.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/visuals/latte.gif b/visuals/latte.gif
new file mode 100644
index 0000000000000000000000000000000000000000..76fed2abbbb6c380a9cf8d09e82c484f7e427a46
--- /dev/null
+++ b/visuals/latte.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cdf9e21af9816e9192554b7c4c29abc2a8aaf360bd7fc91b33c3a6baa6147425
+size 13943841
diff --git a/visuals/latteT2V.gif b/visuals/latteT2V.gif
new file mode 100644
index 0000000000000000000000000000000000000000..572e0187e47e4f3a33722c5c5ff347c0a267e2bb
--- /dev/null
+++ b/visuals/latteT2V.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5cf71d82889a0989d44be0adde1c16726cbf4cdf9e63a9d7e22a2c394a5dd891
+size 23910650