diff --git a/Text2Human/.gitignore b/Text2Human/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..053631a757f77ff272ac6da4613368c097bbadb4
--- /dev/null
+++ b/Text2Human/.gitignore
@@ -0,0 +1,9 @@
+__pycache__/
+.cache/
+datasets/*
+experiments/*
+tb_logger/*
+results/*
+*.png
+*.txt
+*.pth
\ No newline at end of file
diff --git a/Text2Human/LICENSE b/Text2Human/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..854cfa0bd6b1ed05fd3adbd07938128137f0f76f
--- /dev/null
+++ b/Text2Human/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Yuming Jiang
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/Text2Human/README.md b/Text2Human/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8e93c456558086ba888d76351ae194973f32dd20
--- /dev/null
+++ b/Text2Human/README.md
@@ -0,0 +1,255 @@
+# Text2Human - Official PyTorch Implementation
+
+
+
+This repository provides the official PyTorch implementation for the following paper:
+
+**Text2Human: Text-Driven Controllable Human Image Generation**
+[Yuming Jiang](https://yumingj.github.io/), [Shuai Yang](https://williamyang1991.github.io/), [Haonan Qiu](http://haonanqiu.com/), [Wayne Wu](https://dblp.org/pid/50/8731.html), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/) and [Ziwei Liu](https://liuziwei7.github.io/)
+In ACM Transactions on Graphics (Proceedings of SIGGRAPH), 2022.
+
+From [MMLab@NTU](https://www.mmlab-ntu.com/index.html) affliated with S-Lab, Nanyang Technological University and SenseTime Research.
+
+
+
+ |
+ |
+ |
+ |
+
+
+ The lady wears a short-sleeve T-shirt with pure color pattern, and a short and denim skirt. |
+ The man wears a long and floral shirt, and long pants with the pure color pattern. |
+ A lady is wearing a sleeveless pure-color shirt and long jeans |
+ The man wears a short-sleeve T-shirt with the pure color pattern and a short pants with the pure color pattern. |
+
+
+
+[**[Project Page]**](https://yumingj.github.io/projects/Text2Human.html) | [**[Paper]**](https://arxiv.org/pdf/2205.15996.pdf) | [**[Dataset]**](https://github.com/yumingj/DeepFashion-MultiModal) | [**[Demo Video]**](https://youtu.be/yKh4VORA_E0)
+
+
+## Updates
+
+- [05/2022] Paper and demo video are released.
+- [05/2022] Code is released.
+- [05/2022] This website is created.
+
+## Installation
+**Clone this repo:**
+```bash
+git clone https://github.com/yumingj/Text2Human.git
+cd Text2Human
+```
+**Dependencies:**
+
+All dependencies for defining the environment are provided in `environment/text2human_env.yaml`.
+We recommend using [Anaconda](https://docs.anaconda.com/anaconda/install/) to manage the python environment:
+```bash
+conda env create -f ./environment/text2human_env.yaml
+conda activate text2human
+conda install -c huggingface tokenizers=0.9.4
+conda install -c huggingface transformers=4.0.0
+conda install -c conda-forge sentence-transformers=2.0.0
+```
+
+If it doesn't work, you may need to install the following packages on your own:
+ - Python 3.6
+ - PyTorch 1.7.1
+ - CUDA 10.1
+ - [sentence-transformers](https://huggingface.co./sentence-transformers) 2.0.0
+ - [tokenizers](https://pypi.org/project/tokenizers/) 0.9.4
+ - [transformers](https://huggingface.co./docs/transformers/installation) 4.0.0
+
+## (1) Dataset Preparation
+
+In this work, we contribute a large-scale high-quality dataset with rich multi-modal annotations named [DeepFashion-MultiModal](https://github.com/yumingj/DeepFashion-MultiModal) Dataset.
+Here we pre-processed the raw annotations of the original dataset for the task of text-driven controllable human image generation. The pre-processing pipeline consists of:
+ - align the human body in the center of the images according to the human pose
+ - fuse the clothing color and clothing fabric annotations into one texture annotation
+ - do some annotation cleaning and image filtering
+ - split the whole dataset into the training set and testing set
+
+You can download our processed dataset from this [Google Drive](https://drive.google.com/file/d/1KIoFfRZNQVn6RV_wTxG2wZmY8f2T_84B/view?usp=sharing). If you want to access the raw annotations, please refer to the [DeepFashion-MultiModal](https://github.com/yumingj/DeepFashion-MultiModal) Dataset.
+
+After downloading the dataset, unzip the file and put them under the dataset folder with the following structure:
+```
+./datasets
+├── train_images
+ ├── xxx.png
+ ...
+ ├── xxx.png
+ └── xxx.png
+├── test_images
+ % the same structure as in train_images
+├── densepose
+ % the same structure as in train_images
+├── segm
+ % the same structure as in train_images
+├── shape_ann
+ ├── test_ann_file.txt
+ ├── train_ann_file.txt
+ └── val_ann_file.txt
+└── texture_ann
+ ├── test
+ ├── lower_fused.txt
+ ├── outer_fused.txt
+ └── upper_fused.txt
+ ├── train
+ % the same files as in test
+ └── val
+ % the same files as in test
+```
+
+## (2) Sampling
+
+### Inference Notebook
+
+Coming soon.
+
+
+### Pretrained Models
+
+Pretrained models can be downloaded from this [Google Drive](https://drive.google.com/file/d/1VyI8_AbPwAUaZJPaPba8zxsFIWumlDen/view?usp=sharing). Unzip the file and put them under the dataset folder with the following structure:
+```
+pretrained_models
+├── index_pred_net.pth
+├── parsing_gen.pth
+├── parsing_token.pth
+├── sampler.pth
+├── vqvae_bottom.pth
+└── vqvae_top.pth
+```
+
+### Generation from Paring Maps
+You can generate images from given parsing maps and pre-defined texture annotations:
+```python
+python sample_from_parsing.py -opt ./configs/sample_from_parsing.yml
+```
+The results are saved in the folder `./results/sampling_from_parsing`.
+
+### Generation from Poses
+You can generate images from given human poses and pre-defined clothing shape and texture annotations:
+```python
+python sample_from_pose.py -opt ./configs/sample_from_pose.yml
+```
+
+**Remarks**: The above two scripts generate images without language interactions. If you want to generate images using texts, you can use the notebook or our user interface.
+
+### User Interface
+
+```python
+python ui_demo.py
+```
+
+
+The descriptions for shapes should follow the following format:
+```
+, , , , , ...
+
+Note: The outer clothing type and accessories can be omitted.
+
+Examples:
+man, sleeveless T-shirt, long pants
+woman, short-sleeve T-shirt, short jeans
+```
+
+The descriptions for textures should follow the following format:
+```
+, ,
+
+Note: Currently, we only support 5 types of textures, i.e., pure color, stripe/spline, plaid/lattice,
+ floral, denim. Your inputs should be restricted to these textures.
+```
+
+## (3) Training Text2Human
+
+### Stage I: Pose to Parsing
+Train the parsing generation network. If you want to skip the training of this network, you can download our pretrained model from [here](https://drive.google.com/file/d/1MNyFLGqIQcOMg_HhgwCmKqdwfQSjeg_6/view?usp=sharing).
+```python
+python train_parsing_gen.py -opt ./configs/parsing_gen.yml
+```
+
+### Stage II: Parsing to Human
+
+**Step 1: Train the top level of the hierarchical VQVAE.**
+We provide our pretrained model [here](https://drive.google.com/file/d/1TwypUg85gPFJtMwBLUjVS66FKR3oaTz8/view?usp=sharing). This model is trained by:
+```python
+python train_vqvae.py -opt ./configs/vqvae_top.yml
+```
+
+**Step 2: Train the bottom level of the hierarchical VQVAE.**
+We provide our pretrained model [here](https://drive.google.com/file/d/15hzbY-RG-ILgzUqqGC0qMzlS4OayPdRH/view?usp=sharing). This model is trained by:
+```python
+python train_vqvae.py -opt ./configs/vqvae_bottom.yml
+```
+
+**Stage 3 & 4: Train the sampler with mixture-of-experts.** To train the sampler, we first need to train a model to tokenize the parsing maps. You can access our pretrained parsing maps [here](https://drive.google.com/file/d/1GLHoOeCP6sMao1-R63ahJMJF7-J00uir/view?usp=sharing).
+```python
+python train_parsing_token.py -opt ./configs/parsing_token.yml
+```
+
+With the parsing tokenization model, the sampler is trained by:
+```python
+python train_sampler.py -opt ./configs/sampler.yml
+```
+Our pretrained sampler is provided [here](https://drive.google.com/file/d/1OQO_kG2fK7eKiG1VJH1OL782X71UQAmS/view?usp=sharing).
+
+**Stage 5: Train the index prediction network.**
+We provide our pretrained index prediction network [here](https://drive.google.com/file/d/1rqhkQD-JGd7YBeIfDvMV-vjfbNHpIhYm/view?usp=sharing). It is trained by:
+```python
+python train_index_prediction.py -opt ./configs/index_pred_net.yml
+```
+
+
+**Remarks**: In the config files, we use the path to our models as the required pretrained models. If you want to train the models from scratch, please replace the path to your own one. We set the numbers of the training epochs as large numbers and you can choose the best epoch for each model. For your reference, our pretrained parsing generation network is trained for 50 epochs, top-level VQVAE is trained for 135 epochs, bottom-level VQVAE is trained for 70 epochs, parsing tokenization network is trained for 20 epochs, sampler is trained for 95 epochs, and the index prediction network is trained for 70 epochs.
+
+## (4) Results
+
+Please visit our [Project Page](https://yumingj.github.io/projects/Text2Human.html#results) to view more results.
+You can select the attribtues to customize the desired human images.
+[
+](https://yumingj.github.io/projects/Text2Human.html#results)
+
+## DeepFashion-MultiModal Dataset
+
+
+
+In this work, we also propose **DeepFashion-MultiModal**, a large-scale high-quality human dataset with rich multi-modal annotations. It has the following properties:
+1. It contains 44,096 high-resolution human images, including 12,701 full body human images.
+2. For each full body images, we **manually annotate** the human parsing labels of 24 classes.
+3. For each full body images, we **manually annotate** the keypoints.
+4. We extract DensePose for each human image.
+5. Each image is **manually annotated** with attributes for both clothes shapes and textures.
+6. We provide a textual description for each image.
+
+
+
+Please refer to [this repo](https://github.com/yumingj/DeepFashion-MultiModal) for more details about our proposed dataset.
+
+## TODO List
+
+- [ ] Release 1024x512 version of Text2Human.
+- [ ] Train the Text2Human using [SHHQ dataset](https://stylegan-human.github.io/).
+
+## Citation
+
+If you find this work useful for your research, please consider citing our paper:
+
+```bibtex
+@article{jiang2022text2human,
+ title={Text2Human: Text-Driven Controllable Human Image Generation},
+ author={Jiang, Yuming and Yang, Shuai and Qiu, Haonan and Wu, Wayne and Loy, Chen Change and Liu, Ziwei},
+ journal={ACM Transactions on Graphics (TOG)},
+ volume={41},
+ number={4},
+ articleno={162},
+ pages={1--11},
+ year={2022},
+ publisher={ACM New York, NY, USA},
+ doi={10.1145/3528223.3530104},
+}
+```
+
+## Acknowledgments
+
+Part of the code is borrowed from [unleashing-transformers](https://github.com/samb-t/unleashing-transformers), [taming-transformers](https://github.com/CompVis/taming-transformers) and [mmsegmentation](https://github.com/open-mmlab/mmsegmentation).
diff --git a/Text2Human/configs/index_pred_net.yml b/Text2Human/configs/index_pred_net.yml
new file mode 100644
index 0000000000000000000000000000000000000000..9269d38d3b260f39f161d765fd641048be9a836b
--- /dev/null
+++ b/Text2Human/configs/index_pred_net.yml
@@ -0,0 +1,84 @@
+name: index_prediction_network
+use_tb_logger: true
+set_CUDA_VISIBLE_DEVICES: ~
+gpu_ids: [3]
+
+# dataset configs
+batch_size: 4
+num_workers: 4
+train_img_dir: ./datasets/train_images
+test_img_dir: ./datasets/test_images
+segm_dir: ./datasets/segm
+pose_dir: ./datasets/densepose
+train_ann_file: ./datasets/texture_ann/train
+val_ann_file: ./datasets/texture_ann/val
+test_ann_file: ./datasets/texture_ann/test
+downsample_factor: 2
+
+model_type: VQGANTextureAwareSpatialHierarchyInferenceModel
+# network configs
+embed_dim: 256
+n_embed: 1024
+codebook_spatial_size: 2
+
+# bottom level vqvae
+bot_n_embed: 512
+bot_double_z: false
+bot_z_channels: 256
+bot_resolution: 512
+bot_in_channels: 3
+bot_out_ch: 3
+bot_ch: 128
+bot_ch_mult: [1, 1, 2, 4]
+bot_num_res_blocks: 2
+bot_attn_resolutions: [64]
+bot_dropout: 0.0
+bot_vae_path: ./pretrained_models/vqvae_bottom.pth
+
+# top level vqgan
+top_double_z: false
+top_z_channels: 256
+top_resolution: 512
+top_in_channels: 3
+top_out_ch: 3
+top_ch: 128
+top_ch_mult: [1, 1, 2, 2, 4]
+top_num_res_blocks: 2
+top_attn_resolutions: [32]
+top_dropout: 0.0
+top_vae_path: ./pretrained_models/vqvae_top.pth
+
+# unet configs
+encoder_in_channels: 256
+fc_in_channels: 64
+fc_in_index: 4
+fc_channels: 64
+fc_num_convs: 1
+fc_concat_input: False
+fc_dropout_ratio: 0.1
+fc_num_classes: 512
+fc_align_corners: False
+
+disc_layers: 3
+disc_weight_max: 1
+disc_start_step: 30001
+n_channels: 3
+ndf: 64
+nf: 128
+perceptual_weight: 1.0
+
+num_segm_classes: 24
+
+# training configs
+val_freq: 5
+print_freq: 100
+weight_decay: 0
+manual_seed: 2021
+num_epochs: 100
+lr: !!float 1.0e-04
+lr_decay: step
+gamma: 1.0
+step: 50
+optimizer: Adam
+loss_function: cross_entropy
+
diff --git a/Text2Human/configs/parsing_gen.yml b/Text2Human/configs/parsing_gen.yml
new file mode 100644
index 0000000000000000000000000000000000000000..fe8770ce03039704dac797ae2dfab721393af234
--- /dev/null
+++ b/Text2Human/configs/parsing_gen.yml
@@ -0,0 +1,40 @@
+name: parsing_generation
+use_tb_logger: true
+set_CUDA_VISIBLE_DEVICES: ~
+gpu_ids: [3]
+
+# dataset configs
+batch_size: 8
+num_workers: 4
+segm_dir: ./datasets/segm
+pose_dir: ./datasets/densepose
+train_ann_file: ./datasets/shape_ann/train_ann_file.txt
+val_ann_file: ./datasets/shape_ann/val_ann_file.txt
+test_ann_file: ./datasets/shape_ann/test_ann_file.txt
+downsample_factor: 2
+
+model_type: ParsingGenModel
+# network configs
+embedder_dim: 8
+embedder_out_dim: 128
+attr_class_num: [2, 4, 6, 5, 4, 3, 5, 5, 3, 2, 2, 2, 2, 2, 2]
+encoder_in_channels: 1
+fc_in_channels: 64
+fc_in_index: 4
+fc_channels: 64
+fc_num_convs: 1
+fc_concat_input: False
+fc_dropout_ratio: 0.1
+fc_num_classes: 24
+fc_align_corners: False
+
+# training configs
+val_freq: 5
+print_freq: 100
+weight_decay: 0
+manual_seed: 2021
+num_epochs: 100
+lr: !!float 1e-4
+lr_decay: step
+gamma: 0.1
+step: 50
diff --git a/Text2Human/configs/parsing_token.yml b/Text2Human/configs/parsing_token.yml
new file mode 100644
index 0000000000000000000000000000000000000000..0f49fdee88b2f565024dca17429836593218ff6d
--- /dev/null
+++ b/Text2Human/configs/parsing_token.yml
@@ -0,0 +1,47 @@
+name: parsing_tokenization
+use_tb_logger: true
+set_CUDA_VISIBLE_DEVICES: ~
+gpu_ids: [3]
+
+# dataset configs
+batch_size: 4
+num_workers: 4
+train_img_dir: ./datasets/train_images
+test_img_dir: ./datasets/test_images
+segm_dir: ./datasets/segm
+pose_dir: ./datasets/densepose
+train_ann_file: ./datasets/texture_ann/train
+val_ann_file: ./datasets/texture_ann/val
+test_ann_file: ./datasets/texture_ann/test
+downsample_factor: 2
+
+model_type: VQSegmentationModel
+# network configs
+embed_dim: 32
+n_embed: 1024
+image_key: "segmentation"
+n_labels: 24
+double_z: false
+z_channels: 32
+resolution: 512
+in_channels: 24
+out_ch: 24
+ch: 64
+ch_mult: [1, 1, 2, 2, 4]
+num_res_blocks: 1
+attn_resolutions: [16]
+dropout: 0.0
+
+num_segm_classes: 24
+
+
+# training configs
+val_freq: 5
+print_freq: 100
+weight_decay: 0
+manual_seed: 2021
+num_epochs: 100
+lr: !!float 4.5e-05
+lr_decay: step
+gamma: 0.1
+step: 50
diff --git a/Text2Human/configs/sample_from_parsing.yml b/Text2Human/configs/sample_from_parsing.yml
new file mode 100644
index 0000000000000000000000000000000000000000..47da333846cd76a885f18c41fd775a99e3fce726
--- /dev/null
+++ b/Text2Human/configs/sample_from_parsing.yml
@@ -0,0 +1,93 @@
+name: sample_from_parsing
+use_tb_logger: true
+set_CUDA_VISIBLE_DEVICES: ~
+gpu_ids: [3]
+
+# dataset configs
+batch_size: 4
+num_workers: 4
+test_img_dir: ./datasets/test_images
+segm_dir: ./datasets/segm
+pose_dir: ./datasets/densepose
+test_ann_file: ./datasets/texture_ann/test
+downsample_factor: 2
+
+model_type: SampleFromParsingModel
+# network configs
+embed_dim: 256
+n_embed: 1024
+codebook_spatial_size: 2
+
+# bottom level vqvae
+bot_n_embed: 512
+bot_codebook_spatial_size: 2
+bot_double_z: false
+bot_z_channels: 256
+bot_resolution: 512
+bot_in_channels: 3
+bot_out_ch: 3
+bot_ch: 128
+bot_ch_mult: [1, 1, 2, 4]
+bot_num_res_blocks: 2
+bot_attn_resolutions: [64]
+bot_dropout: 0.0
+bot_vae_path: ./pretrained_models/vqvae_bottom.pth
+
+# top level vqgan
+top_double_z: false
+top_z_channels: 256
+top_resolution: 512
+top_in_channels: 3
+top_out_ch: 3
+top_ch: 128
+top_ch_mult: [1, 1, 2, 2, 4]
+top_num_res_blocks: 2
+top_attn_resolutions: [32]
+top_dropout: 0.0
+top_vae_path: ./pretrained_models/vqvae_top.pth
+
+# unet configs
+index_pred_encoder_in_channels: 256
+index_pred_fc_in_channels: 64
+index_pred_fc_in_index: 4
+index_pred_fc_channels: 64
+index_pred_fc_num_convs: 1
+index_pred_fc_concat_input: False
+index_pred_fc_dropout_ratio: 0.1
+index_pred_fc_num_classes: 512
+index_pred_fc_align_corners: False
+pretrained_index_network: ./pretrained_models/index_pred_net.pth
+
+# segmentation tokenization
+segm_double_z: false
+segm_z_channels: 32
+segm_resolution: 512
+segm_in_channels: 24
+segm_out_ch: 24
+segm_ch: 64
+segm_ch_mult: [1, 1, 2, 2, 4]
+segm_num_res_blocks: 1
+segm_attn_resolutions: [16]
+segm_dropout: 0.0
+segm_num_segm_classes: 24
+segm_n_embed: 1024
+segm_embed_dim: 32
+segm_token_path: ./pretrained_models/parsing_token.pth
+
+# sampler configs
+codebook_size: 18432
+segm_codebook_size: 1024
+texture_codebook_size: 18
+bert_n_emb: 512
+bert_n_layers: 24
+bert_n_head: 8
+block_size: 512 # 32 x 16
+latent_shape: [32, 16]
+embd_pdrop: 0.0
+resid_pdrop: 0.0
+attn_pdrop: 0.0
+num_head: 18
+pretrained_sampler: ./pretrained_models/sampler.pth
+
+manual_seed: 2021
+sample_steps: 256
diff --git a/Text2Human/configs/sample_from_pose.yml b/Text2Human/configs/sample_from_pose.yml
new file mode 100644
index 0000000000000000000000000000000000000000..ed7fb13e4ce6994c55585235026f2c3aab35bf1a
--- /dev/null
+++ b/Text2Human/configs/sample_from_pose.yml
@@ -0,0 +1,107 @@
+name: sample_from_pose
+use_tb_logger: true
+set_CUDA_VISIBLE_DEVICES: ~
+gpu_ids: [3]
+
+# dataset configs
+batch_size: 4
+num_workers: 4
+pose_dir: ./datasets/densepose
+texture_ann_file: ./datasets/texture_ann/test
+shape_ann_path: ./datasets/shape_ann/test_ann_file.txt
+downsample_factor: 2
+
+model_type: SampleFromPoseModel
+# network configs
+embed_dim: 256
+n_embed: 1024
+codebook_spatial_size: 2
+
+# bottom level vqgan
+bot_n_embed: 512
+bot_codebook_spatial_size: 2
+bot_double_z: false
+bot_z_channels: 256
+bot_resolution: 512
+bot_in_channels: 3
+bot_out_ch: 3
+bot_ch: 128
+bot_ch_mult: [1, 1, 2, 4]
+bot_num_res_blocks: 2
+bot_attn_resolutions: [64]
+bot_dropout: 0.0
+bot_vae_path: ./pretrained_models/vqvae_bottom.pth
+
+# top level vqgan
+top_double_z: false
+top_z_channels: 256
+top_resolution: 512
+top_in_channels: 3
+top_out_ch: 3
+top_ch: 128
+top_ch_mult: [1, 1, 2, 2, 4]
+top_num_res_blocks: 2
+top_attn_resolutions: [32]
+top_dropout: 0.0
+top_vae_path: ./pretrained_models/vqvae_top.pth
+
+# unet configs
+index_pred_encoder_in_channels: 256
+index_pred_fc_in_channels: 64
+index_pred_fc_in_index: 4
+index_pred_fc_channels: 64
+index_pred_fc_num_convs: 1
+index_pred_fc_concat_input: False
+index_pred_fc_dropout_ratio: 0.1
+index_pred_fc_num_classes: 512
+index_pred_fc_align_corners: False
+pretrained_index_network: ./pretrained_models/index_pred_net.pth
+
+# segmentation tokenization
+segm_double_z: false
+segm_z_channels: 32
+segm_resolution: 512
+segm_in_channels: 24
+segm_out_ch: 24
+segm_ch: 64
+segm_ch_mult: [1, 1, 2, 2, 4]
+segm_num_res_blocks: 1
+segm_attn_resolutions: [16]
+segm_dropout: 0.0
+segm_num_segm_classes: 24
+segm_n_embed: 1024
+segm_embed_dim: 32
+segm_token_path: ./pretrained_models/parsing_token.pth
+
+# sampler configs
+codebook_size: 18432
+segm_codebook_size: 1024
+texture_codebook_size: 18
+bert_n_emb: 512
+bert_n_layers: 24
+bert_n_head: 8
+block_size: 512 # 32 x 16
+latent_shape: [32, 16]
+embd_pdrop: 0.0
+resid_pdrop: 0.0
+attn_pdrop: 0.0
+num_head: 18
+pretrained_sampler: ./pretrained_models/sampler.pth
+
+# shape network configs
+shape_embedder_dim: 8
+shape_embedder_out_dim: 128
+shape_attr_class_num: [2, 4, 6, 5, 4, 3, 5, 5, 3, 2, 2, 2, 2, 2, 2]
+shape_encoder_in_channels: 1
+shape_fc_in_channels: 64
+shape_fc_in_index: 4
+shape_fc_channels: 64
+shape_fc_num_convs: 1
+shape_fc_concat_input: False
+shape_fc_dropout_ratio: 0.1
+shape_fc_num_classes: 24
+shape_fc_align_corners: False
+pretrained_parsing_gen: ./pretrained_models/parsing_gen.pth
+
+manual_seed: 2021
+sample_steps: 256
diff --git a/Text2Human/configs/sampler.yml b/Text2Human/configs/sampler.yml
new file mode 100644
index 0000000000000000000000000000000000000000..7011d93e408bffa88bdc16a195252a2d08cfd606
--- /dev/null
+++ b/Text2Human/configs/sampler.yml
@@ -0,0 +1,83 @@
+name: sampler
+use_tb_logger: true
+set_CUDA_VISIBLE_DEVICES: ~
+gpu_ids: [3]
+
+# dataset configs
+batch_size: 4
+num_workers: 1
+train_img_dir: ./datasets/train_images
+test_img_dir: ./datasets/test_images
+segm_dir: ./datasets/segm
+pose_dir: ./datasets/densepose
+train_ann_file: ./datasets/texture_ann/train
+val_ann_file: ./datasets/texture_ann/val
+test_ann_file: ./datasets/texture_ann/test
+downsample_factor: 2
+
+# pretrained models
+img_ae_path: ./pretrained_models/vqvae_top.pth
+segm_ae_path: ./pretrained_models/parsing_token.pth
+
+model_type: TransformerTextureAwareModel
+# network configs
+
+# image autoencoder
+img_embed_dim: 256
+img_n_embed: 1024
+img_double_z: false
+img_z_channels: 256
+img_resolution: 512
+img_in_channels: 3
+img_out_ch: 3
+img_ch: 128
+img_ch_mult: [1, 1, 2, 2, 4]
+img_num_res_blocks: 2
+img_attn_resolutions: [32]
+img_dropout: 0.0
+
+# segmentation tokenization
+segm_double_z: false
+segm_z_channels: 32
+segm_resolution: 512
+segm_in_channels: 24
+segm_out_ch: 24
+segm_ch: 64
+segm_ch_mult: [1, 1, 2, 2, 4]
+segm_num_res_blocks: 1
+segm_attn_resolutions: [16]
+segm_dropout: 0.0
+segm_num_segm_classes: 24
+segm_n_embed: 1024
+segm_embed_dim: 32
+
+# sampler configs
+codebook_size: 18432
+segm_codebook_size: 1024
+texture_codebook_size: 18
+bert_n_emb: 512
+bert_n_layers: 24
+bert_n_head: 8
+block_size: 512 # 32 x 16
+latent_shape: [32, 16]
+embd_pdrop: 0.0
+resid_pdrop: 0.0
+attn_pdrop: 0.0
+num_head: 18
+
+# loss configs
+loss_type: reweighted_elbo
+mask_schedule: random
+
+sample_steps: 256
+
+# training configs
+val_freq: 5
+print_freq: 100
+weight_decay: 0
+manual_seed: 2021
+num_epochs: 100
+lr: !!float 1e-4
+lr_decay: step
+gamma: 1.0
+step: 50
diff --git a/Text2Human/configs/vqvae_bottom.yml b/Text2Human/configs/vqvae_bottom.yml
new file mode 100644
index 0000000000000000000000000000000000000000..e426a1cf56a40c7f09d8655c2685431bdf837a61
--- /dev/null
+++ b/Text2Human/configs/vqvae_bottom.yml
@@ -0,0 +1,72 @@
+name: vqvae_bottom
+use_tb_logger: true
+set_CUDA_VISIBLE_DEVICES: ~
+gpu_ids: [3]
+
+# dataset configs
+batch_size: 4
+num_workers: 4
+train_img_dir: ./datasets/train_images
+test_img_dir: ./datasets/test_images
+segm_dir: ./datasets/segm
+pose_dir: ./datasets/densepose
+train_ann_file: ./datasets/texture_ann/train
+val_ann_file: ./datasets/texture_ann/val
+test_ann_file: ./datasets/texture_ann/test
+downsample_factor: 2
+
+model_type: HierarchyVQSpatialTextureAwareModel
+# network configs
+embed_dim: 256
+n_embed: 1024
+codebook_spatial_size: 2
+
+# bottom level vqvae
+bot_n_embed: 512
+bot_double_z: false
+bot_z_channels: 256
+bot_resolution: 512
+bot_in_channels: 3
+bot_out_ch: 3
+bot_ch: 128
+bot_ch_mult: [1, 1, 2, 4]
+bot_num_res_blocks: 2
+bot_attn_resolutions: [64]
+bot_dropout: 0.0
+
+# top level vqgan
+top_double_z: false
+top_z_channels: 256
+top_resolution: 512
+top_in_channels: 3
+top_out_ch: 3
+top_ch: 128
+top_ch_mult: [1, 1, 2, 2, 4]
+top_num_res_blocks: 2
+top_attn_resolutions: [32]
+top_dropout: 0.0
+top_vae_path: ./pretrained_models/vqvae_top.pth
+
+fix_decoder: false
+
+disc_layers: 3
+disc_weight_max: 1
+disc_start_step: 1
+n_channels: 3
+ndf: 64
+nf: 128
+perceptual_weight: 1.0
+
+num_segm_classes: 24
+
+# training configs
+val_freq: 5
+print_freq: 100
+weight_decay: 0
+manual_seed: 2021
+num_epochs: 1000
+lr: !!float 1.0e-04
+lr_decay: step
+gamma: 1.0
+step: 50
+
diff --git a/Text2Human/configs/vqvae_top.yml b/Text2Human/configs/vqvae_top.yml
new file mode 100644
index 0000000000000000000000000000000000000000..ef4a5f412770bf69d9644d5ac014e1e642ae30cf
--- /dev/null
+++ b/Text2Human/configs/vqvae_top.yml
@@ -0,0 +1,53 @@
+name: vqvae_top
+use_tb_logger: true
+set_CUDA_VISIBLE_DEVICES: ~
+gpu_ids: [3]
+
+# dataset configs
+batch_size: 4
+num_workers: 4
+train_img_dir: ./datasets/train_images
+test_img_dir: ./datasets/test_images
+segm_dir: ./datasets/segm
+pose_dir: ./datasets/densepose
+train_ann_file: ./datasets/texture_ann/train
+val_ann_file: ./datasets/texture_ann/val
+test_ann_file: ./datasets/texture_ann/test
+downsample_factor: 2
+
+model_type: VQImageSegmTextureModel
+# network configs
+embed_dim: 256
+n_embed: 1024
+double_z: false
+z_channels: 256
+resolution: 512
+in_channels: 3
+out_ch: 3
+ch: 128
+ch_mult: [1, 1, 2, 2, 4]
+num_res_blocks: 2
+attn_resolutions: [32]
+dropout: 0.0
+
+disc_layers: 3
+disc_weight_max: 0
+disc_start_step: 3000000000000000000000000001
+n_channels: 3
+ndf: 64
+nf: 128
+perceptual_weight: 1.0
+
+num_segm_classes: 24
+
+
+# training configs
+val_freq: 5
+print_freq: 100
+weight_decay: 0
+manual_seed: 2021
+num_epochs: 1000
+lr: !!float 1.0e-04
+lr_decay: step
+gamma: 1.0
+step: 50
diff --git a/Text2Human/data/__init__.py b/Text2Human/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Text2Human/data/mask_dataset.py b/Text2Human/data/mask_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..12d711076deab116111c513ca8b3e931934584a8
--- /dev/null
+++ b/Text2Human/data/mask_dataset.py
@@ -0,0 +1,59 @@
+import os
+import os.path
+import random
+
+import numpy as np
+import torch
+import torch.utils.data as data
+from PIL import Image
+
+
+class MaskDataset(data.Dataset):
+
+ def __init__(self, segm_dir, ann_dir, downsample_factor=2, xflip=False):
+
+ self._segm_path = segm_dir
+ self._image_fnames = []
+
+ self.downsample_factor = downsample_factor
+ self.xflip = xflip
+
+ # load attributes
+ assert os.path.exists(f'{ann_dir}/upper_fused.txt')
+ for idx, row in enumerate(
+ open(os.path.join(f'{ann_dir}/upper_fused.txt'), 'r')):
+ annotations = row.split()
+ self._image_fnames.append(annotations[0])
+
+ def _open_file(self, path_prefix, fname):
+ return open(os.path.join(path_prefix, fname), 'rb')
+
+ def _load_segm(self, raw_idx):
+ fname = self._image_fnames[raw_idx]
+ fname = f'{fname[:-4]}_segm.png'
+ with self._open_file(self._segm_path, fname) as f:
+ segm = Image.open(f)
+ if self.downsample_factor != 1:
+ width, height = segm.size
+ width = width // self.downsample_factor
+ height = height // self.downsample_factor
+ segm = segm.resize(
+ size=(width, height), resample=Image.NEAREST)
+ segm = np.array(segm)
+ # segm = segm[:, :, np.newaxis].transpose(2, 0, 1)
+ return segm.astype(np.float32)
+
+ def __getitem__(self, index):
+ segm = self._load_segm(index)
+
+ if self.xflip and random.random() > 0.5:
+ segm = segm[:, ::-1].copy()
+
+ segm = torch.from_numpy(segm).long()
+
+ return_dict = {'segm': segm, 'img_name': self._image_fnames[index]}
+
+ return return_dict
+
+ def __len__(self):
+ return len(self._image_fnames)
diff --git a/Text2Human/data/parsing_generation_segm_attr_dataset.py b/Text2Human/data/parsing_generation_segm_attr_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a9d50c2fe21e0bb327334c64148ff79efd9dcad
--- /dev/null
+++ b/Text2Human/data/parsing_generation_segm_attr_dataset.py
@@ -0,0 +1,80 @@
+import os
+import os.path
+
+import numpy as np
+import torch
+import torch.utils.data as data
+from PIL import Image
+
+
+class ParsingGenerationDeepFashionAttrSegmDataset(data.Dataset):
+
+ def __init__(self, segm_dir, pose_dir, ann_file, downsample_factor=2):
+ self._densepose_path = pose_dir
+ self._segm_path = segm_dir
+ self._image_fnames = []
+ self.attrs = []
+
+ self.downsample_factor = downsample_factor
+
+ # training, ground-truth available
+ assert os.path.exists(ann_file)
+ for row in open(os.path.join(ann_file), 'r'):
+ annotations = row.split()
+ self._image_fnames.append(annotations[0])
+ self.attrs.append([int(i) for i in annotations[1:]])
+
+ def _open_file(self, path_prefix, fname):
+ return open(os.path.join(path_prefix, fname), 'rb')
+
+ def _load_densepose(self, raw_idx):
+ fname = self._image_fnames[raw_idx]
+ fname = f'{fname[:-4]}_densepose.png'
+ with self._open_file(self._densepose_path, fname) as f:
+ densepose = Image.open(f)
+ if self.downsample_factor != 1:
+ width, height = densepose.size
+ width = width // self.downsample_factor
+ height = height // self.downsample_factor
+ densepose = densepose.resize(
+ size=(width, height), resample=Image.NEAREST)
+ # channel-wise IUV order, [3, H, W]
+ densepose = np.array(densepose)[:, :, 2:].transpose(2, 0, 1)
+ return densepose.astype(np.float32)
+
+ def _load_segm(self, raw_idx):
+ fname = self._image_fnames[raw_idx]
+ fname = f'{fname[:-4]}_segm.png'
+ with self._open_file(self._segm_path, fname) as f:
+ segm = Image.open(f)
+ if self.downsample_factor != 1:
+ width, height = segm.size
+ width = width // self.downsample_factor
+ height = height // self.downsample_factor
+ segm = segm.resize(
+ size=(width, height), resample=Image.NEAREST)
+ segm = np.array(segm)
+ return segm.astype(np.float32)
+
+ def __getitem__(self, index):
+ pose = self._load_densepose(index)
+ segm = self._load_segm(index)
+ attr = self.attrs[index]
+
+ pose = torch.from_numpy(pose)
+ segm = torch.LongTensor(segm)
+ attr = torch.LongTensor(attr)
+
+ pose = pose / 12. - 1
+
+ return_dict = {
+ 'densepose': pose,
+ 'segm': segm,
+ 'attr': attr,
+ 'img_name': self._image_fnames[index]
+ }
+
+ return return_dict
+
+ def __len__(self):
+ return len(self._image_fnames)
diff --git a/Text2Human/data/pose_attr_dataset.py b/Text2Human/data/pose_attr_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..7245846cb321db91c7935edbae83f7c451110725
--- /dev/null
+++ b/Text2Human/data/pose_attr_dataset.py
@@ -0,0 +1,109 @@
+import os
+import os.path
+import random
+
+import numpy as np
+import torch
+import torch.utils.data as data
+from PIL import Image
+
+
+class DeepFashionAttrPoseDataset(data.Dataset):
+
+ def __init__(self,
+ pose_dir,
+ texture_ann_dir,
+ shape_ann_path,
+ downsample_factor=2,
+ xflip=False):
+ self._densepose_path = pose_dir
+ self._image_fnames_target = []
+ self._image_fnames = []
+ self.upper_fused_attrs = []
+ self.lower_fused_attrs = []
+ self.outer_fused_attrs = []
+ self.shape_attrs = []
+
+ self.downsample_factor = downsample_factor
+ self.xflip = xflip
+
+ # load attributes
+ assert os.path.exists(f'{texture_ann_dir}/upper_fused.txt')
+ for idx, row in enumerate(
+ open(os.path.join(f'{texture_ann_dir}/upper_fused.txt'), 'r')):
+ annotations = row.split()
+ self._image_fnames_target.append(annotations[0])
+ self._image_fnames.append(f'{annotations[0].split(".")[0]}.png')
+ self.upper_fused_attrs.append(int(annotations[1]))
+
+ assert len(self._image_fnames_target) == len(self.upper_fused_attrs)
+
+ assert os.path.exists(f'{texture_ann_dir}/lower_fused.txt')
+ for idx, row in enumerate(
+ open(os.path.join(f'{texture_ann_dir}/lower_fused.txt'), 'r')):
+ annotations = row.split()
+ assert self._image_fnames_target[idx] == annotations[0]
+ self.lower_fused_attrs.append(int(annotations[1]))
+
+ assert len(self._image_fnames_target) == len(self.lower_fused_attrs)
+
+ assert os.path.exists(f'{texture_ann_dir}/outer_fused.txt')
+ for idx, row in enumerate(
+ open(os.path.join(f'{texture_ann_dir}/outer_fused.txt'), 'r')):
+ annotations = row.split()
+ assert self._image_fnames_target[idx] == annotations[0]
+ self.outer_fused_attrs.append(int(annotations[1]))
+
+ assert len(self._image_fnames_target) == len(self.outer_fused_attrs)
+
+ assert os.path.exists(shape_ann_path)
+ for idx, row in enumerate(open(os.path.join(shape_ann_path), 'r')):
+ annotations = row.split()
+ assert self._image_fnames_target[idx] == annotations[0]
+ self.shape_attrs.append([int(i) for i in annotations[1:]])
+
+ def _open_file(self, path_prefix, fname):
+ return open(os.path.join(path_prefix, fname), 'rb')
+
+ def _load_densepose(self, raw_idx):
+ fname = self._image_fnames[raw_idx]
+ fname = f'{fname[:-4]}_densepose.png'
+ with self._open_file(self._densepose_path, fname) as f:
+ densepose = Image.open(f)
+ if self.downsample_factor != 1:
+ width, height = densepose.size
+ width = width // self.downsample_factor
+ height = height // self.downsample_factor
+ densepose = densepose.resize(
+ size=(width, height), resample=Image.NEAREST)
+ # channel-wise IUV order, [3, H, W]
+ densepose = np.array(densepose)[:, :, 2:].transpose(2, 0, 1)
+ return densepose.astype(np.float32)
+
+ def __getitem__(self, index):
+ pose = self._load_densepose(index)
+ shape_attr = self.shape_attrs[index]
+ shape_attr = torch.LongTensor(shape_attr)
+
+ if self.xflip and random.random() > 0.5:
+ pose = pose[:, :, ::-1].copy()
+
+ upper_fused_attr = self.upper_fused_attrs[index]
+ lower_fused_attr = self.lower_fused_attrs[index]
+ outer_fused_attr = self.outer_fused_attrs[index]
+
+ pose = pose / 12. - 1
+
+ return_dict = {
+ 'densepose': pose,
+ 'img_name': self._image_fnames_target[index],
+ 'shape_attr': shape_attr,
+ 'upper_fused_attr': upper_fused_attr,
+ 'lower_fused_attr': lower_fused_attr,
+ 'outer_fused_attr': outer_fused_attr,
+ }
+
+ return return_dict
+
+ def __len__(self):
+ return len(self._image_fnames)
diff --git a/Text2Human/data/segm_attr_dataset.py b/Text2Human/data/segm_attr_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ab45cb71bce2f20e703f8293a7f2b430c1aaa4e
--- /dev/null
+++ b/Text2Human/data/segm_attr_dataset.py
@@ -0,0 +1,167 @@
+import os
+import os.path
+import random
+
+import numpy as np
+import torch
+import torch.utils.data as data
+from PIL import Image
+
+
+class DeepFashionAttrSegmDataset(data.Dataset):
+
+ def __init__(self,
+ img_dir,
+ segm_dir,
+ pose_dir,
+ ann_dir,
+ downsample_factor=2,
+ xflip=False):
+ self._img_path = img_dir
+ self._densepose_path = pose_dir
+ self._segm_path = segm_dir
+ self._image_fnames = []
+ self.upper_fused_attrs = []
+ self.lower_fused_attrs = []
+ self.outer_fused_attrs = []
+
+ self.downsample_factor = downsample_factor
+ self.xflip = xflip
+
+ # load attributes
+ assert os.path.exists(f'{ann_dir}/upper_fused.txt')
+ for idx, row in enumerate(
+ open(os.path.join(f'{ann_dir}/upper_fused.txt'), 'r')):
+ annotations = row.split()
+ self._image_fnames.append(annotations[0])
+ # assert self._image_fnames[idx] == annotations[0]
+ self.upper_fused_attrs.append(int(annotations[1]))
+
+ assert len(self._image_fnames) == len(self.upper_fused_attrs)
+
+ assert os.path.exists(f'{ann_dir}/lower_fused.txt')
+ for idx, row in enumerate(
+ open(os.path.join(f'{ann_dir}/lower_fused.txt'), 'r')):
+ annotations = row.split()
+ assert self._image_fnames[idx] == annotations[0]
+ self.lower_fused_attrs.append(int(annotations[1]))
+
+ assert len(self._image_fnames) == len(self.lower_fused_attrs)
+
+ assert os.path.exists(f'{ann_dir}/outer_fused.txt')
+ for idx, row in enumerate(
+ open(os.path.join(f'{ann_dir}/outer_fused.txt'), 'r')):
+ annotations = row.split()
+ assert self._image_fnames[idx] == annotations[0]
+ self.outer_fused_attrs.append(int(annotations[1]))
+
+ assert len(self._image_fnames) == len(self.outer_fused_attrs)
+
+ # remove the overlapping item between upper cls and lower cls
+ # cls 21 can appear with upper clothes
+ # cls 4 can appear with lower clothes
+ self.upper_cls = [1., 4.]
+ self.lower_cls = [3., 5., 21.]
+ self.outer_cls = [2.]
+ self.other_cls = [
+ 11., 18., 7., 8., 9., 10., 12., 16., 17., 19., 20., 22., 23., 15.,
+ 14., 13., 0., 6.
+ ]
+
+ def _open_file(self, path_prefix, fname):
+ return open(os.path.join(path_prefix, fname), 'rb')
+
+ def _load_raw_image(self, raw_idx):
+ fname = self._image_fnames[raw_idx]
+ with self._open_file(self._img_path, fname) as f:
+ image = Image.open(f)
+ if self.downsample_factor != 1:
+ width, height = image.size
+ width = width // self.downsample_factor
+ height = height // self.downsample_factor
+ image = image.resize(
+ size=(width, height), resample=Image.LANCZOS)
+ image = np.array(image)
+ if image.ndim == 2:
+ image = image[:, :, np.newaxis] # HW => HWC
+ image = image.transpose(2, 0, 1) # HWC => CHW
+ return image
+
+ def _load_densepose(self, raw_idx):
+ fname = self._image_fnames[raw_idx]
+ fname = f'{fname[:-4]}_densepose.png'
+ with self._open_file(self._densepose_path, fname) as f:
+ densepose = Image.open(f)
+ if self.downsample_factor != 1:
+ width, height = densepose.size
+ width = width // self.downsample_factor
+ height = height // self.downsample_factor
+ densepose = densepose.resize(
+ size=(width, height), resample=Image.NEAREST)
+ # channel-wise IUV order, [3, H, W]
+ densepose = np.array(densepose)[:, :, 2:].transpose(2, 0, 1)
+ return densepose.astype(np.float32)
+
+ def _load_segm(self, raw_idx):
+ fname = self._image_fnames[raw_idx]
+ fname = f'{fname[:-4]}_segm.png'
+ with self._open_file(self._segm_path, fname) as f:
+ segm = Image.open(f)
+ if self.downsample_factor != 1:
+ width, height = segm.size
+ width = width // self.downsample_factor
+ height = height // self.downsample_factor
+ segm = segm.resize(
+ size=(width, height), resample=Image.NEAREST)
+ segm = np.array(segm)
+ segm = segm[:, :, np.newaxis].transpose(2, 0, 1)
+ return segm.astype(np.float32)
+
+ def __getitem__(self, index):
+ image = self._load_raw_image(index)
+ pose = self._load_densepose(index)
+ segm = self._load_segm(index)
+
+ if self.xflip and random.random() > 0.5:
+ assert image.ndim == 3 # CHW
+ image = image[:, :, ::-1].copy()
+ pose = pose[:, :, ::-1].copy()
+ segm = segm[:, :, ::-1].copy()
+
+ image = torch.from_numpy(image)
+ segm = torch.from_numpy(segm)
+
+ upper_fused_attr = self.upper_fused_attrs[index]
+ lower_fused_attr = self.lower_fused_attrs[index]
+ outer_fused_attr = self.outer_fused_attrs[index]
+
+ # mask 0: denotes the common codebook,
+ # mask (attr + 1): denotes the texture-specific codebook
+ mask = torch.zeros_like(segm)
+ if upper_fused_attr != 17:
+ for cls in self.upper_cls:
+ mask[segm == cls] = upper_fused_attr + 1
+
+ if lower_fused_attr != 17:
+ for cls in self.lower_cls:
+ mask[segm == cls] = lower_fused_attr + 1
+
+ if outer_fused_attr != 17:
+ for cls in self.outer_cls:
+ mask[segm == cls] = outer_fused_attr + 1
+
+ pose = pose / 12. - 1
+ image = image / 127.5 - 1
+
+ return_dict = {
+ 'image': image,
+ 'densepose': pose,
+ 'segm': segm,
+ 'texture_mask': mask,
+ 'img_name': self._image_fnames[index]
+ }
+
+ return return_dict
+
+ def __len__(self):
+ return len(self._image_fnames)
diff --git a/Text2Human/environment/text2human_env.yaml b/Text2Human/environment/text2human_env.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..542bacaa210e06ffe01e7ef5e07a381673c97a03
--- /dev/null
+++ b/Text2Human/environment/text2human_env.yaml
@@ -0,0 +1,114 @@
+name: text2human
+channels:
+ - pytorch
+ - anaconda
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=main
+ - astroid=2.5=py36h06a4308_1
+ - blas=1.0=mkl
+ - brotlipy=0.7.0=py36h7b6447c_1000
+ - ca-certificates=2021.10.26=h06a4308_2
+ - certifi=2021.5.30=py36h06a4308_0
+ - cffi=1.14.3=py36he30daa8_0
+ - chardet=3.0.4=py36_1003
+ - click=8.0.3=pyhd3eb1b0_0
+ - cryptography=3.1.1=py36h1ba5d50_0
+ - cudatoolkit=10.1.243=h6bb024c_0
+ - dataclasses=0.8=pyh4f3eec9_6
+ - dbus=1.13.18=hb2f20db_0
+ - expat=2.2.10=he6710b0_2
+ - filelock=3.4.0=pyhd3eb1b0_0
+ - fontconfig=2.13.0=h9420a91_0
+ - freetype=2.10.4=h5ab3b9f_0
+ - glib=2.56.2=hd408876_0
+ - gst-plugins-base=1.14.0=hbbd80ab_1
+ - gstreamer=1.14.0=hb453b48_1
+ - icu=58.2=he6710b0_3
+ - idna=2.10=py_0
+ - importlib-metadata=4.8.1=py36h06a4308_0
+ - importlib_metadata=4.8.1=hd3eb1b0_0
+ - intel-openmp=2020.2=254
+ - isort=5.7.0=pyhd3eb1b0_0
+ - joblib=1.0.1=pyhd3eb1b0_0
+ - jpeg=9b=habf39ab_1
+ - lazy-object-proxy=1.5.2=py36h27cfd23_0
+ - lcms2=2.11=h396b838_0
+ - ld_impl_linux-64=2.33.1=h53a641e_7
+ - libffi=3.3=he6710b0_2
+ - libgcc-ng=9.1.0=hdf63c60_0
+ - libpng=1.6.37=hbc83047_0
+ - libprotobuf=3.17.2=h4ff587b_1
+ - libstdcxx-ng=9.1.0=hdf63c60_0
+ - libtiff=4.2.0=h3942068_0
+ - libuuid=1.0.3=h1bed415_2
+ - libuv=1.40.0=h7b6447c_0
+ - libwebp-base=1.2.0=h27cfd23_0
+ - libxcb=1.14=h7b6447c_0
+ - libxml2=2.9.10=hb55368b_3
+ - lz4-c=1.9.3=h2531618_0
+ - mccabe=0.6.1=py36_1
+ - mkl=2020.2=256
+ - mkl-service=2.3.0=py36he8ac12f_0
+ - mkl_fft=1.3.0=py36h54f3939_0
+ - mkl_random=1.1.1=py36h0573a6f_0
+ - ncurses=6.2=he6710b0_1
+ - ninja=1.10.2=h5e70eb0_2
+ - numpy=1.19.2=py36h54aff64_0
+ - numpy-base=1.19.2=py36hfa32c7d_0
+ - olefile=0.46=py36_0
+ - openssl=1.1.1m=h7f8727e_0
+ - packaging=21.3=pyhd3eb1b0_0
+ - pcre=8.44=he6710b0_0
+ - pillow=8.1.2=py36he98fc37_0
+ - pip=21.0.1=py36h06a4308_0
+ - protobuf=3.17.2=py36h295c915_0
+ - pycparser=2.20=py_2
+ - pylint=2.7.2=py36h06a4308_1
+ - pyopenssl=19.1.0=py_1
+ - pyqt=5.9.2=py36h05f1152_2
+ - pysocks=1.7.1=py36_0
+ - python=3.6.13=hdb3f193_0
+ - pytorch=1.7.1=py3.6_cuda10.1.243_cudnn7.6.3_0
+ - qt=5.9.7=h5867ecd_1
+ - readline=8.1=h27cfd23_0
+ - regex=2021.8.3=py36h7f8727e_0
+ - requests=2.24.0=py_0
+ - setuptools=52.0.0=py36h06a4308_0
+ - sip=4.19.8=py36hf484d3e_0
+ - six=1.15.0=py36h06a4308_0
+ - sqlite=3.35.2=hdfb4753_0
+ - tk=8.6.10=hbc83047_0
+ - toml=0.10.2=pyhd3eb1b0_0
+ - torchvision=0.8.2=py36_cu101
+ - tqdm=4.62.3=pyhd3eb1b0_1
+ - typed-ast=1.4.2=py36h27cfd23_1
+ - typing-extensions=3.10.0.2=hd3eb1b0_0
+ - typing_extensions=3.10.0.2=pyh06a4308_0
+ - urllib3=1.25.11=py_0
+ - wheel=0.36.2=pyhd3eb1b0_0
+ - wrapt=1.12.1=py36h7b6447c_1
+ - xz=5.2.5=h7b6447c_0
+ - yaml=0.2.5=h7b6447c_0
+ - zipp=3.6.0=pyhd3eb1b0_0
+ - zlib=1.2.11=h7b6447c_3
+ - zstd=1.4.5=h9ceee32_0
+ - pip:
+ - addict==2.4.0
+ - cycler==0.11.0
+ - einops==0.4.0
+ - kiwisolver==1.3.1
+ - matplotlib==3.3.4
+ - mmcv-full==1.2.1
+ - mmsegmentation==0.9.0
+ - nltk==3.6.7
+ - opencv-python==4.5.5.62
+ - pyparsing==3.0.7
+ - python-dateutil==2.8.2
+ - pyyaml==6.0
+ - scikit-learn==0.24.2
+ - scipy==1.5.4
+ - sentencepiece==0.1.96
+ - terminaltables==3.1.10
+ - threadpoolctl==3.0.0
+ - yapf==0.32.0
diff --git a/Text2Human/models/__init__.py b/Text2Human/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..caeb363ed8ade72ac2bd3214fcbba62313efc262
--- /dev/null
+++ b/Text2Human/models/__init__.py
@@ -0,0 +1,42 @@
+import glob
+import importlib
+import logging
+import os.path as osp
+
+# automatically scan and import model modules
+# scan all the files under the 'models' folder and collect files ending with
+# '_model.py'
+model_folder = osp.dirname(osp.abspath(__file__))
+model_filenames = [
+ osp.splitext(osp.basename(v))[0]
+ for v in glob.glob(f'{model_folder}/*_model.py')
+]
+# import all the model modules
+_model_modules = [
+ importlib.import_module(f'models.{file_name}')
+ for file_name in model_filenames
+]
+
+
+def create_model(opt):
+ """Create model.
+
+ Args:
+ opt (dict): Configuration. It constains:
+ model_type (str): Model type.
+ """
+ model_type = opt['model_type']
+
+ # dynamically instantiation
+ for module in _model_modules:
+ model_cls = getattr(module, model_type, None)
+ if model_cls is not None:
+ break
+ if model_cls is None:
+ raise ValueError(f'Model {model_type} is not found.')
+
+ model = model_cls(opt)
+
+ logger = logging.getLogger('base')
+ logger.info(f'Model [{model.__class__.__name__}] is created.')
+ return model
diff --git a/Text2Human/models/archs/__init__.py b/Text2Human/models/archs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Text2Human/models/archs/fcn_arch.py b/Text2Human/models/archs/fcn_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8bb7c1b9fc66379e5a32ac02a24de63fe6953e7
--- /dev/null
+++ b/Text2Human/models/archs/fcn_arch.py
@@ -0,0 +1,418 @@
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule, normal_init
+from mmseg.ops import resize
+
+
+class BaseDecodeHead(nn.Module):
+ """Base class for BaseDecodeHead.
+
+ Args:
+ in_channels (int|Sequence[int]): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ num_classes (int): Number of classes.
+ dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
+ conv_cfg (dict|None): Config of conv layers. Default: None.
+ norm_cfg (dict|None): Config of norm layers. Default: None.
+ act_cfg (dict): Config of activation layers.
+ Default: dict(type='ReLU')
+ in_index (int|Sequence[int]): Input feature index. Default: -1
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+ 'resize_concat': Multiple feature maps will be resize to the
+ same size as first one and than concat together.
+ Usually used in FCN head of HRNet.
+ 'multiple_select': Multiple feature maps will be bundle into
+ a list and passed into decode head.
+ None: Only one select feature map is allowed.
+ Default: None.
+ loss_decode (dict): Config of decode loss.
+ Default: dict(type='CrossEntropyLoss').
+ ignore_index (int | None): The label index to be ignored. When using
+ masked BCE loss, ignore_index should be set to None. Default: 255
+ sampler (dict|None): The config of segmentation map sampler.
+ Default: None.
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False.
+ """
+
+ def __init__(self,
+ in_channels,
+ channels,
+ *,
+ num_classes,
+ dropout_ratio=0.1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ in_index=-1,
+ input_transform=None,
+ ignore_index=255,
+ align_corners=False):
+ super(BaseDecodeHead, self).__init__()
+ self._init_inputs(in_channels, in_index, input_transform)
+ self.channels = channels
+ self.num_classes = num_classes
+ self.dropout_ratio = dropout_ratio
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.in_index = in_index
+
+ self.ignore_index = ignore_index
+ self.align_corners = align_corners
+
+ self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
+ if dropout_ratio > 0:
+ self.dropout = nn.Dropout2d(dropout_ratio)
+ else:
+ self.dropout = None
+
+ def extra_repr(self):
+ """Extra repr."""
+ s = f'input_transform={self.input_transform}, ' \
+ f'ignore_index={self.ignore_index}, ' \
+ f'align_corners={self.align_corners}'
+ return s
+
+ def _init_inputs(self, in_channels, in_index, input_transform):
+ """Check and initialize input transforms.
+
+ The in_channels, in_index and input_transform must match.
+ Specifically, when input_transform is None, only single feature map
+ will be selected. So in_channels and in_index must be of type int.
+ When input_transform
+
+ Args:
+ in_channels (int|Sequence[int]): Input channels.
+ in_index (int|Sequence[int]): Input feature index.
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+ 'resize_concat': Multiple feature maps will be resize to the
+ same size as first one and than concat together.
+ Usually used in FCN head of HRNet.
+ 'multiple_select': Multiple feature maps will be bundle into
+ a list and passed into decode head.
+ None: Only one select feature map is allowed.
+ """
+
+ if input_transform is not None:
+ assert input_transform in ['resize_concat', 'multiple_select']
+ self.input_transform = input_transform
+ self.in_index = in_index
+ if input_transform is not None:
+ assert isinstance(in_channels, (list, tuple))
+ assert isinstance(in_index, (list, tuple))
+ assert len(in_channels) == len(in_index)
+ if input_transform == 'resize_concat':
+ self.in_channels = sum(in_channels)
+ else:
+ self.in_channels = in_channels
+ else:
+ assert isinstance(in_channels, int)
+ assert isinstance(in_index, int)
+ self.in_channels = in_channels
+
+ def init_weights(self):
+ """Initialize weights of classification layer."""
+ normal_init(self.conv_seg, mean=0, std=0.01)
+
+ def _transform_inputs(self, inputs):
+ """Transform inputs for decoder.
+
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+
+ Returns:
+ Tensor: The transformed inputs
+ """
+
+ if self.input_transform == 'resize_concat':
+ inputs = [inputs[i] for i in self.in_index]
+ upsampled_inputs = [
+ resize(
+ input=x,
+ size=inputs[0].shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners) for x in inputs
+ ]
+ inputs = torch.cat(upsampled_inputs, dim=1)
+ elif self.input_transform == 'multiple_select':
+ inputs = [inputs[i] for i in self.in_index]
+ else:
+ inputs = inputs[self.in_index]
+
+ return inputs
+
+ def forward(self, inputs):
+ """Placeholder of forward function."""
+ pass
+
+ def cls_seg(self, feat):
+ """Classify each pixel."""
+ if self.dropout is not None:
+ feat = self.dropout(feat)
+ output = self.conv_seg(feat)
+ return output
+
+
+class FCNHead(BaseDecodeHead):
+ """Fully Convolution Networks for Semantic Segmentation.
+
+ This head is implemented of `FCNNet `_.
+
+ Args:
+ num_convs (int): Number of convs in the head. Default: 2.
+ kernel_size (int): The kernel size for convs in the head. Default: 3.
+ concat_input (bool): Whether concat the input and output of convs
+ before classification layer.
+ """
+
+ def __init__(self,
+ num_convs=2,
+ kernel_size=3,
+ concat_input=True,
+ **kwargs):
+ assert num_convs >= 0
+ self.num_convs = num_convs
+ self.concat_input = concat_input
+ self.kernel_size = kernel_size
+ super(FCNHead, self).__init__(**kwargs)
+ if num_convs == 0:
+ assert self.in_channels == self.channels
+
+ convs = []
+ convs.append(
+ ConvModule(
+ self.in_channels,
+ self.channels,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ for i in range(num_convs - 1):
+ convs.append(
+ ConvModule(
+ self.channels,
+ self.channels,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ if num_convs == 0:
+ self.convs = nn.Identity()
+ else:
+ self.convs = nn.Sequential(*convs)
+ if self.concat_input:
+ self.conv_cat = ConvModule(
+ self.in_channels + self.channels,
+ self.channels,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ output = self.convs(x)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
+
+
+class MultiHeadFCNHead(nn.Module):
+ """Fully Convolution Networks for Semantic Segmentation.
+
+ This head is implemented of `FCNNet `_.
+
+ Args:
+ num_convs (int): Number of convs in the head. Default: 2.
+ kernel_size (int): The kernel size for convs in the head. Default: 3.
+ concat_input (bool): Whether concat the input and output of convs
+ before classification layer.
+ """
+
+ def __init__(self,
+ in_channels,
+ channels,
+ *,
+ num_classes,
+ dropout_ratio=0.1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ in_index=-1,
+ input_transform=None,
+ ignore_index=255,
+ align_corners=False,
+ num_convs=2,
+ kernel_size=3,
+ concat_input=True,
+ num_head=18,
+ **kwargs):
+ super(MultiHeadFCNHead, self).__init__()
+ assert num_convs >= 0
+ self.num_convs = num_convs
+ self.concat_input = concat_input
+ self.kernel_size = kernel_size
+ self._init_inputs(in_channels, in_index, input_transform)
+ self.channels = channels
+ self.num_classes = num_classes
+ self.dropout_ratio = dropout_ratio
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.in_index = in_index
+ self.num_head = num_head
+
+ self.ignore_index = ignore_index
+ self.align_corners = align_corners
+
+ if dropout_ratio > 0:
+ self.dropout = nn.Dropout2d(dropout_ratio)
+
+ conv_seg_head_list = []
+ for _ in range(self.num_head):
+ conv_seg_head_list.append(
+ nn.Conv2d(channels, num_classes, kernel_size=1))
+
+ self.conv_seg_head_list = nn.ModuleList(conv_seg_head_list)
+
+ self.init_weights()
+
+ if num_convs == 0:
+ assert self.in_channels == self.channels
+
+ convs_list = []
+ conv_cat_list = []
+
+ for _ in range(self.num_head):
+ convs = []
+ convs.append(
+ ConvModule(
+ self.in_channels,
+ self.channels,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ for _ in range(num_convs - 1):
+ convs.append(
+ ConvModule(
+ self.channels,
+ self.channels,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ if num_convs == 0:
+ convs_list.append(nn.Identity())
+ else:
+ convs_list.append(nn.Sequential(*convs))
+ if self.concat_input:
+ conv_cat_list.append(
+ ConvModule(
+ self.in_channels + self.channels,
+ self.channels,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+
+ self.convs_list = nn.ModuleList(convs_list)
+ self.conv_cat_list = nn.ModuleList(conv_cat_list)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+
+ output_list = []
+ for head_idx in range(self.num_head):
+ output = self.convs_list[head_idx](x)
+ if self.concat_input:
+ output = self.conv_cat_list[head_idx](
+ torch.cat([x, output], dim=1))
+ if self.dropout is not None:
+ output = self.dropout(output)
+ output = self.conv_seg_head_list[head_idx](output)
+ output_list.append(output)
+
+ return output_list
+
+ def _init_inputs(self, in_channels, in_index, input_transform):
+ """Check and initialize input transforms.
+
+ The in_channels, in_index and input_transform must match.
+ Specifically, when input_transform is None, only single feature map
+ will be selected. So in_channels and in_index must be of type int.
+ When input_transform
+
+ Args:
+ in_channels (int|Sequence[int]): Input channels.
+ in_index (int|Sequence[int]): Input feature index.
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+ 'resize_concat': Multiple feature maps will be resize to the
+ same size as first one and than concat together.
+ Usually used in FCN head of HRNet.
+ 'multiple_select': Multiple feature maps will be bundle into
+ a list and passed into decode head.
+ None: Only one select feature map is allowed.
+ """
+
+ if input_transform is not None:
+ assert input_transform in ['resize_concat', 'multiple_select']
+ self.input_transform = input_transform
+ self.in_index = in_index
+ if input_transform is not None:
+ assert isinstance(in_channels, (list, tuple))
+ assert isinstance(in_index, (list, tuple))
+ assert len(in_channels) == len(in_index)
+ if input_transform == 'resize_concat':
+ self.in_channels = sum(in_channels)
+ else:
+ self.in_channels = in_channels
+ else:
+ assert isinstance(in_channels, int)
+ assert isinstance(in_index, int)
+ self.in_channels = in_channels
+
+ def init_weights(self):
+ """Initialize weights of classification layer."""
+ for conv_seg_head in self.conv_seg_head_list:
+ normal_init(conv_seg_head, mean=0, std=0.01)
+
+ def _transform_inputs(self, inputs):
+ """Transform inputs for decoder.
+
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+
+ Returns:
+ Tensor: The transformed inputs
+ """
+
+ if self.input_transform == 'resize_concat':
+ inputs = [inputs[i] for i in self.in_index]
+ upsampled_inputs = [
+ resize(
+ input=x,
+ size=inputs[0].shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners) for x in inputs
+ ]
+ inputs = torch.cat(upsampled_inputs, dim=1)
+ elif self.input_transform == 'multiple_select':
+ inputs = [inputs[i] for i in self.in_index]
+ else:
+ inputs = inputs[self.in_index]
+
+ return inputs
diff --git a/Text2Human/models/archs/shape_attr_embedding_arch.py b/Text2Human/models/archs/shape_attr_embedding_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..217c179be3591173596bac7eb1df277e6b1a3c23
--- /dev/null
+++ b/Text2Human/models/archs/shape_attr_embedding_arch.py
@@ -0,0 +1,35 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class ShapeAttrEmbedding(nn.Module):
+
+ def __init__(self, dim, out_dim, cls_num_list):
+ super(ShapeAttrEmbedding, self).__init__()
+
+ for idx, cls_num in enumerate(cls_num_list):
+ setattr(
+ self, f'attr_{idx}',
+ nn.Sequential(
+ nn.Linear(cls_num, dim), nn.LeakyReLU(),
+ nn.Linear(dim, dim)))
+ self.cls_num_list = cls_num_list
+ self.attr_num = len(cls_num_list)
+ self.fusion = nn.Sequential(
+ nn.Linear(dim * self.attr_num, out_dim), nn.LeakyReLU(),
+ nn.Linear(out_dim, out_dim))
+
+ def forward(self, attr):
+ attr_embedding_list = []
+ for idx in range(self.attr_num):
+ attr_embed_fc = getattr(self, f'attr_{idx}')
+ attr_embedding_list.append(
+ attr_embed_fc(
+ F.one_hot(
+ attr[:, idx],
+ num_classes=self.cls_num_list[idx]).to(torch.float32)))
+ attr_embedding = torch.cat(attr_embedding_list, dim=1)
+ attr_embedding = self.fusion(attr_embedding)
+
+ return attr_embedding
diff --git a/Text2Human/models/archs/transformer_arch.py b/Text2Human/models/archs/transformer_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..8027555b00c3b6b6cc50ef68081fa02df47cf7b0
--- /dev/null
+++ b/Text2Human/models/archs/transformer_arch.py
@@ -0,0 +1,273 @@
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class CausalSelfAttention(nn.Module):
+ """
+ A vanilla multi-head masked self-attention layer with a projection at the end.
+ It is possible to use torch.nn.MultiheadAttention here but I am including an
+ explicit implementation here to show that there is nothing too scary here.
+ """
+
+ def __init__(self, bert_n_emb, bert_n_head, attn_pdrop, resid_pdrop,
+ latent_shape, sampler):
+ super().__init__()
+ assert bert_n_emb % bert_n_head == 0
+ # key, query, value projections for all heads
+ self.key = nn.Linear(bert_n_emb, bert_n_emb)
+ self.query = nn.Linear(bert_n_emb, bert_n_emb)
+ self.value = nn.Linear(bert_n_emb, bert_n_emb)
+ # regularization
+ self.attn_drop = nn.Dropout(attn_pdrop)
+ self.resid_drop = nn.Dropout(resid_pdrop)
+ # output projection
+ self.proj = nn.Linear(bert_n_emb, bert_n_emb)
+ self.n_head = bert_n_head
+ self.causal = True if sampler == 'autoregressive' else False
+ if self.causal:
+ block_size = np.prod(latent_shape)
+ mask = torch.tril(torch.ones(block_size, block_size))
+ self.register_buffer("mask", mask.view(1, 1, block_size,
+ block_size))
+
+ def forward(self, x, layer_past=None):
+ B, T, C = x.size()
+
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
+ k = self.key(x).view(B, T, self.n_head,
+ C // self.n_head).transpose(1,
+ 2) # (B, nh, T, hs)
+ q = self.query(x).view(B, T, self.n_head,
+ C // self.n_head).transpose(1,
+ 2) # (B, nh, T, hs)
+ v = self.value(x).view(B, T, self.n_head,
+ C // self.n_head).transpose(1,
+ 2) # (B, nh, T, hs)
+
+ present = torch.stack((k, v))
+ if self.causal and layer_past is not None:
+ past_key, past_value = layer_past
+ k = torch.cat((past_key, k), dim=-2)
+ v = torch.cat((past_value, v), dim=-2)
+
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
+
+ if self.causal and layer_past is None:
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
+
+ att = F.softmax(att, dim=-1)
+ att = self.attn_drop(att)
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
+ # re-assemble all head outputs side by side
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
+
+ # output projection
+ y = self.resid_drop(self.proj(y))
+ return y, present
+
+
+class Block(nn.Module):
+ """ an unassuming Transformer block """
+
+ def __init__(self, bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop,
+ latent_shape, sampler):
+ super().__init__()
+ self.ln1 = nn.LayerNorm(bert_n_emb)
+ self.ln2 = nn.LayerNorm(bert_n_emb)
+ self.attn = CausalSelfAttention(bert_n_emb, bert_n_head, attn_pdrop,
+ resid_pdrop, latent_shape, sampler)
+ self.mlp = nn.Sequential(
+ nn.Linear(bert_n_emb, 4 * bert_n_emb),
+ nn.GELU(), # nice
+ nn.Linear(4 * bert_n_emb, bert_n_emb),
+ nn.Dropout(resid_pdrop),
+ )
+
+ def forward(self, x, layer_past=None, return_present=False):
+
+ attn, present = self.attn(self.ln1(x), layer_past)
+ x = x + attn
+ x = x + self.mlp(self.ln2(x))
+
+ if layer_past is not None or return_present:
+ return x, present
+ return x
+
+
+class Transformer(nn.Module):
+ """ the full GPT language model, with a context size of block_size """
+
+ def __init__(self,
+ codebook_size,
+ segm_codebook_size,
+ bert_n_emb,
+ bert_n_layers,
+ bert_n_head,
+ block_size,
+ latent_shape,
+ embd_pdrop,
+ resid_pdrop,
+ attn_pdrop,
+ sampler='absorbing'):
+ super().__init__()
+
+ self.vocab_size = codebook_size + 1
+ self.n_embd = bert_n_emb
+ self.block_size = block_size
+ self.n_layers = bert_n_layers
+ self.codebook_size = codebook_size
+ self.segm_codebook_size = segm_codebook_size
+ self.causal = sampler == 'autoregressive'
+ if self.causal:
+ self.vocab_size = codebook_size
+
+ self.tok_emb = nn.Embedding(self.vocab_size, self.n_embd)
+ self.pos_emb = nn.Parameter(
+ torch.zeros(1, self.block_size, self.n_embd))
+ self.segm_emb = nn.Embedding(self.segm_codebook_size, self.n_embd)
+ self.start_tok = nn.Parameter(torch.zeros(1, 1, self.n_embd))
+ self.drop = nn.Dropout(embd_pdrop)
+
+ # transformer
+ self.blocks = nn.Sequential(*[
+ Block(bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop,
+ latent_shape, sampler) for _ in range(self.n_layers)
+ ])
+ # decoder head
+ self.ln_f = nn.LayerNorm(self.n_embd)
+ self.head = nn.Linear(self.n_embd, self.codebook_size, bias=False)
+
+ def get_block_size(self):
+ return self.block_size
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def forward(self, idx, segm_tokens, t=None):
+ # each index maps to a (learnable) vector
+ token_embeddings = self.tok_emb(idx)
+
+ segm_embeddings = self.segm_emb(segm_tokens)
+
+ if self.causal:
+ token_embeddings = torch.cat((self.start_tok.repeat(
+ token_embeddings.size(0), 1, 1), token_embeddings),
+ dim=1)
+
+ t = token_embeddings.shape[1]
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
+ # each position maps to a (learnable) vector
+
+ position_embeddings = self.pos_emb[:, :t, :]
+
+ x = token_embeddings + position_embeddings + segm_embeddings
+ x = self.drop(x)
+ for block in self.blocks:
+ x = block(x)
+ x = self.ln_f(x)
+ logits = self.head(x)
+
+ return logits
+
+
+class TransformerMultiHead(nn.Module):
+ """ the full GPT language model, with a context size of block_size """
+
+ def __init__(self,
+ codebook_size,
+ segm_codebook_size,
+ texture_codebook_size,
+ bert_n_emb,
+ bert_n_layers,
+ bert_n_head,
+ block_size,
+ latent_shape,
+ embd_pdrop,
+ resid_pdrop,
+ attn_pdrop,
+ num_head,
+ sampler='absorbing'):
+ super().__init__()
+
+ self.vocab_size = codebook_size + 1
+ self.n_embd = bert_n_emb
+ self.block_size = block_size
+ self.n_layers = bert_n_layers
+ self.codebook_size = codebook_size
+ self.segm_codebook_size = segm_codebook_size
+ self.texture_codebook_size = texture_codebook_size
+ self.causal = sampler == 'autoregressive'
+ if self.causal:
+ self.vocab_size = codebook_size
+
+ self.tok_emb = nn.Embedding(self.vocab_size, self.n_embd)
+ self.pos_emb = nn.Parameter(
+ torch.zeros(1, self.block_size, self.n_embd))
+ self.segm_emb = nn.Embedding(self.segm_codebook_size, self.n_embd)
+ self.texture_emb = nn.Embedding(self.texture_codebook_size,
+ self.n_embd)
+ self.start_tok = nn.Parameter(torch.zeros(1, 1, self.n_embd))
+ self.drop = nn.Dropout(embd_pdrop)
+
+ # transformer
+ self.blocks = nn.Sequential(*[
+ Block(bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop,
+ latent_shape, sampler) for _ in range(self.n_layers)
+ ])
+ # decoder head
+ self.num_head = num_head
+ self.head_class_num = codebook_size // self.num_head
+ self.ln_f = nn.LayerNorm(self.n_embd)
+ self.head_list = nn.ModuleList([
+ nn.Linear(self.n_embd, self.head_class_num, bias=False)
+ for _ in range(self.num_head)
+ ])
+
+ def get_block_size(self):
+ return self.block_size
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def forward(self, idx, segm_tokens, texture_tokens, t=None):
+ # each index maps to a (learnable) vector
+ token_embeddings = self.tok_emb(idx)
+ segm_embeddings = self.segm_emb(segm_tokens)
+ texture_embeddings = self.texture_emb(texture_tokens)
+
+ if self.causal:
+ token_embeddings = torch.cat((self.start_tok.repeat(
+ token_embeddings.size(0), 1, 1), token_embeddings),
+ dim=1)
+
+ t = token_embeddings.shape[1]
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
+ # each position maps to a (learnable) vector
+
+ position_embeddings = self.pos_emb[:, :t, :]
+
+ x = token_embeddings + position_embeddings + segm_embeddings + texture_embeddings
+ x = self.drop(x)
+ for block in self.blocks:
+ x = block(x)
+ x = self.ln_f(x)
+ logits_list = [self.head_list[i](x) for i in range(self.num_head)]
+
+ return logits_list
diff --git a/Text2Human/models/archs/unet_arch.py b/Text2Human/models/archs/unet_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..b110d6938a0a1565e07518bb98a04eb608fc3f14
--- /dev/null
+++ b/Text2Human/models/archs/unet_arch.py
@@ -0,0 +1,693 @@
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
+ build_norm_layer, build_upsample_layer, constant_init,
+ kaiming_init)
+from mmcv.runner import load_checkpoint
+from mmcv.utils.parrots_wrapper import _BatchNorm
+from mmseg.utils import get_root_logger
+
+
+class UpConvBlock(nn.Module):
+ """Upsample convolution block in decoder for UNet.
+
+ This upsample convolution block consists of one upsample module
+ followed by one convolution block. The upsample module expands the
+ high-level low-resolution feature map and the convolution block fuses
+ the upsampled high-level low-resolution feature map and the low-level
+ high-resolution feature map from encoder.
+
+ Args:
+ conv_block (nn.Sequential): Sequential of convolutional layers.
+ in_channels (int): Number of input channels of the high-level
+ skip_channels (int): Number of input channels of the low-level
+ high-resolution feature map from encoder.
+ out_channels (int): Number of output channels.
+ num_convs (int): Number of convolutional layers in the conv_block.
+ Default: 2.
+ stride (int): Stride of convolutional layer in conv_block. Default: 1.
+ dilation (int): Dilation rate of convolutional layer in conv_block.
+ Default: 1.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ upsample_cfg (dict): The upsample config of the upsample module in
+ decoder. Default: dict(type='InterpConv'). If the size of
+ high-level feature map is the same as that of skip feature map
+ (low-level feature map from encoder), it does not need upsample the
+ high-level feature map and the upsample_cfg is None.
+ dcn (bool): Use deformable convoluton in convolutional layer or not.
+ Default: None.
+ plugins (dict): plugins for convolutional layers. Default: None.
+ """
+
+ def __init__(self,
+ conv_block,
+ in_channels,
+ skip_channels,
+ out_channels,
+ num_convs=2,
+ stride=1,
+ dilation=1,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ dcn=None,
+ plugins=None):
+ super(UpConvBlock, self).__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+
+ self.conv_block = conv_block(
+ in_channels=2 * skip_channels,
+ out_channels=out_channels,
+ num_convs=num_convs,
+ stride=stride,
+ dilation=dilation,
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ dcn=None,
+ plugins=None)
+ if upsample_cfg is not None:
+ self.upsample = build_upsample_layer(
+ cfg=upsample_cfg,
+ in_channels=in_channels,
+ out_channels=skip_channels,
+ with_cp=with_cp,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ else:
+ self.upsample = ConvModule(
+ in_channels,
+ skip_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ def forward(self, skip, x):
+ """Forward function."""
+
+ x = self.upsample(x)
+ out = torch.cat([skip, x], dim=1)
+ out = self.conv_block(out)
+
+ return out
+
+
+class BasicConvBlock(nn.Module):
+ """Basic convolutional block for UNet.
+
+ This module consists of several plain convolutional layers.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ num_convs (int): Number of convolutional layers. Default: 2.
+ stride (int): Whether use stride convolution to downsample
+ the input feature map. If stride=2, it only uses stride convolution
+ in the first convolutional layer to downsample the input feature
+ map. Options are 1 or 2. Default: 1.
+ dilation (int): Whether use dilated convolution to expand the
+ receptive field. Set dilation rate of each convolutional layer and
+ the dilation rate of the first convolutional layer is always 1.
+ Default: 1.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ dcn (bool): Use deformable convoluton in convolutional layer or not.
+ Default: None.
+ plugins (dict): plugins for convolutional layers. Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_convs=2,
+ stride=1,
+ dilation=1,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ dcn=None,
+ plugins=None):
+ super(BasicConvBlock, self).__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+
+ self.with_cp = with_cp
+ convs = []
+ for i in range(num_convs):
+ convs.append(
+ ConvModule(
+ in_channels=in_channels if i == 0 else out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride if i == 0 else 1,
+ dilation=1 if i == 0 else dilation,
+ padding=1 if i == 0 else dilation,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+
+ self.convs = nn.Sequential(*convs)
+
+ def forward(self, x):
+ """Forward function."""
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(self.convs, x)
+ else:
+ out = self.convs(x)
+ return out
+
+
+class DeconvModule(nn.Module):
+ """Deconvolution upsample module in decoder for UNet (2X upsample).
+
+ This module uses deconvolution to upsample feature map in the decoder
+ of UNet.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ kernel_size (int): Kernel size of the convolutional layer. Default: 4.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ with_cp=False,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ *,
+ kernel_size=4,
+ scale_factor=2):
+ super(DeconvModule, self).__init__()
+
+ assert (kernel_size - scale_factor >= 0) and\
+ (kernel_size - scale_factor) % 2 == 0,\
+ f'kernel_size should be greater than or equal to scale_factor '\
+ f'and (kernel_size - scale_factor) should be even numbers, '\
+ f'while the kernel size is {kernel_size} and scale_factor is '\
+ f'{scale_factor}.'
+
+ stride = scale_factor
+ padding = (kernel_size - scale_factor) // 2
+ self.with_cp = with_cp
+ deconv = nn.ConvTranspose2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding)
+
+ norm_name, norm = build_norm_layer(norm_cfg, out_channels)
+ activate = build_activation_layer(act_cfg)
+ self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
+
+ def forward(self, x):
+ """Forward function."""
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(self.deconv_upsamping, x)
+ else:
+ out = self.deconv_upsamping(x)
+ return out
+
+
+@UPSAMPLE_LAYERS.register_module()
+class InterpConv(nn.Module):
+ """Interpolation upsample module in decoder for UNet.
+
+ This module uses interpolation to upsample feature map in the decoder
+ of UNet. It consists of one interpolation upsample layer and one
+ convolutional layer. It can be one interpolation upsample layer followed
+ by one convolutional layer (conv_first=False) or one convolutional layer
+ followed by one interpolation upsample layer (conv_first=True).
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ conv_first (bool): Whether convolutional layer or interpolation
+ upsample layer first. Default: False. It means interpolation
+ upsample layer followed by one convolutional layer.
+ kernel_size (int): Kernel size of the convolutional layer. Default: 1.
+ stride (int): Stride of the convolutional layer. Default: 1.
+ padding (int): Padding of the convolutional layer. Default: 1.
+ upsampe_cfg (dict): Interpolation config of the upsample layer.
+ Default: dict(
+ scale_factor=2, mode='bilinear', align_corners=False).
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ with_cp=False,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ *,
+ conv_cfg=None,
+ conv_first=False,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ upsampe_cfg=dict(
+ scale_factor=2, mode='bilinear', align_corners=False)):
+ super(InterpConv, self).__init__()
+
+ self.with_cp = with_cp
+ conv = ConvModule(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ upsample = nn.Upsample(**upsampe_cfg)
+ if conv_first:
+ self.interp_upsample = nn.Sequential(conv, upsample)
+ else:
+ self.interp_upsample = nn.Sequential(upsample, conv)
+
+ def forward(self, x):
+ """Forward function."""
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(self.interp_upsample, x)
+ else:
+ out = self.interp_upsample(x)
+ return out
+
+
+class UNet(nn.Module):
+ """UNet backbone.
+ U-Net: Convolutional Networks for Biomedical Image Segmentation.
+ https://arxiv.org/pdf/1505.04597.pdf
+
+ Args:
+ in_channels (int): Number of input image channels. Default" 3.
+ base_channels (int): Number of base channels of each stage.
+ The output channels of the first stage. Default: 64.
+ num_stages (int): Number of stages in encoder, normally 5. Default: 5.
+ strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
+ len(strides) is equal to num_stages. Normally the stride of the
+ first stage in encoder is 1. If strides[i]=2, it uses stride
+ convolution to downsample in the correspondence encoder stage.
+ Default: (1, 1, 1, 1, 1).
+ enc_num_convs (Sequence[int]): Number of convolutional layers in the
+ convolution block of the correspondence encoder stage.
+ Default: (2, 2, 2, 2, 2).
+ dec_num_convs (Sequence[int]): Number of convolutional layers in the
+ convolution block of the correspondence decoder stage.
+ Default: (2, 2, 2, 2).
+ downsamples (Sequence[int]): Whether use MaxPool to downsample the
+ feature map after the first stage of encoder
+ (stages: [1, num_stages)). If the correspondence encoder stage use
+ stride convolution (strides[i]=2), it will never use MaxPool to
+ downsample, even downsamples[i-1]=True.
+ Default: (True, True, True, True).
+ enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
+ Default: (1, 1, 1, 1, 1).
+ dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
+ Default: (1, 1, 1, 1).
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ upsample_cfg (dict): The upsample config of the upsample module in
+ decoder. Default: dict(type='InterpConv').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ dcn (bool): Use deformable convolution in convolutional layer or not.
+ Default: None.
+ plugins (dict): plugins for convolutional layers. Default: None.
+
+ Notice:
+ The input image size should be devisible by the whole downsample rate
+ of the encoder. More detail of the whole downsample rate can be found
+ in UNet._check_input_devisible.
+
+ """
+
+ def __init__(self,
+ in_channels=3,
+ base_channels=64,
+ num_stages=5,
+ strides=(1, 1, 1, 1, 1),
+ enc_num_convs=(2, 2, 2, 2, 2),
+ dec_num_convs=(2, 2, 2, 2),
+ downsamples=(True, True, True, True),
+ enc_dilations=(1, 1, 1, 1, 1),
+ dec_dilations=(1, 1, 1, 1),
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ norm_eval=False,
+ dcn=None,
+ plugins=None):
+ super(UNet, self).__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+ assert len(strides) == num_stages, \
+ 'The length of strides should be equal to num_stages, '\
+ f'while the strides is {strides}, the length of '\
+ f'strides is {len(strides)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(enc_num_convs) == num_stages, \
+ 'The length of enc_num_convs should be equal to num_stages, '\
+ f'while the enc_num_convs is {enc_num_convs}, the length of '\
+ f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(dec_num_convs) == (num_stages-1), \
+ 'The length of dec_num_convs should be equal to (num_stages-1), '\
+ f'while the dec_num_convs is {dec_num_convs}, the length of '\
+ f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(downsamples) == (num_stages-1), \
+ 'The length of downsamples should be equal to (num_stages-1), '\
+ f'while the downsamples is {downsamples}, the length of '\
+ f'downsamples is {len(downsamples)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(enc_dilations) == num_stages, \
+ 'The length of enc_dilations should be equal to num_stages, '\
+ f'while the enc_dilations is {enc_dilations}, the length of '\
+ f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(dec_dilations) == (num_stages-1), \
+ 'The length of dec_dilations should be equal to (num_stages-1), '\
+ f'while the dec_dilations is {dec_dilations}, the length of '\
+ f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
+ f'{num_stages}.'
+ self.num_stages = num_stages
+ self.strides = strides
+ self.downsamples = downsamples
+ self.norm_eval = norm_eval
+
+ self.encoder = nn.ModuleList()
+ self.decoder = nn.ModuleList()
+
+ for i in range(num_stages):
+ enc_conv_block = []
+ if i != 0:
+ if strides[i] == 1 and downsamples[i - 1]:
+ enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
+ upsample = (strides[i] != 1 or downsamples[i - 1])
+ self.decoder.append(
+ UpConvBlock(
+ conv_block=BasicConvBlock,
+ in_channels=base_channels * 2**i,
+ skip_channels=base_channels * 2**(i - 1),
+ out_channels=base_channels * 2**(i - 1),
+ num_convs=dec_num_convs[i - 1],
+ stride=1,
+ dilation=dec_dilations[i - 1],
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ upsample_cfg=upsample_cfg if upsample else None,
+ dcn=None,
+ plugins=None))
+
+ enc_conv_block.append(
+ BasicConvBlock(
+ in_channels=in_channels,
+ out_channels=base_channels * 2**i,
+ num_convs=enc_num_convs[i],
+ stride=strides[i],
+ dilation=enc_dilations[i],
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ dcn=None,
+ plugins=None))
+ self.encoder.append((nn.Sequential(*enc_conv_block)))
+ in_channels = base_channels * 2**i
+
+ def forward(self, x):
+ enc_outs = []
+
+ for enc in self.encoder:
+ x = enc(x)
+ enc_outs.append(x)
+ dec_outs = [x]
+ for i in reversed(range(len(self.decoder))):
+ x = self.decoder[i](enc_outs[i], x)
+ dec_outs.append(x)
+
+ return dec_outs
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+
+class ShapeUNet(nn.Module):
+ """ShapeUNet backbone with small modifications.
+ U-Net: Convolutional Networks for Biomedical Image Segmentation.
+ https://arxiv.org/pdf/1505.04597.pdf
+
+ Args:
+ in_channels (int): Number of input image channels. Default" 3.
+ base_channels (int): Number of base channels of each stage.
+ The output channels of the first stage. Default: 64.
+ num_stages (int): Number of stages in encoder, normally 5. Default: 5.
+ strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
+ len(strides) is equal to num_stages. Normally the stride of the
+ first stage in encoder is 1. If strides[i]=2, it uses stride
+ convolution to downsample in the correspondance encoder stage.
+ Default: (1, 1, 1, 1, 1).
+ enc_num_convs (Sequence[int]): Number of convolutional layers in the
+ convolution block of the correspondance encoder stage.
+ Default: (2, 2, 2, 2, 2).
+ dec_num_convs (Sequence[int]): Number of convolutional layers in the
+ convolution block of the correspondance decoder stage.
+ Default: (2, 2, 2, 2).
+ downsamples (Sequence[int]): Whether use MaxPool to downsample the
+ feature map after the first stage of encoder
+ (stages: [1, num_stages)). If the correspondance encoder stage use
+ stride convolution (strides[i]=2), it will never use MaxPool to
+ downsample, even downsamples[i-1]=True.
+ Default: (True, True, True, True).
+ enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
+ Default: (1, 1, 1, 1, 1).
+ dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
+ Default: (1, 1, 1, 1).
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ upsample_cfg (dict): The upsample config of the upsample module in
+ decoder. Default: dict(type='InterpConv').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ dcn (bool): Use deformable convoluton in convolutional layer or not.
+ Default: None.
+ plugins (dict): plugins for convolutional layers. Default: None.
+
+ Notice:
+ The input image size should be devisible by the whole downsample rate
+ of the encoder. More detail of the whole downsample rate can be found
+ in UNet._check_input_devisible.
+
+ """
+
+ def __init__(self,
+ in_channels=3,
+ base_channels=64,
+ num_stages=5,
+ attr_embedding=128,
+ strides=(1, 1, 1, 1, 1),
+ enc_num_convs=(2, 2, 2, 2, 2),
+ dec_num_convs=(2, 2, 2, 2),
+ downsamples=(True, True, True, True),
+ enc_dilations=(1, 1, 1, 1, 1),
+ dec_dilations=(1, 1, 1, 1),
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ norm_eval=False,
+ dcn=None,
+ plugins=None):
+ super(ShapeUNet, self).__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+ assert len(strides) == num_stages, \
+ 'The length of strides should be equal to num_stages, '\
+ f'while the strides is {strides}, the length of '\
+ f'strides is {len(strides)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(enc_num_convs) == num_stages, \
+ 'The length of enc_num_convs should be equal to num_stages, '\
+ f'while the enc_num_convs is {enc_num_convs}, the length of '\
+ f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(dec_num_convs) == (num_stages-1), \
+ 'The length of dec_num_convs should be equal to (num_stages-1), '\
+ f'while the dec_num_convs is {dec_num_convs}, the length of '\
+ f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(downsamples) == (num_stages-1), \
+ 'The length of downsamples should be equal to (num_stages-1), '\
+ f'while the downsamples is {downsamples}, the length of '\
+ f'downsamples is {len(downsamples)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(enc_dilations) == num_stages, \
+ 'The length of enc_dilations should be equal to num_stages, '\
+ f'while the enc_dilations is {enc_dilations}, the length of '\
+ f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(dec_dilations) == (num_stages-1), \
+ 'The length of dec_dilations should be equal to (num_stages-1), '\
+ f'while the dec_dilations is {dec_dilations}, the length of '\
+ f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
+ f'{num_stages}.'
+ self.num_stages = num_stages
+ self.strides = strides
+ self.downsamples = downsamples
+ self.norm_eval = norm_eval
+
+ self.encoder = nn.ModuleList()
+ self.decoder = nn.ModuleList()
+
+ for i in range(num_stages):
+ enc_conv_block = []
+ if i != 0:
+ if strides[i] == 1 and downsamples[i - 1]:
+ enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
+ upsample = (strides[i] != 1 or downsamples[i - 1])
+ self.decoder.append(
+ UpConvBlock(
+ conv_block=BasicConvBlock,
+ in_channels=base_channels * 2**i,
+ skip_channels=base_channels * 2**(i - 1),
+ out_channels=base_channels * 2**(i - 1),
+ num_convs=dec_num_convs[i - 1],
+ stride=1,
+ dilation=dec_dilations[i - 1],
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ upsample_cfg=upsample_cfg if upsample else None,
+ dcn=None,
+ plugins=None))
+
+ enc_conv_block.append(
+ BasicConvBlock(
+ in_channels=in_channels + attr_embedding,
+ out_channels=base_channels * 2**i,
+ num_convs=enc_num_convs[i],
+ stride=strides[i],
+ dilation=enc_dilations[i],
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ dcn=None,
+ plugins=None))
+ self.encoder.append((nn.Sequential(*enc_conv_block)))
+ in_channels = base_channels * 2**i
+
+ def forward(self, x, attr_embedding):
+ enc_outs = []
+ Be, Ce = attr_embedding.size()
+ for enc in self.encoder:
+ _, _, H, W = x.size()
+ x = enc(
+ torch.cat([
+ x,
+ attr_embedding.view(Be, Ce, 1, 1).expand((Be, Ce, H, W))
+ ],
+ dim=1))
+ enc_outs.append(x)
+ dec_outs = [x]
+ for i in reversed(range(len(self.decoder))):
+ x = self.decoder[i](enc_outs[i], x)
+ dec_outs.append(x)
+
+ return dec_outs
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+ else:
+ raise TypeError('pretrained must be a str or None')
diff --git a/Text2Human/models/archs/vqgan_arch.py b/Text2Human/models/archs/vqgan_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..51980ec048dc25e5c84ae26ba6bde384d1d2a94f
--- /dev/null
+++ b/Text2Human/models/archs/vqgan_arch.py
@@ -0,0 +1,1203 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+from urllib.request import proxy_bypass
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+
+
+class VectorQuantizer(nn.Module):
+ """
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
+ """
+
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
+ # backwards compatibility we use the buggy version by default, but you can
+ # specify legacy=False to fix it.
+ def __init__(self,
+ n_e,
+ e_dim,
+ beta,
+ remap=None,
+ unknown_index="random",
+ sane_index_shape=False,
+ legacy=True):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.legacy = legacy
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_e
+
+ self.sane_index_shape = sane_index_shape
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ match = (inds[:, :, None] == used[None, None, ...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2) < 1
+ if self.unknown_index == "random":
+ new[unknown] = torch.randint(
+ 0, self.re_embed,
+ size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
+ assert rescale_logits == False, "Only for interface compatible with Gumbel"
+ assert return_logits == False, "Only for interface compatible with Gumbel"
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
+ perplexity = None
+ min_encodings = None
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
+ torch.mean((z_q - z.detach()) ** 2)
+ else:
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
+ torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
+
+ if self.remap is not None:
+ min_encoding_indices = min_encoding_indices.reshape(
+ z.shape[0], -1) # add batch axis
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
+ min_encoding_indices = min_encoding_indices.reshape(-1,
+ 1) # flatten
+
+ if self.sane_index_shape:
+ min_encoding_indices = min_encoding_indices.reshape(
+ z_q.shape[0], z_q.shape[2], z_q.shape[3])
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ if self.remap is not None:
+ indices = indices.reshape(shape[0], -1) # add batch axis
+ indices = self.unmap_to_all(indices)
+ indices = indices.reshape(-1) # flatten again
+
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class VectorQuantizerTexture(nn.Module):
+ """
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
+ """
+
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
+ # backwards compatibility we use the buggy version by default, but you can
+ # specify legacy=False to fix it.
+ def __init__(self,
+ n_e,
+ e_dim,
+ beta,
+ remap=None,
+ unknown_index="random",
+ sane_index_shape=False,
+ legacy=True):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.legacy = legacy
+
+ # TODO: decide number of embeddings
+ self.embedding_list = nn.ModuleList(
+ [nn.Embedding(self.n_e, self.e_dim) for i in range(18)])
+ for embedding in self.embedding_list:
+ embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_e
+
+ self.sane_index_shape = sane_index_shape
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ match = (inds[:, :, None] == used[None, None, ...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2) < 1
+ if self.unknown_index == "random":
+ new[unknown] = torch.randint(
+ 0, self.re_embed,
+ size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self,
+ z,
+ segm_map,
+ temp=None,
+ rescale_logits=False,
+ return_logits=False):
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
+ assert rescale_logits == False, "Only for interface compatible with Gumbel"
+ assert return_logits == False, "Only for interface compatible with Gumbel"
+
+ segm_map = F.interpolate(segm_map, size=z.size()[2:], mode='nearest')
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+
+ # flatten segm_map (b, h, w)
+ segm_map_flatten = segm_map.view(-1)
+
+ z_q = torch.zeros_like(z_flattened)
+ min_encoding_indices_list = []
+ min_encoding_indices_continual = torch.full(
+ segm_map_flatten.size(),
+ fill_value=-1,
+ dtype=torch.long,
+ device=segm_map_flatten.device)
+ for codebook_idx in range(18):
+ min_encoding_indices = torch.full(
+ segm_map_flatten.size(),
+ fill_value=-1,
+ dtype=torch.long,
+ device=segm_map_flatten.device)
+ if torch.sum(segm_map_flatten == codebook_idx) > 0:
+ z_selected = z_flattened[segm_map_flatten == codebook_idx]
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d_selected = torch.sum(
+ z_selected**2, dim=1, keepdim=True) + torch.sum(
+ self.embedding_list[codebook_idx].weight**2,
+ dim=1) - 2 * torch.einsum(
+ 'bd,dn->bn', z_selected,
+ rearrange(self.embedding_list[codebook_idx].weight,
+ 'n d -> d n'))
+ min_encoding_indices_selected = torch.argmin(d_selected, dim=1)
+ z_q_selected = self.embedding_list[codebook_idx](
+ min_encoding_indices_selected)
+ z_q[segm_map_flatten == codebook_idx] = z_q_selected
+ min_encoding_indices[
+ segm_map_flatten ==
+ codebook_idx] = min_encoding_indices_selected
+ min_encoding_indices_continual[
+ segm_map_flatten ==
+ codebook_idx] = min_encoding_indices_selected + 1024 * codebook_idx
+ min_encoding_indices = min_encoding_indices.reshape(
+ z.shape[0], z.shape[1], z.shape[2])
+ min_encoding_indices_list.append(min_encoding_indices)
+
+ min_encoding_indices_continual = min_encoding_indices_continual.reshape(
+ z.shape[0], z.shape[1], z.shape[2])
+ z_q = z_q.view(z.shape)
+ perplexity = None
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
+ torch.mean((z_q - z.detach()) ** 2)
+ else:
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
+ torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
+
+ return z_q, loss, (perplexity, min_encoding_indices_continual,
+ min_encoding_indices_list)
+
+ def get_codebook_entry(self, indices_list, segm_map, shape):
+ # flatten segm_map (b, h, w)
+ segm_map = F.interpolate(
+ segm_map, size=(shape[1], shape[2]), mode='nearest')
+ segm_map_flatten = segm_map.view(-1)
+
+ z_q = torch.zeros((shape[0] * shape[1] * shape[2]),
+ self.e_dim).to(segm_map.device)
+ for codebook_idx in range(18):
+ if torch.sum(segm_map_flatten == codebook_idx) > 0:
+ min_encoding_indices_selected = indices_list[
+ codebook_idx].view(-1)[segm_map_flatten == codebook_idx]
+ z_q_selected = self.embedding_list[codebook_idx](
+ min_encoding_indices_selected)
+ z_q[segm_map_flatten == codebook_idx] = z_q_selected
+
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+def sample_patches(inputs, patch_size=3, stride=1):
+ """Extract sliding local patches from an input feature tensor.
+ The sampled pathes are row-major.
+ Args:
+ inputs (Tensor): the input feature maps, shape: (n, c, h, w).
+ patch_size (int): the spatial size of sampled patches. Default: 3.
+ stride (int): the stride of sampling. Default: 1.
+ Returns:
+ patches (Tensor): extracted patches, shape: (n, c * patch_size *
+ patch_size, n_patches).
+ """
+
+ patches = F.unfold(inputs, (patch_size, patch_size), stride=stride)
+
+ return patches
+
+
+class VectorQuantizerSpatialTextureAware(nn.Module):
+ """
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
+ """
+
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
+ # backwards compatibility we use the buggy version by default, but you can
+ # specify legacy=False to fix it.
+ def __init__(self,
+ n_e,
+ e_dim,
+ beta,
+ spatial_size,
+ remap=None,
+ unknown_index="random",
+ sane_index_shape=False,
+ legacy=True):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim * spatial_size * spatial_size
+ self.beta = beta
+ self.legacy = legacy
+ self.spatial_size = spatial_size
+
+ # TODO: decide number of embeddings
+ self.embedding_list = nn.ModuleList(
+ [nn.Embedding(self.n_e, self.e_dim) for i in range(18)])
+ for embedding in self.embedding_list:
+ embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_e
+
+ self.sane_index_shape = sane_index_shape
+
+ def forward(self,
+ z,
+ segm_map,
+ temp=None,
+ rescale_logits=False,
+ return_logits=False):
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
+ assert rescale_logits == False, "Only for interface compatible with Gumbel"
+ assert return_logits == False, "Only for interface compatible with Gumbel"
+
+ segm_map = F.interpolate(
+ segm_map,
+ size=(z.size(2) // self.spatial_size,
+ z.size(3) // self.spatial_size),
+ mode='nearest')
+
+ # reshape z -> (batch, height, width, channel) and flatten
+ # z = rearrange(z, 'b c h w -> b h w c').contiguous() ?
+ z_patches = sample_patches(
+ z, patch_size=self.spatial_size,
+ stride=self.spatial_size).permute(0, 2, 1)
+ z_patches_flattened = z_patches.reshape(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ # flatten segm_map (b, h, w)
+ segm_map_flatten = segm_map.view(-1)
+
+ z_q = torch.zeros_like(z_patches_flattened)
+ min_encoding_indices_list = []
+ min_encoding_indices_continual = torch.full(
+ segm_map_flatten.size(),
+ fill_value=-1,
+ dtype=torch.long,
+ device=segm_map_flatten.device)
+
+ for codebook_idx in range(18):
+ min_encoding_indices = torch.full(
+ segm_map_flatten.size(),
+ fill_value=-1,
+ dtype=torch.long,
+ device=segm_map_flatten.device)
+ if torch.sum(segm_map_flatten == codebook_idx) > 0:
+ z_selected = z_patches_flattened[segm_map_flatten ==
+ codebook_idx]
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d_selected = torch.sum(
+ z_selected**2, dim=1, keepdim=True) + torch.sum(
+ self.embedding_list[codebook_idx].weight**2,
+ dim=1) - 2 * torch.einsum(
+ 'bd,dn->bn', z_selected,
+ rearrange(self.embedding_list[codebook_idx].weight,
+ 'n d -> d n'))
+ min_encoding_indices_selected = torch.argmin(d_selected, dim=1)
+ z_q_selected = self.embedding_list[codebook_idx](
+ min_encoding_indices_selected)
+ z_q[segm_map_flatten == codebook_idx] = z_q_selected
+ min_encoding_indices[
+ segm_map_flatten ==
+ codebook_idx] = min_encoding_indices_selected
+ min_encoding_indices_continual[
+ segm_map_flatten ==
+ codebook_idx] = min_encoding_indices_selected + self.n_e * codebook_idx
+ min_encoding_indices = min_encoding_indices.reshape(
+ z_patches.shape[0], segm_map.shape[2], segm_map.shape[3])
+ min_encoding_indices_list.append(min_encoding_indices)
+
+ z_q = F.fold(
+ z_q.view(z_patches.shape).permute(0, 2, 1),
+ z.size()[2:],
+ kernel_size=(self.spatial_size, self.spatial_size),
+ stride=self.spatial_size)
+
+ perplexity = None
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
+ torch.mean((z_q - z.detach()) ** 2)
+ else:
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
+ torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ return z_q, loss, (perplexity, min_encoding_indices_continual,
+ min_encoding_indices_list)
+
+ def get_codebook_entry(self, indices_list, segm_map, shape):
+ # flatten segm_map (b, h, w)
+ segm_map = F.interpolate(
+ segm_map, size=(shape[1], shape[2]), mode='nearest')
+ segm_map_flatten = segm_map.view(-1)
+
+ z_q = torch.zeros((shape[0] * shape[1] * shape[2]),
+ self.e_dim).to(segm_map.device)
+ for codebook_idx in range(18):
+ if torch.sum(segm_map_flatten == codebook_idx) > 0:
+ min_encoding_indices_selected = indices_list[
+ codebook_idx].view(-1)[segm_map_flatten == codebook_idx]
+ z_q_selected = self.embedding_list[codebook_idx](
+ min_encoding_indices_selected)
+ z_q[segm_map_flatten == codebook_idx] = z_q_selected
+
+ z_q = F.fold(
+ z_q.view(((shape[0], shape[1] * shape[2],
+ self.e_dim))).permute(0, 2, 1),
+ (shape[1] * self.spatial_size, shape[2] * self.spatial_size),
+ kernel_size=(self.spatial_size, self.spatial_size),
+ stride=self.spatial_size)
+
+ return z_q
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(
+ x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+
+ def __init__(self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout,
+ temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class AttnBlock(nn.Module):
+
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h * w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b, c, h * w) # b,c,hw
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(
+ v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class Model(nn.Module):
+
+ def __init__(self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ use_timestep=True):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = self.ch * 4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch, self.temb_ch),
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1, ) + tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ skip_in = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch * in_ch_mult[i_level]
+ block.append(
+ ResnetBlock(
+ in_channels=block_in + skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in, out_ch, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x, t=None):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()],
+ dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Encoder(nn.Module):
+
+ def __init__(self,
+ ch,
+ num_res_blocks,
+ attn_resolutions,
+ in_channels,
+ resolution,
+ z_channels,
+ ch_mult=(1, 2, 4, 8),
+ dropout=0.0,
+ resamp_with_conv=True,
+ double_z=True):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1, ) + tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in,
+ 2 * z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
+
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ resolution,
+ z_channels,
+ ch,
+ out_ch,
+ num_res_blocks,
+ attn_resolutions,
+ ch_mult=(1, 2, 4, 8),
+ dropout=0.0,
+ resamp_with_conv=True,
+ give_pre_end=False):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1, ) + tuple(ch_mult)
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2**(self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res // 2)
+ print("Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(
+ z_channels, block_in, kernel_size=3, stride=1, padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in, out_ch, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, z, bot_h=None):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+ if i_level == 4 and bot_h is not None:
+ h += bot_h
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_feature_top(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+ if i_level == 4:
+ return h
+
+ def get_feature_middle(self, z, mid_h):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+ if i_level == 4:
+ h += mid_h
+ if i_level == 3:
+ return h
+
+
+class DecoderRes(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ resolution,
+ z_channels,
+ ch,
+ num_res_blocks,
+ ch_mult=(1, 2, 4, 8),
+ dropout=0.0,
+ give_pre_end=False):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1, ) + tuple(ch_mult)
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2**(self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res // 2)
+ print("Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(
+ z_channels, block_in, kernel_size=3, stride=1, padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ return h
+
+
+# patch based discriminator
+class Discriminator(nn.Module):
+
+ def __init__(self, nc, ndf, n_layers=3):
+ super().__init__()
+
+ layers = [
+ nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1),
+ nn.LeakyReLU(0.2, True)
+ ]
+ ndf_mult = 1
+ ndf_mult_prev = 1
+ for n in range(1,
+ n_layers): # gradually increase the number of filters
+ ndf_mult_prev = ndf_mult
+ ndf_mult = min(2**n, 8)
+ layers += [
+ nn.Conv2d(
+ ndf * ndf_mult_prev,
+ ndf * ndf_mult,
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ bias=False),
+ nn.BatchNorm2d(ndf * ndf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ ndf_mult_prev = ndf_mult
+ ndf_mult = min(2**n_layers, 8)
+
+ layers += [
+ nn.Conv2d(
+ ndf * ndf_mult_prev,
+ ndf * ndf_mult,
+ kernel_size=4,
+ stride=1,
+ padding=1,
+ bias=False),
+ nn.BatchNorm2d(ndf * ndf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ layers += [
+ nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)
+ ] # output 1 channel prediction map
+ self.main = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.main(x)
diff --git a/Text2Human/models/hierarchy_inference_model.py b/Text2Human/models/hierarchy_inference_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..3116307caa051cec1a2d0e3793f459f92b44fd80
--- /dev/null
+++ b/Text2Human/models/hierarchy_inference_model.py
@@ -0,0 +1,363 @@
+import logging
+import math
+from collections import OrderedDict
+
+import torch
+import torch.nn.functional as F
+from torchvision.utils import save_image
+
+from models.archs.fcn_arch import MultiHeadFCNHead
+from models.archs.unet_arch import UNet
+from models.archs.vqgan_arch import (Decoder, DecoderRes, Encoder,
+ VectorQuantizerSpatialTextureAware,
+ VectorQuantizerTexture)
+from models.losses.accuracy import accuracy
+from models.losses.cross_entropy_loss import CrossEntropyLoss
+
+logger = logging.getLogger('base')
+
+
+class VQGANTextureAwareSpatialHierarchyInferenceModel():
+
+ def __init__(self, opt):
+ self.opt = opt
+ self.device = torch.device('cuda')
+ self.is_train = opt['is_train']
+
+ self.top_encoder = Encoder(
+ ch=opt['top_ch'],
+ num_res_blocks=opt['top_num_res_blocks'],
+ attn_resolutions=opt['top_attn_resolutions'],
+ ch_mult=opt['top_ch_mult'],
+ in_channels=opt['top_in_channels'],
+ resolution=opt['top_resolution'],
+ z_channels=opt['top_z_channels'],
+ double_z=opt['top_double_z'],
+ dropout=opt['top_dropout']).to(self.device)
+ self.decoder = Decoder(
+ in_channels=opt['top_in_channels'],
+ resolution=opt['top_resolution'],
+ z_channels=opt['top_z_channels'],
+ ch=opt['top_ch'],
+ out_ch=opt['top_out_ch'],
+ num_res_blocks=opt['top_num_res_blocks'],
+ attn_resolutions=opt['top_attn_resolutions'],
+ ch_mult=opt['top_ch_mult'],
+ dropout=opt['top_dropout'],
+ resamp_with_conv=True,
+ give_pre_end=False).to(self.device)
+ self.top_quantize = VectorQuantizerTexture(
+ 1024, opt['embed_dim'], beta=0.25).to(self.device)
+ self.top_quant_conv = torch.nn.Conv2d(opt["top_z_channels"],
+ opt['embed_dim'],
+ 1).to(self.device)
+ self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
+ opt["top_z_channels"],
+ 1).to(self.device)
+ self.load_top_pretrain_models()
+
+ self.bot_encoder = Encoder(
+ ch=opt['bot_ch'],
+ num_res_blocks=opt['bot_num_res_blocks'],
+ attn_resolutions=opt['bot_attn_resolutions'],
+ ch_mult=opt['bot_ch_mult'],
+ in_channels=opt['bot_in_channels'],
+ resolution=opt['bot_resolution'],
+ z_channels=opt['bot_z_channels'],
+ double_z=opt['bot_double_z'],
+ dropout=opt['bot_dropout']).to(self.device)
+ self.bot_decoder_res = DecoderRes(
+ in_channels=opt['bot_in_channels'],
+ resolution=opt['bot_resolution'],
+ z_channels=opt['bot_z_channels'],
+ ch=opt['bot_ch'],
+ num_res_blocks=opt['bot_num_res_blocks'],
+ ch_mult=opt['bot_ch_mult'],
+ dropout=opt['bot_dropout'],
+ give_pre_end=False).to(self.device)
+ self.bot_quantize = VectorQuantizerSpatialTextureAware(
+ opt['bot_n_embed'],
+ opt['embed_dim'],
+ beta=0.25,
+ spatial_size=opt['codebook_spatial_size']).to(self.device)
+ self.bot_quant_conv = torch.nn.Conv2d(opt["bot_z_channels"],
+ opt['embed_dim'],
+ 1).to(self.device)
+ self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
+ opt["bot_z_channels"],
+ 1).to(self.device)
+
+ self.load_bot_pretrain_network()
+
+ self.guidance_encoder = UNet(
+ in_channels=opt['encoder_in_channels']).to(self.device)
+ self.index_decoder = MultiHeadFCNHead(
+ in_channels=opt['fc_in_channels'],
+ in_index=opt['fc_in_index'],
+ channels=opt['fc_channels'],
+ num_convs=opt['fc_num_convs'],
+ concat_input=opt['fc_concat_input'],
+ dropout_ratio=opt['fc_dropout_ratio'],
+ num_classes=opt['fc_num_classes'],
+ align_corners=opt['fc_align_corners'],
+ num_head=18).to(self.device)
+
+ self.init_training_settings()
+
+ def init_training_settings(self):
+ optim_params = []
+ for v in self.guidance_encoder.parameters():
+ if v.requires_grad:
+ optim_params.append(v)
+ for v in self.index_decoder.parameters():
+ if v.requires_grad:
+ optim_params.append(v)
+ # set up optimizers
+ if self.opt['optimizer'] == 'Adam':
+ self.optimizer = torch.optim.Adam(
+ optim_params,
+ self.opt['lr'],
+ weight_decay=self.opt['weight_decay'])
+ elif self.opt['optimizer'] == 'SGD':
+ self.optimizer = torch.optim.SGD(
+ optim_params,
+ self.opt['lr'],
+ momentum=self.opt['momentum'],
+ weight_decay=self.opt['weight_decay'])
+ self.log_dict = OrderedDict()
+ if self.opt['loss_function'] == 'cross_entropy':
+ self.loss_func = CrossEntropyLoss().to(self.device)
+
+ def load_top_pretrain_models(self):
+ # load pretrained vqgan for segmentation mask
+ top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
+ self.top_encoder.load_state_dict(
+ top_vae_checkpoint['encoder'], strict=True)
+ self.decoder.load_state_dict(
+ top_vae_checkpoint['decoder'], strict=True)
+ self.top_quantize.load_state_dict(
+ top_vae_checkpoint['quantize'], strict=True)
+ self.top_quant_conv.load_state_dict(
+ top_vae_checkpoint['quant_conv'], strict=True)
+ self.top_post_quant_conv.load_state_dict(
+ top_vae_checkpoint['post_quant_conv'], strict=True)
+ self.top_encoder.eval()
+ self.top_quantize.eval()
+ self.top_quant_conv.eval()
+ self.top_post_quant_conv.eval()
+
+ def load_bot_pretrain_network(self):
+ checkpoint = torch.load(self.opt['bot_vae_path'])
+ self.bot_encoder.load_state_dict(
+ checkpoint['bot_encoder'], strict=True)
+ self.bot_decoder_res.load_state_dict(
+ checkpoint['bot_decoder_res'], strict=True)
+ self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
+ self.bot_quantize.load_state_dict(
+ checkpoint['bot_quantize'], strict=True)
+ self.bot_quant_conv.load_state_dict(
+ checkpoint['bot_quant_conv'], strict=True)
+ self.bot_post_quant_conv.load_state_dict(
+ checkpoint['bot_post_quant_conv'], strict=True)
+
+ self.bot_encoder.eval()
+ self.bot_decoder_res.eval()
+ self.decoder.eval()
+ self.bot_quantize.eval()
+ self.bot_quant_conv.eval()
+ self.bot_post_quant_conv.eval()
+
+ def top_encode(self, x, mask):
+ h = self.top_encoder(x)
+ h = self.top_quant_conv(h)
+ quant, _, _ = self.top_quantize(h, mask)
+ quant = self.top_post_quant_conv(quant)
+
+ return quant, quant
+
+ def feed_data(self, data):
+ self.image = data['image'].to(self.device)
+ self.texture_mask = data['texture_mask'].float().to(self.device)
+ self.get_gt_indices()
+
+ self.texture_tokens = F.interpolate(
+ self.texture_mask, size=(32, 16),
+ mode='nearest').view(self.image.size(0), -1).long()
+
+ def bot_encode(self, x, mask):
+ h = self.bot_encoder(x)
+ h = self.bot_quant_conv(h)
+ _, _, (_, _, indices_list) = self.bot_quantize(h, mask)
+
+ return indices_list
+
+ def get_gt_indices(self):
+ self.quant_t, self.feature_t = self.top_encode(self.image,
+ self.texture_mask)
+ self.gt_indices_list = self.bot_encode(self.image, self.texture_mask)
+
+ def index_to_image(self, index_bottom_list, texture_mask):
+ quant_b = self.bot_quantize.get_codebook_entry(
+ index_bottom_list, texture_mask,
+ (index_bottom_list[0].size(0), index_bottom_list[0].size(1),
+ index_bottom_list[0].size(2),
+ self.opt["bot_z_channels"])) #.permute(0, 3, 1, 2)
+ quant_b = self.bot_post_quant_conv(quant_b)
+ bot_dec_res = self.bot_decoder_res(quant_b)
+
+ dec = self.decoder(self.quant_t, bot_h=bot_dec_res)
+
+ return dec
+
+ def get_vis(self, pred_img_index, rec_img_index, texture_mask, save_path):
+ rec_img = self.index_to_image(rec_img_index, texture_mask)
+ pred_img = self.index_to_image(pred_img_index, texture_mask)
+
+ base_img = self.decoder(self.quant_t)
+ img_cat = torch.cat([
+ self.image,
+ rec_img,
+ base_img,
+ pred_img,
+ ], dim=3).detach()
+ img_cat = ((img_cat + 1) / 2)
+ img_cat = img_cat.clamp_(0, 1)
+ save_image(img_cat, save_path, nrow=1, padding=4)
+
+ def optimize_parameters(self):
+ self.guidance_encoder.train()
+ self.index_decoder.train()
+
+ self.feature_enc = self.guidance_encoder(self.feature_t)
+ self.memory_logits_list = self.index_decoder(self.feature_enc)
+
+ loss = 0
+ for i in range(18):
+ loss += self.loss_func(
+ self.memory_logits_list[i],
+ self.gt_indices_list[i],
+ ignore_index=-1)
+
+ self.optimizer.zero_grad()
+ loss.backward()
+ self.optimizer.step()
+
+ self.log_dict['loss_total'] = loss
+
+ def inference(self, data_loader, save_dir):
+ self.guidance_encoder.eval()
+ self.index_decoder.eval()
+
+ acc = 0
+ num = 0
+
+ for _, data in enumerate(data_loader):
+ self.feed_data(data)
+ img_name = data['img_name']
+
+ num += self.image.size(0)
+
+ texture_mask_flatten = self.texture_tokens.view(-1)
+ min_encodings_indices_list = [
+ torch.full(
+ texture_mask_flatten.size(),
+ fill_value=-1,
+ dtype=torch.long,
+ device=texture_mask_flatten.device) for _ in range(18)
+ ]
+ with torch.no_grad():
+ self.feature_enc = self.guidance_encoder(self.feature_t)
+ memory_logits_list = self.index_decoder(self.feature_enc)
+ # memory_indices_pred = memory_logits.argmax(dim=1)
+ batch_acc = 0
+ for codebook_idx, memory_logits in enumerate(memory_logits_list):
+ region_of_interest = texture_mask_flatten == codebook_idx
+ if torch.sum(region_of_interest) > 0:
+ memory_indices_pred = memory_logits.argmax(dim=1).view(-1)
+ batch_acc += torch.sum(
+ memory_indices_pred[region_of_interest] ==
+ self.gt_indices_list[codebook_idx].view(
+ -1)[region_of_interest])
+ memory_indices_pred = memory_indices_pred
+ min_encodings_indices_list[codebook_idx][
+ region_of_interest] = memory_indices_pred[
+ region_of_interest]
+ min_encodings_indices_return_list = [
+ min_encodings_indices.view(self.gt_indices_list[0].size())
+ for min_encodings_indices in min_encodings_indices_list
+ ]
+ batch_acc = batch_acc / self.gt_indices_list[codebook_idx].numel(
+ ) * self.image.size(0)
+ acc += batch_acc
+ self.get_vis(min_encodings_indices_return_list,
+ self.gt_indices_list, self.texture_mask,
+ f'{save_dir}/{img_name[0]}')
+
+ self.guidance_encoder.train()
+ self.index_decoder.train()
+ return (acc / num).item()
+
+ def load_network(self):
+ checkpoint = torch.load(self.opt['pretrained_models'])
+ self.guidance_encoder.load_state_dict(
+ checkpoint['guidance_encoder'], strict=True)
+ self.guidance_encoder.eval()
+
+ self.index_decoder.load_state_dict(
+ checkpoint['index_decoder'], strict=True)
+ self.index_decoder.eval()
+
+ def save_network(self, save_path):
+ """Save networks.
+
+ Args:
+ net (nn.Module): Network to be saved.
+ net_label (str): Network label.
+ current_iter (int): Current iter number.
+ """
+
+ save_dict = {}
+ save_dict['guidance_encoder'] = self.guidance_encoder.state_dict()
+ save_dict['index_decoder'] = self.index_decoder.state_dict()
+
+ torch.save(save_dict, save_path)
+
+ def update_learning_rate(self, epoch):
+ """Update learning rate.
+
+ Args:
+ current_iter (int): Current iteration.
+ warmup_iter (int): Warmup iter numbers. -1 for no warmup.
+ Default: -1.
+ """
+ lr = self.optimizer.param_groups[0]['lr']
+
+ if self.opt['lr_decay'] == 'step':
+ lr = self.opt['lr'] * (
+ self.opt['gamma']**(epoch // self.opt['step']))
+ elif self.opt['lr_decay'] == 'cos':
+ lr = self.opt['lr'] * (
+ 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
+ elif self.opt['lr_decay'] == 'linear':
+ lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
+ elif self.opt['lr_decay'] == 'linear2exp':
+ if epoch < self.opt['turning_point'] + 1:
+ # learning rate decay as 95%
+ # at the turning point (1 / 95% = 1.0526)
+ lr = self.opt['lr'] * (
+ 1 - epoch / int(self.opt['turning_point'] * 1.0526))
+ else:
+ lr *= self.opt['gamma']
+ elif self.opt['lr_decay'] == 'schedule':
+ if epoch in self.opt['schedule']:
+ lr *= self.opt['gamma']
+ else:
+ raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
+ # set learning rate
+ for param_group in self.optimizer.param_groups:
+ param_group['lr'] = lr
+
+ return lr
+
+ def get_current_log(self):
+ return self.log_dict
diff --git a/Text2Human/models/hierarchy_vqgan_model.py b/Text2Human/models/hierarchy_vqgan_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b0d657864b5771bdbcd3ba134f4352ea2ca1e19
--- /dev/null
+++ b/Text2Human/models/hierarchy_vqgan_model.py
@@ -0,0 +1,374 @@
+import math
+import sys
+from collections import OrderedDict
+
+sys.path.append('..')
+import lpips
+import torch
+import torch.nn.functional as F
+from torchvision.utils import save_image
+
+from models.archs.vqgan_arch import (Decoder, DecoderRes, Discriminator,
+ Encoder,
+ VectorQuantizerSpatialTextureAware,
+ VectorQuantizerTexture)
+from models.losses.vqgan_loss import (DiffAugment, adopt_weight,
+ calculate_adaptive_weight, hinge_d_loss)
+
+
+class HierarchyVQSpatialTextureAwareModel():
+
+ def __init__(self, opt):
+ self.opt = opt
+ self.device = torch.device('cuda')
+ self.top_encoder = Encoder(
+ ch=opt['top_ch'],
+ num_res_blocks=opt['top_num_res_blocks'],
+ attn_resolutions=opt['top_attn_resolutions'],
+ ch_mult=opt['top_ch_mult'],
+ in_channels=opt['top_in_channels'],
+ resolution=opt['top_resolution'],
+ z_channels=opt['top_z_channels'],
+ double_z=opt['top_double_z'],
+ dropout=opt['top_dropout']).to(self.device)
+ self.decoder = Decoder(
+ in_channels=opt['top_in_channels'],
+ resolution=opt['top_resolution'],
+ z_channels=opt['top_z_channels'],
+ ch=opt['top_ch'],
+ out_ch=opt['top_out_ch'],
+ num_res_blocks=opt['top_num_res_blocks'],
+ attn_resolutions=opt['top_attn_resolutions'],
+ ch_mult=opt['top_ch_mult'],
+ dropout=opt['top_dropout'],
+ resamp_with_conv=True,
+ give_pre_end=False).to(self.device)
+ self.top_quantize = VectorQuantizerTexture(
+ 1024, opt['embed_dim'], beta=0.25).to(self.device)
+ self.top_quant_conv = torch.nn.Conv2d(opt["top_z_channels"],
+ opt['embed_dim'],
+ 1).to(self.device)
+ self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
+ opt["top_z_channels"],
+ 1).to(self.device)
+ self.load_top_pretrain_models()
+
+ self.bot_encoder = Encoder(
+ ch=opt['bot_ch'],
+ num_res_blocks=opt['bot_num_res_blocks'],
+ attn_resolutions=opt['bot_attn_resolutions'],
+ ch_mult=opt['bot_ch_mult'],
+ in_channels=opt['bot_in_channels'],
+ resolution=opt['bot_resolution'],
+ z_channels=opt['bot_z_channels'],
+ double_z=opt['bot_double_z'],
+ dropout=opt['bot_dropout']).to(self.device)
+ self.bot_decoder_res = DecoderRes(
+ in_channels=opt['bot_in_channels'],
+ resolution=opt['bot_resolution'],
+ z_channels=opt['bot_z_channels'],
+ ch=opt['bot_ch'],
+ num_res_blocks=opt['bot_num_res_blocks'],
+ ch_mult=opt['bot_ch_mult'],
+ dropout=opt['bot_dropout'],
+ give_pre_end=False).to(self.device)
+ self.bot_quantize = VectorQuantizerSpatialTextureAware(
+ opt['bot_n_embed'],
+ opt['embed_dim'],
+ beta=0.25,
+ spatial_size=opt['codebook_spatial_size']).to(self.device)
+ self.bot_quant_conv = torch.nn.Conv2d(opt["bot_z_channels"],
+ opt['embed_dim'],
+ 1).to(self.device)
+ self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
+ opt["bot_z_channels"],
+ 1).to(self.device)
+
+ self.disc = Discriminator(
+ opt['n_channels'], opt['ndf'],
+ n_layers=opt['disc_layers']).to(self.device)
+ self.perceptual = lpips.LPIPS(net="vgg").to(self.device)
+ self.perceptual_weight = opt['perceptual_weight']
+ self.disc_start_step = opt['disc_start_step']
+ self.disc_weight_max = opt['disc_weight_max']
+ self.diff_aug = opt['diff_aug']
+ self.policy = "color,translation"
+
+ self.load_discriminator_models()
+
+ self.disc.train()
+
+ self.fix_decoder = opt['fix_decoder']
+
+ self.init_training_settings()
+
+ def load_top_pretrain_models(self):
+ # load pretrained vqgan for segmentation mask
+ top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
+ self.top_encoder.load_state_dict(
+ top_vae_checkpoint['encoder'], strict=True)
+ self.decoder.load_state_dict(
+ top_vae_checkpoint['decoder'], strict=True)
+ self.top_quantize.load_state_dict(
+ top_vae_checkpoint['quantize'], strict=True)
+ self.top_quant_conv.load_state_dict(
+ top_vae_checkpoint['quant_conv'], strict=True)
+ self.top_post_quant_conv.load_state_dict(
+ top_vae_checkpoint['post_quant_conv'], strict=True)
+ self.top_encoder.eval()
+ self.top_quantize.eval()
+ self.top_quant_conv.eval()
+ self.top_post_quant_conv.eval()
+
+ def init_training_settings(self):
+ self.log_dict = OrderedDict()
+ self.configure_optimizers()
+
+ def configure_optimizers(self):
+ optim_params = []
+ for v in self.bot_encoder.parameters():
+ if v.requires_grad:
+ optim_params.append(v)
+ for v in self.bot_decoder_res.parameters():
+ if v.requires_grad:
+ optim_params.append(v)
+ for v in self.bot_quantize.parameters():
+ if v.requires_grad:
+ optim_params.append(v)
+ for v in self.bot_quant_conv.parameters():
+ if v.requires_grad:
+ optim_params.append(v)
+ for v in self.bot_post_quant_conv.parameters():
+ if v.requires_grad:
+ optim_params.append(v)
+ if not self.fix_decoder:
+ for name, v in self.decoder.named_parameters():
+ if v.requires_grad:
+ if 'up.0' in name:
+ optim_params.append(v)
+ if 'up.1' in name:
+ optim_params.append(v)
+ if 'up.2' in name:
+ optim_params.append(v)
+ if 'up.3' in name:
+ optim_params.append(v)
+
+ self.optimizer = torch.optim.Adam(optim_params, lr=self.opt['lr'])
+
+ self.disc_optimizer = torch.optim.Adam(
+ self.disc.parameters(), lr=self.opt['lr'])
+
+ def load_discriminator_models(self):
+ # load pretrained vqgan for segmentation mask
+ top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
+ self.disc.load_state_dict(
+ top_vae_checkpoint['discriminator'], strict=True)
+
+ def save_network(self, save_path):
+ """Save networks.
+ """
+
+ save_dict = {}
+ save_dict['bot_encoder'] = self.bot_encoder.state_dict()
+ save_dict['bot_decoder_res'] = self.bot_decoder_res.state_dict()
+ save_dict['decoder'] = self.decoder.state_dict()
+ save_dict['bot_quantize'] = self.bot_quantize.state_dict()
+ save_dict['bot_quant_conv'] = self.bot_quant_conv.state_dict()
+ save_dict['bot_post_quant_conv'] = self.bot_post_quant_conv.state_dict(
+ )
+ save_dict['discriminator'] = self.disc.state_dict()
+ torch.save(save_dict, save_path)
+
+ def load_network(self):
+ checkpoint = torch.load(self.opt['pretrained_models'])
+ self.bot_encoder.load_state_dict(
+ checkpoint['bot_encoder'], strict=True)
+ self.bot_decoder_res.load_state_dict(
+ checkpoint['bot_decoder_res'], strict=True)
+ self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
+ self.bot_quantize.load_state_dict(
+ checkpoint['bot_quantize'], strict=True)
+ self.bot_quant_conv.load_state_dict(
+ checkpoint['bot_quant_conv'], strict=True)
+ self.bot_post_quant_conv.load_state_dict(
+ checkpoint['bot_post_quant_conv'], strict=True)
+
+ def optimize_parameters(self, data, step):
+ self.bot_encoder.train()
+ self.bot_decoder_res.train()
+ if not self.fix_decoder:
+ self.decoder.train()
+ self.bot_quantize.train()
+ self.bot_quant_conv.train()
+ self.bot_post_quant_conv.train()
+
+ loss, d_loss = self.training_step(data, step)
+ self.optimizer.zero_grad()
+ loss.backward()
+ self.optimizer.step()
+
+ if step > self.disc_start_step:
+ self.disc_optimizer.zero_grad()
+ d_loss.backward()
+ self.disc_optimizer.step()
+
+ def top_encode(self, x, mask):
+ h = self.top_encoder(x)
+ h = self.top_quant_conv(h)
+ quant, _, _ = self.top_quantize(h, mask)
+ quant = self.top_post_quant_conv(quant)
+ return quant
+
+ def bot_encode(self, x, mask):
+ h = self.bot_encoder(x)
+ h = self.bot_quant_conv(h)
+ quant, emb_loss, info = self.bot_quantize(h, mask)
+ quant = self.bot_post_quant_conv(quant)
+ bot_dec_res = self.bot_decoder_res(quant)
+ return bot_dec_res, emb_loss, info
+
+ def decode(self, quant_top, bot_dec_res):
+ dec = self.decoder(quant_top, bot_h=bot_dec_res)
+ return dec
+
+ def forward_step(self, input, mask):
+ with torch.no_grad():
+ quant_top = self.top_encode(input, mask)
+ bot_dec_res, diff, _ = self.bot_encode(input, mask)
+ dec = self.decode(quant_top, bot_dec_res)
+ return dec, diff
+
+ def feed_data(self, data):
+ x = data['image'].float().to(self.device)
+ mask = data['texture_mask'].float().to(self.device)
+
+ return x, mask
+
+ def training_step(self, data, step):
+ x, mask = self.feed_data(data)
+ xrec, codebook_loss = self.forward_step(x, mask)
+
+ # get recon/perceptual loss
+ recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
+ p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
+ nll_loss = recon_loss + self.perceptual_weight * p_loss
+ nll_loss = torch.mean(nll_loss)
+
+ # augment for input to discriminator
+ if self.diff_aug:
+ xrec = DiffAugment(xrec, policy=self.policy)
+
+ # update generator
+ logits_fake = self.disc(xrec)
+ g_loss = -torch.mean(logits_fake)
+ last_layer = self.decoder.conv_out.weight
+ d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer,
+ self.disc_weight_max)
+ d_weight *= adopt_weight(1, step, self.disc_start_step)
+ loss = nll_loss + d_weight * g_loss + codebook_loss
+
+ self.log_dict["loss"] = loss
+ self.log_dict["l1"] = recon_loss.mean().item()
+ self.log_dict["perceptual"] = p_loss.mean().item()
+ self.log_dict["nll_loss"] = nll_loss.item()
+ self.log_dict["g_loss"] = g_loss.item()
+ self.log_dict["d_weight"] = d_weight
+ self.log_dict["codebook_loss"] = codebook_loss.item()
+
+ if step > self.disc_start_step:
+ if self.diff_aug:
+ logits_real = self.disc(
+ DiffAugment(x.contiguous().detach(), policy=self.policy))
+ else:
+ logits_real = self.disc(x.contiguous().detach())
+ logits_fake = self.disc(xrec.contiguous().detach(
+ )) # detach so that generator isn"t also updated
+ d_loss = hinge_d_loss(logits_real, logits_fake)
+ self.log_dict["d_loss"] = d_loss
+ else:
+ d_loss = None
+
+ return loss, d_loss
+
+ @torch.no_grad()
+ def inference(self, data_loader, save_dir):
+ self.bot_encoder.eval()
+ self.bot_decoder_res.eval()
+ self.decoder.eval()
+ self.bot_quantize.eval()
+ self.bot_quant_conv.eval()
+ self.bot_post_quant_conv.eval()
+
+ loss_total = 0
+ num = 0
+
+ for _, data in enumerate(data_loader):
+ img_name = data['img_name'][0]
+ x, mask = self.feed_data(data)
+ xrec, _ = self.forward_step(x, mask)
+
+ recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
+ p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
+ nll_loss = recon_loss + self.perceptual_weight * p_loss
+ nll_loss = torch.mean(nll_loss)
+ loss_total += nll_loss
+
+ num += x.size(0)
+
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ # convert logits to indices
+ xrec = torch.argmax(xrec, dim=1, keepdim=True)
+ xrec = F.one_hot(xrec, num_classes=x.shape[1])
+ xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+
+ img_cat = torch.cat([x, xrec], dim=3).detach()
+ img_cat = ((img_cat + 1) / 2)
+ img_cat = img_cat.clamp_(0, 1)
+ save_image(
+ img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
+
+ return (loss_total / num).item()
+
+ def get_current_log(self):
+ return self.log_dict
+
+ def update_learning_rate(self, epoch):
+ """Update learning rate.
+
+ Args:
+ current_iter (int): Current iteration.
+ warmup_iter (int): Warmup iter numbers. -1 for no warmup.
+ Default: -1.
+ """
+ lr = self.optimizer.param_groups[0]['lr']
+
+ if self.opt['lr_decay'] == 'step':
+ lr = self.opt['lr'] * (
+ self.opt['gamma']**(epoch // self.opt['step']))
+ elif self.opt['lr_decay'] == 'cos':
+ lr = self.opt['lr'] * (
+ 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
+ elif self.opt['lr_decay'] == 'linear':
+ lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
+ elif self.opt['lr_decay'] == 'linear2exp':
+ if epoch < self.opt['turning_point'] + 1:
+ # learning rate decay as 95%
+ # at the turning point (1 / 95% = 1.0526)
+ lr = self.opt['lr'] * (
+ 1 - epoch / int(self.opt['turning_point'] * 1.0526))
+ else:
+ lr *= self.opt['gamma']
+ elif self.opt['lr_decay'] == 'schedule':
+ if epoch in self.opt['schedule']:
+ lr *= self.opt['gamma']
+ else:
+ raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
+ # set learning rate
+ for param_group in self.optimizer.param_groups:
+ param_group['lr'] = lr
+
+ return lr
diff --git a/Text2Human/models/losses/__init__.py b/Text2Human/models/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Text2Human/models/losses/accuracy.py b/Text2Human/models/losses/accuracy.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e17db52c85aa693fe8a2f6d0036afc432580cfc
--- /dev/null
+++ b/Text2Human/models/losses/accuracy.py
@@ -0,0 +1,46 @@
+def accuracy(pred, target, topk=1, thresh=None):
+ """Calculate accuracy according to the prediction and target.
+
+ Args:
+ pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
+ target (torch.Tensor): The target of each prediction, shape (N, , ...)
+ topk (int | tuple[int], optional): If the predictions in ``topk``
+ matches the target, the predictions will be regarded as
+ correct ones. Defaults to 1.
+ thresh (float, optional): If not None, predictions with scores under
+ this threshold are considered incorrect. Default to None.
+
+ Returns:
+ float | tuple[float]: If the input ``topk`` is a single integer,
+ the function will return a single float as accuracy. If
+ ``topk`` is a tuple containing multiple integers, the
+ function will return a tuple containing accuracies of
+ each ``topk`` number.
+ """
+ assert isinstance(topk, (int, tuple))
+ if isinstance(topk, int):
+ topk = (topk, )
+ return_single = True
+ else:
+ return_single = False
+
+ maxk = max(topk)
+ if pred.size(0) == 0:
+ accu = [pred.new_tensor(0.) for i in range(len(topk))]
+ return accu[0] if return_single else accu
+ assert pred.ndim == target.ndim + 1
+ assert pred.size(0) == target.size(0)
+ assert maxk <= pred.size(1), \
+ f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
+ pred_value, pred_label = pred.topk(maxk, dim=1)
+ # transpose to shape (maxk, N, ...)
+ pred_label = pred_label.transpose(0, 1)
+ correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label))
+ if thresh is not None:
+ # Only prediction values larger than thresh are counted as correct
+ correct = correct & (pred_value > thresh).t()
+ res = []
+ for k in topk:
+ correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
+ res.append(correct_k.mul_(100.0 / target.numel()))
+ return res[0] if return_single else res
diff --git a/Text2Human/models/losses/cross_entropy_loss.py b/Text2Human/models/losses/cross_entropy_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..87cc79d7ff8deba8ca9aa82eacae97b94e218fb0
--- /dev/null
+++ b/Text2Human/models/losses/cross_entropy_loss.py
@@ -0,0 +1,246 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def reduce_loss(loss, reduction):
+ """Reduce loss as specified.
+
+ Args:
+ loss (Tensor): Elementwise loss tensor.
+ reduction (str): Options are "none", "mean" and "sum".
+
+ Return:
+ Tensor: Reduced loss tensor.
+ """
+ reduction_enum = F._Reduction.get_enum(reduction)
+ # none: 0, elementwise_mean:1, sum: 2
+ if reduction_enum == 0:
+ return loss
+ elif reduction_enum == 1:
+ return loss.mean()
+ elif reduction_enum == 2:
+ return loss.sum()
+
+
+def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
+ """Apply element-wise weight and reduce loss.
+
+ Args:
+ loss (Tensor): Element-wise loss.
+ weight (Tensor): Element-wise weights.
+ reduction (str): Same as built-in losses of PyTorch.
+ avg_factor (float): Avarage factor when computing the mean of losses.
+
+ Returns:
+ Tensor: Processed loss values.
+ """
+ # if weight is specified, apply element-wise weight
+ if weight is not None:
+ assert weight.dim() == loss.dim()
+ if weight.dim() > 1:
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
+ loss = loss * weight
+
+ # if avg_factor is not specified, just reduce the loss
+ if avg_factor is None:
+ loss = reduce_loss(loss, reduction)
+ else:
+ # if reduction is mean, then average the loss by avg_factor
+ if reduction == 'mean':
+ loss = loss.sum() / avg_factor
+ # if reduction is 'none', then do nothing, otherwise raise an error
+ elif reduction != 'none':
+ raise ValueError('avg_factor can not be used with reduction="sum"')
+ return loss
+
+
+def cross_entropy(pred,
+ label,
+ weight=None,
+ class_weight=None,
+ reduction='mean',
+ avg_factor=None,
+ ignore_index=-100):
+ """The wrapper function for :func:`F.cross_entropy`"""
+ # class_weight is a manual rescaling weight given to each class.
+ # If given, has to be a Tensor of size C element-wise losses
+ loss = F.cross_entropy(
+ pred,
+ label,
+ weight=class_weight,
+ reduction='none',
+ ignore_index=ignore_index)
+
+ # apply weights and do the reduction
+ if weight is not None:
+ weight = weight.float()
+ loss = weight_reduce_loss(
+ loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
+
+ return loss
+
+
+def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
+ """Expand onehot labels to match the size of prediction."""
+ bin_labels = labels.new_zeros(target_shape)
+ valid_mask = (labels >= 0) & (labels != ignore_index)
+ inds = torch.nonzero(valid_mask, as_tuple=True)
+
+ if inds[0].numel() > 0:
+ if labels.dim() == 3:
+ bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
+ else:
+ bin_labels[inds[0], labels[valid_mask]] = 1
+
+ valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
+ if label_weights is None:
+ bin_label_weights = valid_mask
+ else:
+ bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
+ bin_label_weights *= valid_mask
+
+ return bin_labels, bin_label_weights
+
+
+def binary_cross_entropy(pred,
+ label,
+ weight=None,
+ reduction='mean',
+ avg_factor=None,
+ class_weight=None,
+ ignore_index=255):
+ """Calculate the binary CrossEntropy loss.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, 1).
+ label (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ reduction (str, optional): The method used to reduce the loss.
+ Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ class_weight (list[float], optional): The weight for each class.
+ ignore_index (int | None): The label index to be ignored. Default: 255
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ if pred.dim() != label.dim():
+ assert (pred.dim() == 2 and label.dim() == 1) or (
+ pred.dim() == 4 and label.dim() == 3), \
+ 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
+ 'H, W], label shape [N, H, W] are supported'
+ label, weight = _expand_onehot_labels(label, weight, pred.shape,
+ ignore_index)
+
+ # weighted element-wise losses
+ if weight is not None:
+ weight = weight.float()
+ loss = F.binary_cross_entropy_with_logits(
+ pred, label.float(), pos_weight=class_weight, reduction='none')
+ # do the reduction for the weighted loss
+ loss = weight_reduce_loss(
+ loss, weight, reduction=reduction, avg_factor=avg_factor)
+
+ return loss
+
+
+def mask_cross_entropy(pred,
+ target,
+ label,
+ reduction='mean',
+ avg_factor=None,
+ class_weight=None,
+ ignore_index=None):
+ """Calculate the CrossEntropy loss for masks.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, C), C is the number
+ of classes.
+ target (torch.Tensor): The learning label of the prediction.
+ label (torch.Tensor): ``label`` indicates the class label of the mask'
+ corresponding object. This will be used to select the mask in the
+ of the class which the object belongs to when the mask prediction
+ if not class-agnostic.
+ reduction (str, optional): The method used to reduce the loss.
+ Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ class_weight (list[float], optional): The weight for each class.
+ ignore_index (None): Placeholder, to be consistent with other loss.
+ Default: None.
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert ignore_index is None, 'BCE loss does not support ignore_index'
+ # TODO: handle these two reserved arguments
+ assert reduction == 'mean' and avg_factor is None
+ num_rois = pred.size()[0]
+ inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
+ pred_slice = pred[inds, label].squeeze(1)
+ return F.binary_cross_entropy_with_logits(
+ pred_slice, target, weight=class_weight, reduction='mean')[None]
+
+
+class CrossEntropyLoss(nn.Module):
+ """CrossEntropyLoss.
+
+ Args:
+ use_sigmoid (bool, optional): Whether the prediction uses sigmoid
+ of softmax. Defaults to False.
+ use_mask (bool, optional): Whether to use mask cross entropy loss.
+ Defaults to False.
+ reduction (str, optional): . Defaults to 'mean'.
+ Options are "none", "mean" and "sum".
+ class_weight (list[float], optional): Weight of each class.
+ Defaults to None.
+ loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
+ """
+
+ def __init__(self,
+ use_sigmoid=False,
+ use_mask=False,
+ reduction='mean',
+ class_weight=None,
+ loss_weight=1.0):
+ super(CrossEntropyLoss, self).__init__()
+ assert (use_sigmoid is False) or (use_mask is False)
+ self.use_sigmoid = use_sigmoid
+ self.use_mask = use_mask
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.class_weight = class_weight
+
+ if self.use_sigmoid:
+ self.cls_criterion = binary_cross_entropy
+ elif self.use_mask:
+ self.cls_criterion = mask_cross_entropy
+ else:
+ self.cls_criterion = cross_entropy
+
+ def forward(self,
+ cls_score,
+ label,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ """Forward function."""
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if self.class_weight is not None:
+ class_weight = cls_score.new_tensor(self.class_weight)
+ else:
+ class_weight = None
+ loss_cls = self.loss_weight * self.cls_criterion(
+ cls_score,
+ label,
+ weight,
+ class_weight=class_weight,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss_cls
diff --git a/Text2Human/models/losses/segmentation_loss.py b/Text2Human/models/losses/segmentation_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..85cb46e4eea5510a95da23996fdd357bd8f8e743
--- /dev/null
+++ b/Text2Human/models/losses/segmentation_loss.py
@@ -0,0 +1,25 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class BCELoss(nn.Module):
+
+ def forward(self, prediction, target):
+ loss = F.binary_cross_entropy_with_logits(prediction, target)
+ return loss, {}
+
+
+class BCELossWithQuant(nn.Module):
+
+ def __init__(self, codebook_weight=1.):
+ super().__init__()
+ self.codebook_weight = codebook_weight
+
+ def forward(self, qloss, target, prediction, split):
+ bce_loss = F.binary_cross_entropy_with_logits(prediction, target)
+ loss = bce_loss + self.codebook_weight * qloss
+ return loss, {
+ "{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/bce_loss".format(split): bce_loss.detach().mean(),
+ "{}/quant_loss".format(split): qloss.detach().mean()
+ }
diff --git a/Text2Human/models/losses/vqgan_loss.py b/Text2Human/models/losses/vqgan_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..f07315711e2fabae548a2dc48743f593be6f8e39
--- /dev/null
+++ b/Text2Human/models/losses/vqgan_loss.py
@@ -0,0 +1,114 @@
+import torch
+import torch.nn.functional as F
+
+
+def calculate_adaptive_weight(recon_loss, g_loss, last_layer, disc_weight_max):
+ recon_grads = torch.autograd.grad(
+ recon_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+
+ d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
+ return d_weight
+
+
+def adopt_weight(weight, global_step, threshold=0, value=0.):
+ if global_step < threshold:
+ weight = value
+ return weight
+
+
+@torch.jit.script
+def hinge_d_loss(logits_real, logits_fake):
+ loss_real = torch.mean(F.relu(1. - logits_real))
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
+ d_loss = 0.5 * (loss_real + loss_fake)
+ return d_loss
+
+
+def DiffAugment(x, policy='', channels_first=True):
+ if policy:
+ if not channels_first:
+ x = x.permute(0, 3, 1, 2)
+ for p in policy.split(','):
+ for f in AUGMENT_FNS[p]:
+ x = f(x)
+ if not channels_first:
+ x = x.permute(0, 2, 3, 1)
+ x = x.contiguous()
+ return x
+
+
+def rand_brightness(x):
+ x = x + (
+ torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
+ return x
+
+
+def rand_saturation(x):
+ x_mean = x.mean(dim=1, keepdim=True)
+ x = (x - x_mean) * (torch.rand(
+ x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
+ return x
+
+
+def rand_contrast(x):
+ x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
+ x = (x - x_mean) * (torch.rand(
+ x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
+ return x
+
+
+def rand_translation(x, ratio=0.125):
+ shift_x, shift_y = int(x.size(2) * ratio +
+ 0.5), int(x.size(3) * ratio + 0.5)
+ translation_x = torch.randint(
+ -shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
+ translation_y = torch.randint(
+ -shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
+ grid_batch, grid_x, grid_y = torch.meshgrid(
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
+ torch.arange(x.size(2), dtype=torch.long, device=x.device),
+ torch.arange(x.size(3), dtype=torch.long, device=x.device),
+ )
+ grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
+ grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
+ x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
+ x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x,
+ grid_y].permute(0, 3, 1, 2)
+ return x
+
+
+def rand_cutout(x, ratio=0.5):
+ cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
+ offset_x = torch.randint(
+ 0,
+ x.size(2) + (1 - cutout_size[0] % 2),
+ size=[x.size(0), 1, 1],
+ device=x.device)
+ offset_y = torch.randint(
+ 0,
+ x.size(3) + (1 - cutout_size[1] % 2),
+ size=[x.size(0), 1, 1],
+ device=x.device)
+ grid_batch, grid_x, grid_y = torch.meshgrid(
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
+ torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
+ torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
+ )
+ grid_x = torch.clamp(
+ grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
+ grid_y = torch.clamp(
+ grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
+ mask = torch.ones(
+ x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
+ mask[grid_batch, grid_x, grid_y] = 0
+ x = x * mask.unsqueeze(1)
+ return x
+
+
+AUGMENT_FNS = {
+ 'color': [rand_brightness, rand_saturation, rand_contrast],
+ 'translation': [rand_translation],
+ 'cutout': [rand_cutout],
+}
diff --git a/Text2Human/models/parsing_gen_model.py b/Text2Human/models/parsing_gen_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..9440345dcf08d0ef9441d48734a3acaf2e2a4b5f
--- /dev/null
+++ b/Text2Human/models/parsing_gen_model.py
@@ -0,0 +1,220 @@
+import logging
+import math
+from collections import OrderedDict
+
+import mmcv
+import numpy as np
+import torch
+from torchvision.utils import save_image
+
+from models.archs.fcn_arch import FCNHead
+from models.archs.shape_attr_embedding_arch import ShapeAttrEmbedding
+from models.archs.unet_arch import ShapeUNet
+from models.losses.accuracy import accuracy
+from models.losses.cross_entropy_loss import CrossEntropyLoss
+
+logger = logging.getLogger('base')
+
+
+class ParsingGenModel():
+ """Paring Generation model.
+ """
+
+ def __init__(self, opt):
+ self.opt = opt
+ self.device = torch.device('cuda')
+ self.is_train = opt['is_train']
+
+ self.attr_embedder = ShapeAttrEmbedding(
+ dim=opt['embedder_dim'],
+ out_dim=opt['embedder_out_dim'],
+ cls_num_list=opt['attr_class_num']).to(self.device)
+ self.parsing_encoder = ShapeUNet(
+ in_channels=opt['encoder_in_channels']).to(self.device)
+ self.parsing_decoder = FCNHead(
+ in_channels=opt['fc_in_channels'],
+ in_index=opt['fc_in_index'],
+ channels=opt['fc_channels'],
+ num_convs=opt['fc_num_convs'],
+ concat_input=opt['fc_concat_input'],
+ dropout_ratio=opt['fc_dropout_ratio'],
+ num_classes=opt['fc_num_classes'],
+ align_corners=opt['fc_align_corners'],
+ ).to(self.device)
+
+ self.init_training_settings()
+
+ self.palette = [[0, 0, 0], [255, 250, 250], [220, 220, 220],
+ [250, 235, 215], [255, 250, 205], [211, 211, 211],
+ [70, 130, 180], [127, 255, 212], [0, 100, 0],
+ [50, 205, 50], [255, 255, 0], [245, 222, 179],
+ [255, 140, 0], [255, 0, 0], [16, 78, 139],
+ [144, 238, 144], [50, 205, 174], [50, 155, 250],
+ [160, 140, 88], [213, 140, 88], [90, 140, 90],
+ [185, 210, 205], [130, 165, 180], [225, 141, 151]]
+
+ def init_training_settings(self):
+ optim_params = []
+ for v in self.attr_embedder.parameters():
+ if v.requires_grad:
+ optim_params.append(v)
+ for v in self.parsing_encoder.parameters():
+ if v.requires_grad:
+ optim_params.append(v)
+ for v in self.parsing_decoder.parameters():
+ if v.requires_grad:
+ optim_params.append(v)
+ # set up optimizers
+ self.optimizer = torch.optim.Adam(
+ optim_params,
+ self.opt['lr'],
+ weight_decay=self.opt['weight_decay'])
+ self.log_dict = OrderedDict()
+ self.entropy_loss = CrossEntropyLoss().to(self.device)
+
+ def feed_data(self, data):
+ self.pose = data['densepose'].to(self.device)
+ self.attr = data['attr'].to(self.device)
+ self.segm = data['segm'].to(self.device)
+
+ def optimize_parameters(self):
+ self.attr_embedder.train()
+ self.parsing_encoder.train()
+ self.parsing_decoder.train()
+
+ self.attr_embedding = self.attr_embedder(self.attr)
+ self.pose_enc = self.parsing_encoder(self.pose, self.attr_embedding)
+ self.seg_logits = self.parsing_decoder(self.pose_enc)
+
+ loss = self.entropy_loss(self.seg_logits, self.segm)
+
+ self.optimizer.zero_grad()
+ loss.backward()
+ self.optimizer.step()
+
+ self.log_dict['loss_total'] = loss
+
+ def get_vis(self, save_path):
+ img_cat = torch.cat([
+ self.pose,
+ self.segm,
+ ], dim=3).detach()
+ img_cat = ((img_cat + 1) / 2)
+
+ img_cat = img_cat.clamp_(0, 1)
+
+ save_image(img_cat, save_path, nrow=1, padding=4)
+
+ def inference(self, data_loader, save_dir):
+ self.attr_embedder.eval()
+ self.parsing_encoder.eval()
+ self.parsing_decoder.eval()
+
+ acc = 0
+ num = 0
+
+ for _, data in enumerate(data_loader):
+ pose = data['densepose'].to(self.device)
+ attr = data['attr'].to(self.device)
+ segm = data['segm'].to(self.device)
+ img_name = data['img_name']
+
+ num += pose.size(0)
+ with torch.no_grad():
+ attr_embedding = self.attr_embedder(attr)
+ pose_enc = self.parsing_encoder(pose, attr_embedding)
+ seg_logits = self.parsing_decoder(pose_enc)
+ seg_pred = seg_logits.argmax(dim=1)
+ acc += accuracy(seg_logits, segm)
+ palette_label = self.palette_result(segm.cpu().numpy())
+ palette_pred = self.palette_result(seg_pred.cpu().numpy())
+ pose_numpy = ((pose[0] + 1) / 2. * 255.).expand(
+ 3,
+ pose[0].size(1),
+ pose[0].size(2),
+ ).cpu().numpy().clip(0, 255).astype(np.uint8).transpose(1, 2, 0)
+ concat_result = np.concatenate(
+ (pose_numpy, palette_pred, palette_label), axis=1)
+ mmcv.imwrite(concat_result, f'{save_dir}/{img_name[0]}')
+
+ self.attr_embedder.train()
+ self.parsing_encoder.train()
+ self.parsing_decoder.train()
+ return (acc / num).item()
+
+ def get_current_log(self):
+ return self.log_dict
+
+ def update_learning_rate(self, epoch):
+ """Update learning rate.
+
+ Args:
+ current_iter (int): Current iteration.
+ warmup_iter (int): Warmup iter numbers. -1 for no warmup.
+ Default: -1.
+ """
+ lr = self.optimizer.param_groups[0]['lr']
+
+ if self.opt['lr_decay'] == 'step':
+ lr = self.opt['lr'] * (
+ self.opt['gamma']**(epoch // self.opt['step']))
+ elif self.opt['lr_decay'] == 'cos':
+ lr = self.opt['lr'] * (
+ 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
+ elif self.opt['lr_decay'] == 'linear':
+ lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
+ elif self.opt['lr_decay'] == 'linear2exp':
+ if epoch < self.opt['turning_point'] + 1:
+ # learning rate decay as 95%
+ # at the turning point (1 / 95% = 1.0526)
+ lr = self.opt['lr'] * (
+ 1 - epoch / int(self.opt['turning_point'] * 1.0526))
+ else:
+ lr *= self.opt['gamma']
+ elif self.opt['lr_decay'] == 'schedule':
+ if epoch in self.opt['schedule']:
+ lr *= self.opt['gamma']
+ else:
+ raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
+ # set learning rate
+ for param_group in self.optimizer.param_groups:
+ param_group['lr'] = lr
+
+ return lr
+
+ def save_network(self, save_path):
+ """Save networks.
+ """
+
+ save_dict = {}
+ save_dict['embedder'] = self.attr_embedder.state_dict()
+ save_dict['encoder'] = self.parsing_encoder.state_dict()
+ save_dict['decoder'] = self.parsing_decoder.state_dict()
+
+ torch.save(save_dict, save_path)
+
+ def load_network(self):
+ checkpoint = torch.load(self.opt['pretrained_parsing_gen'])
+
+ self.attr_embedder.load_state_dict(checkpoint['embedder'], strict=True)
+ self.attr_embedder.eval()
+
+ self.parsing_encoder.load_state_dict(
+ checkpoint['encoder'], strict=True)
+ self.parsing_encoder.eval()
+
+ self.parsing_decoder.load_state_dict(
+ checkpoint['decoder'], strict=True)
+ self.parsing_decoder.eval()
+
+ def palette_result(self, result):
+ seg = result[0]
+ palette = np.array(self.palette)
+ assert palette.shape[1] == 3
+ assert len(palette.shape) == 2
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
+ for label, color in enumerate(palette):
+ color_seg[seg == label, :] = color
+ # convert to BGR
+ color_seg = color_seg[..., ::-1]
+ return color_seg
diff --git a/Text2Human/models/sample_model.py b/Text2Human/models/sample_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..708f5def222ae4bfa0c746c81e36032de226b254
--- /dev/null
+++ b/Text2Human/models/sample_model.py
@@ -0,0 +1,500 @@
+import logging
+
+import numpy as np
+import torch
+import torch.distributions as dists
+import torch.nn.functional as F
+from torchvision.utils import save_image
+
+from models.archs.fcn_arch import FCNHead, MultiHeadFCNHead
+from models.archs.shape_attr_embedding_arch import ShapeAttrEmbedding
+from models.archs.transformer_arch import TransformerMultiHead
+from models.archs.unet_arch import ShapeUNet, UNet
+from models.archs.vqgan_arch import (Decoder, DecoderRes, Encoder,
+ VectorQuantizer,
+ VectorQuantizerSpatialTextureAware,
+ VectorQuantizerTexture)
+
+logger = logging.getLogger('base')
+
+
+class BaseSampleModel():
+ """Base Model"""
+
+ def __init__(self, opt):
+ self.opt = opt
+ self.device = torch.device(opt['device'])
+
+ # hierarchical VQVAE
+ self.decoder = Decoder(
+ in_channels=opt['top_in_channels'],
+ resolution=opt['top_resolution'],
+ z_channels=opt['top_z_channels'],
+ ch=opt['top_ch'],
+ out_ch=opt['top_out_ch'],
+ num_res_blocks=opt['top_num_res_blocks'],
+ attn_resolutions=opt['top_attn_resolutions'],
+ ch_mult=opt['top_ch_mult'],
+ dropout=opt['top_dropout'],
+ resamp_with_conv=True,
+ give_pre_end=False).to(self.device)
+ self.top_quantize = VectorQuantizerTexture(
+ 1024, opt['embed_dim'], beta=0.25).to(self.device)
+ self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
+ opt["top_z_channels"],
+ 1).to(self.device)
+ self.load_top_pretrain_models()
+
+ self.bot_decoder_res = DecoderRes(
+ in_channels=opt['bot_in_channels'],
+ resolution=opt['bot_resolution'],
+ z_channels=opt['bot_z_channels'],
+ ch=opt['bot_ch'],
+ num_res_blocks=opt['bot_num_res_blocks'],
+ ch_mult=opt['bot_ch_mult'],
+ dropout=opt['bot_dropout'],
+ give_pre_end=False).to(self.device)
+ self.bot_quantize = VectorQuantizerSpatialTextureAware(
+ opt['bot_n_embed'],
+ opt['embed_dim'],
+ beta=0.25,
+ spatial_size=opt['bot_codebook_spatial_size']).to(self.device)
+ self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
+ opt["bot_z_channels"],
+ 1).to(self.device)
+ self.load_bot_pretrain_network()
+
+ # top -> bot prediction
+ self.index_pred_guidance_encoder = UNet(
+ in_channels=opt['index_pred_encoder_in_channels']).to(self.device)
+ self.index_pred_decoder = MultiHeadFCNHead(
+ in_channels=opt['index_pred_fc_in_channels'],
+ in_index=opt['index_pred_fc_in_index'],
+ channels=opt['index_pred_fc_channels'],
+ num_convs=opt['index_pred_fc_num_convs'],
+ concat_input=opt['index_pred_fc_concat_input'],
+ dropout_ratio=opt['index_pred_fc_dropout_ratio'],
+ num_classes=opt['index_pred_fc_num_classes'],
+ align_corners=opt['index_pred_fc_align_corners'],
+ num_head=18).to(self.device)
+ self.load_index_pred_network()
+
+ # VAE for segmentation mask
+ self.segm_encoder = Encoder(
+ ch=opt['segm_ch'],
+ num_res_blocks=opt['segm_num_res_blocks'],
+ attn_resolutions=opt['segm_attn_resolutions'],
+ ch_mult=opt['segm_ch_mult'],
+ in_channels=opt['segm_in_channels'],
+ resolution=opt['segm_resolution'],
+ z_channels=opt['segm_z_channels'],
+ double_z=opt['segm_double_z'],
+ dropout=opt['segm_dropout']).to(self.device)
+ self.segm_quantizer = VectorQuantizer(
+ opt['segm_n_embed'],
+ opt['segm_embed_dim'],
+ beta=0.25,
+ sane_index_shape=True).to(self.device)
+ self.segm_quant_conv = torch.nn.Conv2d(opt["segm_z_channels"],
+ opt['segm_embed_dim'],
+ 1).to(self.device)
+ self.load_pretrained_segm_token()
+
+ # define sampler
+ self.sampler_fn = TransformerMultiHead(
+ codebook_size=opt['codebook_size'],
+ segm_codebook_size=opt['segm_codebook_size'],
+ texture_codebook_size=opt['texture_codebook_size'],
+ bert_n_emb=opt['bert_n_emb'],
+ bert_n_layers=opt['bert_n_layers'],
+ bert_n_head=opt['bert_n_head'],
+ block_size=opt['block_size'],
+ latent_shape=opt['latent_shape'],
+ embd_pdrop=opt['embd_pdrop'],
+ resid_pdrop=opt['resid_pdrop'],
+ attn_pdrop=opt['attn_pdrop'],
+ num_head=opt['num_head']).to(self.device)
+ self.load_sampler_pretrained_network()
+
+ self.shape = tuple(opt['latent_shape'])
+
+ self.mask_id = opt['codebook_size']
+ self.sample_steps = opt['sample_steps']
+
+ def load_top_pretrain_models(self):
+ # load pretrained vqgan
+ top_vae_checkpoint = torch.load(self.opt['top_vae_path'],map_location=torch.device('cpu'))
+
+ self.decoder.load_state_dict(
+ top_vae_checkpoint['decoder'], strict=True)
+ self.top_quantize.load_state_dict(
+ top_vae_checkpoint['quantize'], strict=True)
+ self.top_post_quant_conv.load_state_dict(
+ top_vae_checkpoint['post_quant_conv'], strict=True)
+
+ self.decoder.eval()
+ self.top_quantize.eval()
+ self.top_post_quant_conv.eval()
+
+ def load_bot_pretrain_network(self):
+ checkpoint = torch.load(self.opt['bot_vae_path'],map_location=torch.device('cpu'))
+ self.bot_decoder_res.load_state_dict(
+ checkpoint['bot_decoder_res'], strict=True)
+ self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
+ self.bot_quantize.load_state_dict(
+ checkpoint['bot_quantize'], strict=True)
+ self.bot_post_quant_conv.load_state_dict(
+ checkpoint['bot_post_quant_conv'], strict=True)
+
+ self.bot_decoder_res.eval()
+ self.decoder.eval()
+ self.bot_quantize.eval()
+ self.bot_post_quant_conv.eval()
+
+ def load_pretrained_segm_token(self):
+ # load pretrained vqgan for segmentation mask
+ segm_token_checkpoint = torch.load(self.opt['segm_token_path'],map_location=torch.device('cpu'))
+ self.segm_encoder.load_state_dict(
+ segm_token_checkpoint['encoder'], strict=True)
+ self.segm_quantizer.load_state_dict(
+ segm_token_checkpoint['quantize'], strict=True)
+ self.segm_quant_conv.load_state_dict(
+ segm_token_checkpoint['quant_conv'], strict=True)
+
+ self.segm_encoder.eval()
+ self.segm_quantizer.eval()
+ self.segm_quant_conv.eval()
+
+ def load_index_pred_network(self):
+ checkpoint = torch.load(self.opt['pretrained_index_network'],map_location=torch.device('cpu'))
+ self.index_pred_guidance_encoder.load_state_dict(
+ checkpoint['guidance_encoder'], strict=True)
+ self.index_pred_decoder.load_state_dict(
+ checkpoint['index_decoder'], strict=True)
+
+ self.index_pred_guidance_encoder.eval()
+ self.index_pred_decoder.eval()
+
+ def load_sampler_pretrained_network(self):
+ checkpoint = torch.load(self.opt['pretrained_sampler'],map_location=torch.device('cpu'))
+ self.sampler_fn.load_state_dict(checkpoint, strict=True)
+ self.sampler_fn.eval()
+
+ def bot_index_prediction(self, feature_top, texture_mask):
+ self.index_pred_guidance_encoder.eval()
+ self.index_pred_decoder.eval()
+
+ texture_tokens = F.interpolate(
+ texture_mask, (32, 16), mode='nearest').view(self.batch_size,
+ -1).long()
+
+ texture_mask_flatten = texture_tokens.view(-1)
+ min_encodings_indices_list = [
+ torch.full(
+ texture_mask_flatten.size(),
+ fill_value=-1,
+ dtype=torch.long,
+ device=texture_mask_flatten.device) for _ in range(18)
+ ]
+ with torch.no_grad():
+ feature_enc = self.index_pred_guidance_encoder(feature_top)
+ memory_logits_list = self.index_pred_decoder(feature_enc)
+ for codebook_idx, memory_logits in enumerate(memory_logits_list):
+ region_of_interest = texture_mask_flatten == codebook_idx
+ if torch.sum(region_of_interest) > 0:
+ memory_indices_pred = memory_logits.argmax(dim=1).view(-1)
+ memory_indices_pred = memory_indices_pred
+ min_encodings_indices_list[codebook_idx][
+ region_of_interest] = memory_indices_pred[
+ region_of_interest]
+ min_encodings_indices_return_list = [
+ min_encodings_indices.view((1, 32, 16))
+ for min_encodings_indices in min_encodings_indices_list
+ ]
+
+ return min_encodings_indices_return_list
+
+ def sample_and_refine(self, save_dir=None, img_name=None):
+ # sample 32x16 features indices
+ sampled_top_indices_list = self.sample_fn(
+ temp=1, sample_steps=self.sample_steps)
+
+ for sample_idx in range(self.batch_size):
+ sample_indices = [
+ sampled_indices_cur[sample_idx:sample_idx + 1]
+ for sampled_indices_cur in sampled_top_indices_list
+ ]
+ top_quant = self.top_quantize.get_codebook_entry(
+ sample_indices, self.texture_mask[sample_idx:sample_idx + 1],
+ (sample_indices[0].size(0), self.shape[0], self.shape[1],
+ self.opt["top_z_channels"]))
+
+ top_quant = self.top_post_quant_conv(top_quant)
+
+ bot_indices_list = self.bot_index_prediction(
+ top_quant, self.texture_mask[sample_idx:sample_idx + 1])
+
+ quant_bot = self.bot_quantize.get_codebook_entry(
+ bot_indices_list, self.texture_mask[sample_idx:sample_idx + 1],
+ (bot_indices_list[0].size(0), bot_indices_list[0].size(1),
+ bot_indices_list[0].size(2),
+ self.opt["bot_z_channels"])) #.permute(0, 3, 1, 2)
+ quant_bot = self.bot_post_quant_conv(quant_bot)
+ bot_dec_res = self.bot_decoder_res(quant_bot)
+
+ dec = self.decoder(top_quant, bot_h=bot_dec_res)
+
+ dec = ((dec + 1) / 2)
+ dec = dec.clamp_(0, 1)
+ if save_dir is None and img_name is None:
+ return dec
+ else:
+ save_image(
+ dec,
+ f'{save_dir}/{img_name[sample_idx]}',
+ nrow=1,
+ padding=4)
+
+ def sample_fn(self, temp=1.0, sample_steps=None):
+ self.sampler_fn.eval()
+
+ x_t = torch.ones((self.batch_size, np.prod(self.shape)),
+ device=self.device).long() * self.mask_id
+ unmasked = torch.zeros_like(x_t, device=self.device).bool()
+ sample_steps = list(range(1, sample_steps + 1))
+
+ texture_tokens = F.interpolate(
+ self.texture_mask, (32, 16),
+ mode='nearest').view(self.batch_size, -1).long()
+
+ texture_mask_flatten = texture_tokens.view(-1)
+
+ # min_encodings_indices_list would be used to visualize the image
+ min_encodings_indices_list = [
+ torch.full(
+ texture_mask_flatten.size(),
+ fill_value=-1,
+ dtype=torch.long,
+ device=texture_mask_flatten.device) for _ in range(18)
+ ]
+
+ for t in reversed(sample_steps):
+ t = torch.full((self.batch_size, ),
+ t,
+ device=self.device,
+ dtype=torch.long)
+
+ # where to unmask
+ changes = torch.rand(
+ x_t.shape, device=self.device) < 1 / t.float().unsqueeze(-1)
+ # don't unmask somewhere already unmasked
+ changes = torch.bitwise_xor(changes,
+ torch.bitwise_and(changes, unmasked))
+ # update mask with changes
+ unmasked = torch.bitwise_or(unmasked, changes)
+
+ x_0_logits_list = self.sampler_fn(
+ x_t, self.segm_tokens, texture_tokens, t=t)
+
+ changes_flatten = changes.view(-1)
+ ori_shape = x_t.shape # [b, h*w]
+ x_t = x_t.view(-1) # [b*h*w]
+ for codebook_idx, x_0_logits in enumerate(x_0_logits_list):
+ if torch.sum(texture_mask_flatten[changes_flatten] ==
+ codebook_idx) > 0:
+ # scale by temperature
+ x_0_logits = x_0_logits / temp
+ x_0_dist = dists.Categorical(logits=x_0_logits)
+ x_0_hat = x_0_dist.sample().long()
+ x_0_hat = x_0_hat.view(-1)
+
+ # only replace the changed indices with corresponding codebook_idx
+ changes_segm = torch.bitwise_and(
+ changes_flatten, texture_mask_flatten == codebook_idx)
+
+ # x_t would be the input to the transformer, so the index range should be continual one
+ x_t[changes_segm] = x_0_hat[
+ changes_segm] + 1024 * codebook_idx
+ min_encodings_indices_list[codebook_idx][
+ changes_segm] = x_0_hat[changes_segm]
+
+ x_t = x_t.view(ori_shape) # [b, h*w]
+
+ min_encodings_indices_return_list = [
+ min_encodings_indices.view(ori_shape)
+ for min_encodings_indices in min_encodings_indices_list
+ ]
+
+ self.sampler_fn.train()
+
+ return min_encodings_indices_return_list
+
+ @torch.no_grad()
+ def get_quantized_segm(self, segm):
+ segm_one_hot = F.one_hot(
+ segm.squeeze(1).long(),
+ num_classes=self.opt['segm_num_segm_classes']).permute(
+ 0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+ encoded_segm_mask = self.segm_encoder(segm_one_hot)
+ encoded_segm_mask = self.segm_quant_conv(encoded_segm_mask)
+ _, _, [_, _, segm_tokens] = self.segm_quantizer(encoded_segm_mask)
+
+ return segm_tokens
+
+
+class SampleFromParsingModel(BaseSampleModel):
+ """SampleFromParsing model.
+ """
+
+ def feed_data(self, data):
+ self.segm = data['segm'].to(self.device)
+ self.texture_mask = data['texture_mask'].to(self.device)
+ self.batch_size = self.segm.size(0)
+
+ self.segm_tokens = self.get_quantized_segm(self.segm)
+ self.segm_tokens = self.segm_tokens.view(self.batch_size, -1)
+
+ def inference(self, data_loader, save_dir):
+ for _, data in enumerate(data_loader):
+ img_name = data['img_name']
+ self.feed_data(data)
+ with torch.no_grad():
+ self.sample_and_refine(save_dir, img_name)
+
+
+class SampleFromPoseModel(BaseSampleModel):
+ """SampleFromPose model.
+ """
+
+ def __init__(self, opt):
+ super().__init__(opt)
+ # pose-to-parsing
+ self.shape_attr_embedder = ShapeAttrEmbedding(
+ dim=opt['shape_embedder_dim'],
+ out_dim=opt['shape_embedder_out_dim'],
+ cls_num_list=opt['shape_attr_class_num']).to(self.device)
+ self.shape_parsing_encoder = ShapeUNet(
+ in_channels=opt['shape_encoder_in_channels']).to(self.device)
+ self.shape_parsing_decoder = FCNHead(
+ in_channels=opt['shape_fc_in_channels'],
+ in_index=opt['shape_fc_in_index'],
+ channels=opt['shape_fc_channels'],
+ num_convs=opt['shape_fc_num_convs'],
+ concat_input=opt['shape_fc_concat_input'],
+ dropout_ratio=opt['shape_fc_dropout_ratio'],
+ num_classes=opt['shape_fc_num_classes'],
+ align_corners=opt['shape_fc_align_corners'],
+ ).to(self.device)
+ self.load_shape_generation_models()
+
+ self.palette = [[0, 0, 0], [255, 250, 250], [220, 220, 220],
+ [250, 235, 215], [255, 250, 205], [211, 211, 211],
+ [70, 130, 180], [127, 255, 212], [0, 100, 0],
+ [50, 205, 50], [255, 255, 0], [245, 222, 179],
+ [255, 140, 0], [255, 0, 0], [16, 78, 139],
+ [144, 238, 144], [50, 205, 174], [50, 155, 250],
+ [160, 140, 88], [213, 140, 88], [90, 140, 90],
+ [185, 210, 205], [130, 165, 180], [225, 141, 151]]
+
+ def load_shape_generation_models(self):
+ checkpoint = torch.load(self.opt['pretrained_parsing_gen'],map_location=torch.device('cpu'))
+
+ self.shape_attr_embedder.load_state_dict(
+ checkpoint['embedder'], strict=True)
+ self.shape_attr_embedder.eval()
+
+ self.shape_parsing_encoder.load_state_dict(
+ checkpoint['encoder'], strict=True)
+ self.shape_parsing_encoder.eval()
+
+ self.shape_parsing_decoder.load_state_dict(
+ checkpoint['decoder'], strict=True)
+ self.shape_parsing_decoder.eval()
+
+ def feed_data(self, data):
+ self.pose = data['densepose'].to(self.device)
+ self.batch_size = self.pose.size(0)
+
+ self.shape_attr = data['shape_attr'].to(self.device)
+ self.upper_fused_attr = data['upper_fused_attr'].to(self.device)
+ self.lower_fused_attr = data['lower_fused_attr'].to(self.device)
+ self.outer_fused_attr = data['outer_fused_attr'].to(self.device)
+
+ def inference(self, data_loader, save_dir):
+ for _, data in enumerate(data_loader):
+ img_name = data['img_name']
+ self.feed_data(data)
+ with torch.no_grad():
+ self.generate_parsing_map()
+ self.generate_quantized_segm()
+ self.generate_texture_map()
+ self.sample_and_refine(save_dir, img_name)
+
+ def generate_parsing_map(self):
+ with torch.no_grad():
+ attr_embedding = self.shape_attr_embedder(self.shape_attr)
+ pose_enc = self.shape_parsing_encoder(self.pose, attr_embedding)
+ seg_logits = self.shape_parsing_decoder(pose_enc)
+ self.segm = seg_logits.argmax(dim=1)
+ self.segm = self.segm.unsqueeze(1)
+
+ def generate_quantized_segm(self):
+ self.segm_tokens = self.get_quantized_segm(self.segm)
+ self.segm_tokens = self.segm_tokens.view(self.batch_size, -1)
+
+ def generate_texture_map(self):
+ upper_cls = [1., 4.]
+ lower_cls = [3., 5., 21.]
+ outer_cls = [2.]
+
+ mask_batch = []
+ for idx in range(self.batch_size):
+ mask = torch.zeros_like(self.segm[idx])
+ upper_fused_attr = self.upper_fused_attr[idx]
+ lower_fused_attr = self.lower_fused_attr[idx]
+ outer_fused_attr = self.outer_fused_attr[idx]
+ if upper_fused_attr != 17:
+ for cls in upper_cls:
+ mask[self.segm[idx] == cls] = upper_fused_attr + 1
+
+ if lower_fused_attr != 17:
+ for cls in lower_cls:
+ mask[self.segm[idx] == cls] = lower_fused_attr + 1
+
+ if outer_fused_attr != 17:
+ for cls in outer_cls:
+ mask[self.segm[idx] == cls] = outer_fused_attr + 1
+
+ mask_batch.append(mask)
+ self.texture_mask = torch.stack(mask_batch, dim=0).to(torch.float32)
+
+ def feed_pose_data(self, pose_img):
+ # for ui demo
+
+ self.pose = pose_img.to(self.device)
+ self.batch_size = self.pose.size(0)
+
+ def feed_shape_attributes(self, shape_attr):
+ # for ui demo
+
+ self.shape_attr = shape_attr.to(self.device)
+
+ def feed_texture_attributes(self, texture_attr):
+ # for ui demo
+
+ self.upper_fused_attr = texture_attr[0].unsqueeze(0).to(self.device)
+ self.lower_fused_attr = texture_attr[1].unsqueeze(0).to(self.device)
+ self.outer_fused_attr = texture_attr[2].unsqueeze(0).to(self.device)
+
+ def palette_result(self, result):
+
+ seg = result[0]
+ palette = np.array(self.palette)
+ assert palette.shape[1] == 3
+ assert len(palette.shape) == 2
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
+ for label, color in enumerate(palette):
+ color_seg[seg == label, :] = color
+ # convert to BGR
+ # color_seg = color_seg[..., ::-1]
+ return color_seg
diff --git a/Text2Human/models/transformer_model.py b/Text2Human/models/transformer_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..7db0f3e26924c161f56855346af994171b345365
--- /dev/null
+++ b/Text2Human/models/transformer_model.py
@@ -0,0 +1,482 @@
+import logging
+import math
+from collections import OrderedDict
+
+import numpy as np
+import torch
+import torch.distributions as dists
+import torch.nn.functional as F
+from torchvision.utils import save_image
+
+from models.archs.transformer_arch import TransformerMultiHead
+from models.archs.vqgan_arch import (Decoder, Encoder, VectorQuantizer,
+ VectorQuantizerTexture)
+
+logger = logging.getLogger('base')
+
+
+class TransformerTextureAwareModel():
+ """Texture-Aware Diffusion based Transformer model.
+ """
+
+ def __init__(self, opt):
+ self.opt = opt
+ self.device = torch.device('cuda')
+ self.is_train = opt['is_train']
+
+ # VQVAE for image
+ self.img_encoder = Encoder(
+ ch=opt['img_ch'],
+ num_res_blocks=opt['img_num_res_blocks'],
+ attn_resolutions=opt['img_attn_resolutions'],
+ ch_mult=opt['img_ch_mult'],
+ in_channels=opt['img_in_channels'],
+ resolution=opt['img_resolution'],
+ z_channels=opt['img_z_channels'],
+ double_z=opt['img_double_z'],
+ dropout=opt['img_dropout']).to(self.device)
+ self.img_decoder = Decoder(
+ in_channels=opt['img_in_channels'],
+ resolution=opt['img_resolution'],
+ z_channels=opt['img_z_channels'],
+ ch=opt['img_ch'],
+ out_ch=opt['img_out_ch'],
+ num_res_blocks=opt['img_num_res_blocks'],
+ attn_resolutions=opt['img_attn_resolutions'],
+ ch_mult=opt['img_ch_mult'],
+ dropout=opt['img_dropout'],
+ resamp_with_conv=True,
+ give_pre_end=False).to(self.device)
+ self.img_quantizer = VectorQuantizerTexture(
+ opt['img_n_embed'], opt['img_embed_dim'],
+ beta=0.25).to(self.device)
+ self.img_quant_conv = torch.nn.Conv2d(opt["img_z_channels"],
+ opt['img_embed_dim'],
+ 1).to(self.device)
+ self.img_post_quant_conv = torch.nn.Conv2d(opt['img_embed_dim'],
+ opt["img_z_channels"],
+ 1).to(self.device)
+ self.load_pretrained_image_vae()
+
+ # VAE for segmentation mask
+ self.segm_encoder = Encoder(
+ ch=opt['segm_ch'],
+ num_res_blocks=opt['segm_num_res_blocks'],
+ attn_resolutions=opt['segm_attn_resolutions'],
+ ch_mult=opt['segm_ch_mult'],
+ in_channels=opt['segm_in_channels'],
+ resolution=opt['segm_resolution'],
+ z_channels=opt['segm_z_channels'],
+ double_z=opt['segm_double_z'],
+ dropout=opt['segm_dropout']).to(self.device)
+ self.segm_quantizer = VectorQuantizer(
+ opt['segm_n_embed'],
+ opt['segm_embed_dim'],
+ beta=0.25,
+ sane_index_shape=True).to(self.device)
+ self.segm_quant_conv = torch.nn.Conv2d(opt["segm_z_channels"],
+ opt['segm_embed_dim'],
+ 1).to(self.device)
+ self.load_pretrained_segm_vae()
+
+ # define sampler
+ self._denoise_fn = TransformerMultiHead(
+ codebook_size=opt['codebook_size'],
+ segm_codebook_size=opt['segm_codebook_size'],
+ texture_codebook_size=opt['texture_codebook_size'],
+ bert_n_emb=opt['bert_n_emb'],
+ bert_n_layers=opt['bert_n_layers'],
+ bert_n_head=opt['bert_n_head'],
+ block_size=opt['block_size'],
+ latent_shape=opt['latent_shape'],
+ embd_pdrop=opt['embd_pdrop'],
+ resid_pdrop=opt['resid_pdrop'],
+ attn_pdrop=opt['attn_pdrop'],
+ num_head=opt['num_head']).to(self.device)
+
+ self.num_classes = opt['codebook_size']
+ self.shape = tuple(opt['latent_shape'])
+ self.num_timesteps = 1000
+
+ self.mask_id = opt['codebook_size']
+ self.loss_type = opt['loss_type']
+ self.mask_schedule = opt['mask_schedule']
+
+ self.sample_steps = opt['sample_steps']
+
+ self.init_training_settings()
+
+ def load_pretrained_image_vae(self):
+ # load pretrained vqgan for segmentation mask
+ img_ae_checkpoint = torch.load(self.opt['img_ae_path'])
+ self.img_encoder.load_state_dict(
+ img_ae_checkpoint['encoder'], strict=True)
+ self.img_decoder.load_state_dict(
+ img_ae_checkpoint['decoder'], strict=True)
+ self.img_quantizer.load_state_dict(
+ img_ae_checkpoint['quantize'], strict=True)
+ self.img_quant_conv.load_state_dict(
+ img_ae_checkpoint['quant_conv'], strict=True)
+ self.img_post_quant_conv.load_state_dict(
+ img_ae_checkpoint['post_quant_conv'], strict=True)
+ self.img_encoder.eval()
+ self.img_decoder.eval()
+ self.img_quantizer.eval()
+ self.img_quant_conv.eval()
+ self.img_post_quant_conv.eval()
+
+ def load_pretrained_segm_vae(self):
+ # load pretrained vqgan for segmentation mask
+ segm_ae_checkpoint = torch.load(self.opt['segm_ae_path'])
+ self.segm_encoder.load_state_dict(
+ segm_ae_checkpoint['encoder'], strict=True)
+ self.segm_quantizer.load_state_dict(
+ segm_ae_checkpoint['quantize'], strict=True)
+ self.segm_quant_conv.load_state_dict(
+ segm_ae_checkpoint['quant_conv'], strict=True)
+ self.segm_encoder.eval()
+ self.segm_quantizer.eval()
+ self.segm_quant_conv.eval()
+
+ def init_training_settings(self):
+ optim_params = []
+ for v in self._denoise_fn.parameters():
+ if v.requires_grad:
+ optim_params.append(v)
+ # set up optimizer
+ self.optimizer = torch.optim.Adam(
+ optim_params,
+ self.opt['lr'],
+ weight_decay=self.opt['weight_decay'])
+ self.log_dict = OrderedDict()
+
+ @torch.no_grad()
+ def get_quantized_img(self, image, texture_mask):
+ encoded_img = self.img_encoder(image)
+ encoded_img = self.img_quant_conv(encoded_img)
+
+ # img_tokens_input is the continual index for the input of transformer
+ # img_tokens_gt_list is the index for 18 texture-aware codebooks respectively
+ _, _, [_, img_tokens_input, img_tokens_gt_list
+ ] = self.img_quantizer(encoded_img, texture_mask)
+
+ # reshape the tokens
+ b = image.size(0)
+ img_tokens_input = img_tokens_input.view(b, -1)
+ img_tokens_gt_return_list = [
+ img_tokens_gt.view(b, -1) for img_tokens_gt in img_tokens_gt_list
+ ]
+
+ return img_tokens_input, img_tokens_gt_return_list
+
+ @torch.no_grad()
+ def decode(self, quant):
+ quant = self.img_post_quant_conv(quant)
+ dec = self.img_decoder(quant)
+ return dec
+
+ @torch.no_grad()
+ def decode_image_indices(self, indices_list, texture_mask):
+ quant = self.img_quantizer.get_codebook_entry(
+ indices_list, texture_mask,
+ (indices_list[0].size(0), self.shape[0], self.shape[1],
+ self.opt["img_z_channels"]))
+ dec = self.decode(quant)
+
+ return dec
+
+ def sample_time(self, b, device, method='uniform'):
+ if method == 'importance':
+ if not (self.Lt_count > 10).all():
+ return self.sample_time(b, device, method='uniform')
+
+ Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001
+ Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1.
+ pt_all = Lt_sqrt / Lt_sqrt.sum()
+
+ t = torch.multinomial(pt_all, num_samples=b, replacement=True)
+
+ pt = pt_all.gather(dim=0, index=t)
+
+ return t, pt
+
+ elif method == 'uniform':
+ t = torch.randint(
+ 1, self.num_timesteps + 1, (b, ), device=device).long()
+ pt = torch.ones_like(t).float() / self.num_timesteps
+ return t, pt
+
+ else:
+ raise ValueError
+
+ def q_sample(self, x_0, x_0_gt_list, t):
+ # samples q(x_t | x_0)
+ # randomly set token to mask with probability t/T
+ # x_t, x_0_ignore = x_0.clone(), x_0.clone()
+ x_t = x_0.clone()
+
+ mask = torch.rand_like(x_t.float()) < (
+ t.float().unsqueeze(-1) / self.num_timesteps)
+ x_t[mask] = self.mask_id
+ # x_0_ignore[torch.bitwise_not(mask)] = -1
+
+ # for every gt token list, we also need to do the mask
+ x_0_gt_ignore_list = []
+ for x_0_gt in x_0_gt_list:
+ x_0_gt_ignore = x_0_gt.clone()
+ x_0_gt_ignore[torch.bitwise_not(mask)] = -1
+ x_0_gt_ignore_list.append(x_0_gt_ignore)
+
+ return x_t, x_0_gt_ignore_list, mask
+
+ def _train_loss(self, x_0, x_0_gt_list):
+ b, device = x_0.size(0), x_0.device
+
+ # choose what time steps to compute loss at
+ t, pt = self.sample_time(b, device, 'uniform')
+
+ # make x noisy and denoise
+ if self.mask_schedule == 'random':
+ x_t, x_0_gt_ignore_list, mask = self.q_sample(
+ x_0=x_0, x_0_gt_list=x_0_gt_list, t=t)
+ else:
+ raise NotImplementedError
+
+ # sample p(x_0 | x_t)
+ x_0_hat_logits_list = self._denoise_fn(
+ x_t, self.segm_tokens, self.texture_tokens, t=t)
+
+ # Always compute ELBO for comparison purposes
+ cross_entropy_loss = 0
+ for x_0_hat_logits, x_0_gt_ignore in zip(x_0_hat_logits_list,
+ x_0_gt_ignore_list):
+ cross_entropy_loss += F.cross_entropy(
+ x_0_hat_logits.permute(0, 2, 1),
+ x_0_gt_ignore,
+ ignore_index=-1,
+ reduction='none').sum(1)
+ vb_loss = cross_entropy_loss / t
+ vb_loss = vb_loss / pt
+ vb_loss = vb_loss / (math.log(2) * x_0.shape[1:].numel())
+ if self.loss_type == 'elbo':
+ loss = vb_loss
+ elif self.loss_type == 'mlm':
+ denom = mask.float().sum(1)
+ denom[denom == 0] = 1 # prevent divide by 0 errors.
+ loss = cross_entropy_loss / denom
+ elif self.loss_type == 'reweighted_elbo':
+ weight = (1 - (t / self.num_timesteps))
+ loss = weight * cross_entropy_loss
+ loss = loss / (math.log(2) * x_0.shape[1:].numel())
+ else:
+ raise ValueError
+
+ return loss.mean(), vb_loss.mean()
+
+ def feed_data(self, data):
+ self.image = data['image'].to(self.device)
+ self.segm = data['segm'].to(self.device)
+ self.texture_mask = data['texture_mask'].to(self.device)
+ self.input_indices, self.gt_indices_list = self.get_quantized_img(
+ self.image, self.texture_mask)
+
+ self.texture_tokens = F.interpolate(
+ self.texture_mask, size=self.shape,
+ mode='nearest').view(self.image.size(0), -1).long()
+
+ self.segm_tokens = self.get_quantized_segm(self.segm)
+ self.segm_tokens = self.segm_tokens.view(self.image.size(0), -1)
+
+ def optimize_parameters(self):
+ self._denoise_fn.train()
+
+ loss, vb_loss = self._train_loss(self.input_indices,
+ self.gt_indices_list)
+
+ self.optimizer.zero_grad()
+ loss.backward()
+ self.optimizer.step()
+
+ self.log_dict['loss'] = loss
+ self.log_dict['vb_loss'] = vb_loss
+
+ self._denoise_fn.eval()
+
+ @torch.no_grad()
+ def get_quantized_segm(self, segm):
+ segm_one_hot = F.one_hot(
+ segm.squeeze(1).long(),
+ num_classes=self.opt['segm_num_segm_classes']).permute(
+ 0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+ encoded_segm_mask = self.segm_encoder(segm_one_hot)
+ encoded_segm_mask = self.segm_quant_conv(encoded_segm_mask)
+ _, _, [_, _, segm_tokens] = self.segm_quantizer(encoded_segm_mask)
+
+ return segm_tokens
+
+ def sample_fn(self, temp=1.0, sample_steps=None):
+ self._denoise_fn.eval()
+
+ b, device = self.image.size(0), 'cuda'
+ x_t = torch.ones(
+ (b, np.prod(self.shape)), device=device).long() * self.mask_id
+ unmasked = torch.zeros_like(x_t, device=device).bool()
+ sample_steps = list(range(1, sample_steps + 1))
+
+ texture_mask_flatten = self.texture_tokens.view(-1)
+
+ # min_encodings_indices_list would be used to visualize the image
+ min_encodings_indices_list = [
+ torch.full(
+ texture_mask_flatten.size(),
+ fill_value=-1,
+ dtype=torch.long,
+ device=texture_mask_flatten.device) for _ in range(18)
+ ]
+
+ for t in reversed(sample_steps):
+ print(f'Sample timestep {t:4d}', end='\r')
+ t = torch.full((b, ), t, device=device, dtype=torch.long)
+
+ # where to unmask
+ changes = torch.rand(
+ x_t.shape, device=device) < 1 / t.float().unsqueeze(-1)
+ # don't unmask somewhere already unmasked
+ changes = torch.bitwise_xor(changes,
+ torch.bitwise_and(changes, unmasked))
+ # update mask with changes
+ unmasked = torch.bitwise_or(unmasked, changes)
+
+ x_0_logits_list = self._denoise_fn(
+ x_t, self.segm_tokens, self.texture_tokens, t=t)
+
+ changes_flatten = changes.view(-1)
+ ori_shape = x_t.shape # [b, h*w]
+ x_t = x_t.view(-1) # [b*h*w]
+ for codebook_idx, x_0_logits in enumerate(x_0_logits_list):
+ if torch.sum(texture_mask_flatten[changes_flatten] ==
+ codebook_idx) > 0:
+ # scale by temperature
+ x_0_logits = x_0_logits / temp
+ x_0_dist = dists.Categorical(logits=x_0_logits)
+ x_0_hat = x_0_dist.sample().long()
+ x_0_hat = x_0_hat.view(-1)
+
+ # only replace the changed indices with corresponding codebook_idx
+ changes_segm = torch.bitwise_and(
+ changes_flatten, texture_mask_flatten == codebook_idx)
+
+ # x_t would be the input to the transformer, so the index range should be continual one
+ x_t[changes_segm] = x_0_hat[
+ changes_segm] + 1024 * codebook_idx
+ min_encodings_indices_list[codebook_idx][
+ changes_segm] = x_0_hat[changes_segm]
+
+ x_t = x_t.view(ori_shape) # [b, h*w]
+
+ min_encodings_indices_return_list = [
+ min_encodings_indices.view(ori_shape)
+ for min_encodings_indices in min_encodings_indices_list
+ ]
+
+ self._denoise_fn.train()
+
+ return min_encodings_indices_return_list
+
+ def get_vis(self, image, gt_indices, predicted_indices, texture_mask,
+ save_path):
+ # original image
+ ori_img = self.decode_image_indices(gt_indices, texture_mask)
+ # pred image
+ pred_img = self.decode_image_indices(predicted_indices, texture_mask)
+ img_cat = torch.cat([
+ image,
+ ori_img,
+ pred_img,
+ ], dim=3).detach()
+ img_cat = ((img_cat + 1) / 2)
+ img_cat = img_cat.clamp_(0, 1)
+ save_image(img_cat, save_path, nrow=1, padding=4)
+
+ def inference(self, data_loader, save_dir):
+ self._denoise_fn.eval()
+
+ for _, data in enumerate(data_loader):
+ img_name = data['img_name']
+ self.feed_data(data)
+ b = self.image.size(0)
+ with torch.no_grad():
+ sampled_indices_list = self.sample_fn(
+ temp=1, sample_steps=self.sample_steps)
+ for idx in range(b):
+ self.get_vis(self.image[idx:idx + 1], [
+ gt_indices[idx:idx + 1]
+ for gt_indices in self.gt_indices_list
+ ], [
+ sampled_indices[idx:idx + 1]
+ for sampled_indices in sampled_indices_list
+ ], self.texture_mask[idx:idx + 1],
+ f'{save_dir}/{img_name[idx]}')
+
+ self._denoise_fn.train()
+
+ def get_current_log(self):
+ return self.log_dict
+
+ def update_learning_rate(self, epoch, iters=None):
+ """Update learning rate.
+
+ Args:
+ current_iter (int): Current iteration.
+ warmup_iter (int): Warmup iter numbers. -1 for no warmup.
+ Default: -1.
+ """
+ lr = self.optimizer.param_groups[0]['lr']
+
+ if self.opt['lr_decay'] == 'step':
+ lr = self.opt['lr'] * (
+ self.opt['gamma']**(epoch // self.opt['step']))
+ elif self.opt['lr_decay'] == 'cos':
+ lr = self.opt['lr'] * (
+ 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
+ elif self.opt['lr_decay'] == 'linear':
+ lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
+ elif self.opt['lr_decay'] == 'linear2exp':
+ if epoch < self.opt['turning_point'] + 1:
+ # learning rate decay as 95%
+ # at the turning point (1 / 95% = 1.0526)
+ lr = self.opt['lr'] * (
+ 1 - epoch / int(self.opt['turning_point'] * 1.0526))
+ else:
+ lr *= self.opt['gamma']
+ elif self.opt['lr_decay'] == 'schedule':
+ if epoch in self.opt['schedule']:
+ lr *= self.opt['gamma']
+ elif self.opt['lr_decay'] == 'warm_up':
+ if iters <= self.opt['warmup_iters']:
+ lr = self.opt['lr'] * float(iters) / self.opt['warmup_iters']
+ else:
+ lr = self.opt['lr']
+ else:
+ raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
+ # set learning rate
+ for param_group in self.optimizer.param_groups:
+ param_group['lr'] = lr
+
+ return lr
+
+ def save_network(self, net, save_path):
+ """Save networks.
+
+ Args:
+ net (nn.Module): Network to be saved.
+ net_label (str): Network label.
+ current_iter (int): Current iter number.
+ """
+ state_dict = net.state_dict()
+ torch.save(state_dict, save_path)
+
+ def load_network(self):
+ checkpoint = torch.load(self.opt['pretrained_sampler'])
+ self._denoise_fn.load_state_dict(checkpoint, strict=True)
+ self._denoise_fn.eval()
diff --git a/Text2Human/models/vqgan_model.py b/Text2Human/models/vqgan_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..13a2e7062c4b49052e91ac3c183eaa7056986050
--- /dev/null
+++ b/Text2Human/models/vqgan_model.py
@@ -0,0 +1,551 @@
+import math
+import sys
+from collections import OrderedDict
+
+sys.path.append('..')
+import lpips
+import torch
+import torch.nn.functional as F
+from torchvision.utils import save_image
+
+from models.archs.vqgan_arch import (Decoder, Discriminator, Encoder,
+ VectorQuantizer, VectorQuantizerTexture)
+from models.losses.segmentation_loss import BCELossWithQuant
+from models.losses.vqgan_loss import (DiffAugment, adopt_weight,
+ calculate_adaptive_weight, hinge_d_loss)
+
+
+class VQModel():
+
+ def __init__(self, opt):
+ super().__init__()
+ self.opt = opt
+ self.device = torch.device('cuda')
+ self.encoder = Encoder(
+ ch=opt['ch'],
+ num_res_blocks=opt['num_res_blocks'],
+ attn_resolutions=opt['attn_resolutions'],
+ ch_mult=opt['ch_mult'],
+ in_channels=opt['in_channels'],
+ resolution=opt['resolution'],
+ z_channels=opt['z_channels'],
+ double_z=opt['double_z'],
+ dropout=opt['dropout']).to(self.device)
+ self.decoder = Decoder(
+ in_channels=opt['in_channels'],
+ resolution=opt['resolution'],
+ z_channels=opt['z_channels'],
+ ch=opt['ch'],
+ out_ch=opt['out_ch'],
+ num_res_blocks=opt['num_res_blocks'],
+ attn_resolutions=opt['attn_resolutions'],
+ ch_mult=opt['ch_mult'],
+ dropout=opt['dropout'],
+ resamp_with_conv=True,
+ give_pre_end=False).to(self.device)
+ self.quantize = VectorQuantizer(
+ opt['n_embed'], opt['embed_dim'], beta=0.25).to(self.device)
+ self.quant_conv = torch.nn.Conv2d(opt["z_channels"], opt['embed_dim'],
+ 1).to(self.device)
+ self.post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
+ opt["z_channels"],
+ 1).to(self.device)
+
+ def init_training_settings(self):
+ self.loss = BCELossWithQuant()
+ self.log_dict = OrderedDict()
+ self.configure_optimizers()
+
+ def save_network(self, save_path):
+ """Save networks.
+
+ Args:
+ net (nn.Module): Network to be saved.
+ net_label (str): Network label.
+ current_iter (int): Current iter number.
+ """
+
+ save_dict = {}
+ save_dict['encoder'] = self.encoder.state_dict()
+ save_dict['decoder'] = self.decoder.state_dict()
+ save_dict['quantize'] = self.quantize.state_dict()
+ save_dict['quant_conv'] = self.quant_conv.state_dict()
+ save_dict['post_quant_conv'] = self.post_quant_conv.state_dict()
+ save_dict['discriminator'] = self.disc.state_dict()
+ torch.save(save_dict, save_path)
+
+ def load_network(self):
+ checkpoint = torch.load(self.opt['pretrained_models'])
+ self.encoder.load_state_dict(checkpoint['encoder'], strict=True)
+ self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
+ self.quantize.load_state_dict(checkpoint['quantize'], strict=True)
+ self.quant_conv.load_state_dict(checkpoint['quant_conv'], strict=True)
+ self.post_quant_conv.load_state_dict(
+ checkpoint['post_quant_conv'], strict=True)
+
+ def optimize_parameters(self, data, current_iter):
+ self.encoder.train()
+ self.decoder.train()
+ self.quantize.train()
+ self.quant_conv.train()
+ self.post_quant_conv.train()
+
+ loss = self.training_step(data)
+ self.optimizer.zero_grad()
+ loss.backward()
+ self.optimizer.step()
+
+ def encode(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ quant, emb_loss, info = self.quantize(h)
+ return quant, emb_loss, info
+
+ def decode(self, quant):
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+ return dec
+
+ def decode_code(self, code_b):
+ quant_b = self.quantize.embed_code(code_b)
+ dec = self.decode(quant_b)
+ return dec
+
+ def forward_step(self, input):
+ quant, diff, _ = self.encode(input)
+ dec = self.decode(quant)
+ return dec, diff
+
+ def feed_data(self, data):
+ x = data['segm']
+ x = F.one_hot(x, num_classes=self.opt['num_segm_classes'])
+
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
+ return x.float().to(self.device)
+
+ def get_current_log(self):
+ return self.log_dict
+
+ def update_learning_rate(self, epoch):
+ """Update learning rate.
+
+ Args:
+ current_iter (int): Current iteration.
+ warmup_iter (int): Warmup iter numbers. -1 for no warmup.
+ Default: -1.
+ """
+ lr = self.optimizer.param_groups[0]['lr']
+
+ if self.opt['lr_decay'] == 'step':
+ lr = self.opt['lr'] * (
+ self.opt['gamma']**(epoch // self.opt['step']))
+ elif self.opt['lr_decay'] == 'cos':
+ lr = self.opt['lr'] * (
+ 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
+ elif self.opt['lr_decay'] == 'linear':
+ lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
+ elif self.opt['lr_decay'] == 'linear2exp':
+ if epoch < self.opt['turning_point'] + 1:
+ # learning rate decay as 95%
+ # at the turning point (1 / 95% = 1.0526)
+ lr = self.opt['lr'] * (
+ 1 - epoch / int(self.opt['turning_point'] * 1.0526))
+ else:
+ lr *= self.opt['gamma']
+ elif self.opt['lr_decay'] == 'schedule':
+ if epoch in self.opt['schedule']:
+ lr *= self.opt['gamma']
+ else:
+ raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
+ # set learning rate
+ for param_group in self.optimizer.param_groups:
+ param_group['lr'] = lr
+
+ return lr
+
+
+class VQSegmentationModel(VQModel):
+
+ def __init__(self, opt):
+ super().__init__(opt)
+ self.colorize = torch.randn(3, opt['num_segm_classes'], 1,
+ 1).to(self.device)
+
+ self.init_training_settings()
+
+ def configure_optimizers(self):
+ self.optimizer = torch.optim.Adam(
+ list(self.encoder.parameters()) + list(self.decoder.parameters()) +
+ list(self.quantize.parameters()) +
+ list(self.quant_conv.parameters()) +
+ list(self.post_quant_conv.parameters()),
+ lr=self.opt['lr'],
+ betas=(0.5, 0.9))
+
+ def training_step(self, data):
+ x = self.feed_data(data)
+ xrec, qloss = self.forward_step(x)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
+ self.log_dict.update(log_dict_ae)
+ return aeloss
+
+ def to_rgb(self, x):
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
+ return x
+
+ @torch.no_grad()
+ def inference(self, data_loader, save_dir):
+ self.encoder.eval()
+ self.decoder.eval()
+ self.quantize.eval()
+ self.quant_conv.eval()
+ self.post_quant_conv.eval()
+
+ loss_total = 0
+ loss_bce = 0
+ loss_quant = 0
+ num = 0
+
+ for _, data in enumerate(data_loader):
+ img_name = data['img_name'][0]
+ x = self.feed_data(data)
+ xrec, qloss = self.forward_step(x)
+ _, log_dict_ae = self.loss(qloss, x, xrec, split="val")
+
+ loss_total += log_dict_ae['val/total_loss']
+ loss_bce += log_dict_ae['val/bce_loss']
+ loss_quant += log_dict_ae['val/quant_loss']
+
+ num += x.size(0)
+
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ # convert logits to indices
+ xrec = torch.argmax(xrec, dim=1, keepdim=True)
+ xrec = F.one_hot(xrec, num_classes=x.shape[1])
+ xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+
+ img_cat = torch.cat([x, xrec], dim=3).detach()
+ img_cat = ((img_cat + 1) / 2)
+ img_cat = img_cat.clamp_(0, 1)
+ save_image(
+ img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
+
+ return (loss_total / num).item(), (loss_bce /
+ num).item(), (loss_quant /
+ num).item()
+
+
+class VQImageModel(VQModel):
+
+ def __init__(self, opt):
+ super().__init__(opt)
+ self.disc = Discriminator(
+ opt['n_channels'], opt['ndf'],
+ n_layers=opt['disc_layers']).to(self.device)
+ self.perceptual = lpips.LPIPS(net="vgg").to(self.device)
+ self.perceptual_weight = opt['perceptual_weight']
+ self.disc_start_step = opt['disc_start_step']
+ self.disc_weight_max = opt['disc_weight_max']
+ self.diff_aug = opt['diff_aug']
+ self.policy = "color,translation"
+
+ self.disc.train()
+
+ self.init_training_settings()
+
+ def feed_data(self, data):
+ x = data['image']
+
+ return x.float().to(self.device)
+
+ def init_training_settings(self):
+ self.log_dict = OrderedDict()
+ self.configure_optimizers()
+
+ def configure_optimizers(self):
+ self.optimizer = torch.optim.Adam(
+ list(self.encoder.parameters()) + list(self.decoder.parameters()) +
+ list(self.quantize.parameters()) +
+ list(self.quant_conv.parameters()) +
+ list(self.post_quant_conv.parameters()),
+ lr=self.opt['lr'])
+
+ self.disc_optimizer = torch.optim.Adam(
+ self.disc.parameters(), lr=self.opt['lr'])
+
+ def training_step(self, data, step):
+ x = self.feed_data(data)
+ xrec, codebook_loss = self.forward_step(x)
+
+ # get recon/perceptual loss
+ recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
+ p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
+ nll_loss = recon_loss + self.perceptual_weight * p_loss
+ nll_loss = torch.mean(nll_loss)
+
+ # augment for input to discriminator
+ if self.diff_aug:
+ xrec = DiffAugment(xrec, policy=self.policy)
+
+ # update generator
+ logits_fake = self.disc(xrec)
+ g_loss = -torch.mean(logits_fake)
+ last_layer = self.decoder.conv_out.weight
+ d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer,
+ self.disc_weight_max)
+ d_weight *= adopt_weight(1, step, self.disc_start_step)
+ loss = nll_loss + d_weight * g_loss + codebook_loss
+
+ self.log_dict["loss"] = loss
+ self.log_dict["l1"] = recon_loss.mean().item()
+ self.log_dict["perceptual"] = p_loss.mean().item()
+ self.log_dict["nll_loss"] = nll_loss.item()
+ self.log_dict["g_loss"] = g_loss.item()
+ self.log_dict["d_weight"] = d_weight
+ self.log_dict["codebook_loss"] = codebook_loss.item()
+
+ if step > self.disc_start_step:
+ if self.diff_aug:
+ logits_real = self.disc(
+ DiffAugment(x.contiguous().detach(), policy=self.policy))
+ else:
+ logits_real = self.disc(x.contiguous().detach())
+ logits_fake = self.disc(xrec.contiguous().detach(
+ )) # detach so that generator isn"t also updated
+ d_loss = hinge_d_loss(logits_real, logits_fake)
+ self.log_dict["d_loss"] = d_loss
+ else:
+ d_loss = None
+
+ return loss, d_loss
+
+ def optimize_parameters(self, data, step):
+ self.encoder.train()
+ self.decoder.train()
+ self.quantize.train()
+ self.quant_conv.train()
+ self.post_quant_conv.train()
+
+ loss, d_loss = self.training_step(data, step)
+ self.optimizer.zero_grad()
+ loss.backward()
+ self.optimizer.step()
+
+ if step > self.disc_start_step:
+ self.disc_optimizer.zero_grad()
+ d_loss.backward()
+ self.disc_optimizer.step()
+
+ @torch.no_grad()
+ def inference(self, data_loader, save_dir):
+ self.encoder.eval()
+ self.decoder.eval()
+ self.quantize.eval()
+ self.quant_conv.eval()
+ self.post_quant_conv.eval()
+
+ loss_total = 0
+ num = 0
+
+ for _, data in enumerate(data_loader):
+ img_name = data['img_name'][0]
+ x = self.feed_data(data)
+ xrec, _ = self.forward_step(x)
+
+ recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
+ p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
+ nll_loss = recon_loss + self.perceptual_weight * p_loss
+ nll_loss = torch.mean(nll_loss)
+ loss_total += nll_loss
+
+ num += x.size(0)
+
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ # convert logits to indices
+ xrec = torch.argmax(xrec, dim=1, keepdim=True)
+ xrec = F.one_hot(xrec, num_classes=x.shape[1])
+ xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+
+ img_cat = torch.cat([x, xrec], dim=3).detach()
+ img_cat = ((img_cat + 1) / 2)
+ img_cat = img_cat.clamp_(0, 1)
+ save_image(
+ img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
+
+ return (loss_total / num).item()
+
+
+class VQImageSegmTextureModel(VQImageModel):
+
+ def __init__(self, opt):
+ self.opt = opt
+ self.device = torch.device('cuda')
+ self.encoder = Encoder(
+ ch=opt['ch'],
+ num_res_blocks=opt['num_res_blocks'],
+ attn_resolutions=opt['attn_resolutions'],
+ ch_mult=opt['ch_mult'],
+ in_channels=opt['in_channels'],
+ resolution=opt['resolution'],
+ z_channels=opt['z_channels'],
+ double_z=opt['double_z'],
+ dropout=opt['dropout']).to(self.device)
+ self.decoder = Decoder(
+ in_channels=opt['in_channels'],
+ resolution=opt['resolution'],
+ z_channels=opt['z_channels'],
+ ch=opt['ch'],
+ out_ch=opt['out_ch'],
+ num_res_blocks=opt['num_res_blocks'],
+ attn_resolutions=opt['attn_resolutions'],
+ ch_mult=opt['ch_mult'],
+ dropout=opt['dropout'],
+ resamp_with_conv=True,
+ give_pre_end=False).to(self.device)
+ self.quantize = VectorQuantizerTexture(
+ opt['n_embed'], opt['embed_dim'], beta=0.25).to(self.device)
+ self.quant_conv = torch.nn.Conv2d(opt["z_channels"], opt['embed_dim'],
+ 1).to(self.device)
+ self.post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
+ opt["z_channels"],
+ 1).to(self.device)
+
+ self.disc = Discriminator(
+ opt['n_channels'], opt['ndf'],
+ n_layers=opt['disc_layers']).to(self.device)
+ self.perceptual = lpips.LPIPS(net="vgg").to(self.device)
+ self.perceptual_weight = opt['perceptual_weight']
+ self.disc_start_step = opt['disc_start_step']
+ self.disc_weight_max = opt['disc_weight_max']
+ self.diff_aug = opt['diff_aug']
+ self.policy = "color,translation"
+
+ self.disc.train()
+
+ self.init_training_settings()
+
+ def feed_data(self, data):
+ x = data['image'].float().to(self.device)
+ mask = data['texture_mask'].float().to(self.device)
+
+ return x, mask
+
+ def training_step(self, data, step):
+ x, mask = self.feed_data(data)
+ xrec, codebook_loss = self.forward_step(x, mask)
+
+ # get recon/perceptual loss
+ recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
+ p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
+ nll_loss = recon_loss + self.perceptual_weight * p_loss
+ nll_loss = torch.mean(nll_loss)
+
+ # augment for input to discriminator
+ if self.diff_aug:
+ xrec = DiffAugment(xrec, policy=self.policy)
+
+ # update generator
+ logits_fake = self.disc(xrec)
+ g_loss = -torch.mean(logits_fake)
+ last_layer = self.decoder.conv_out.weight
+ d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer,
+ self.disc_weight_max)
+ d_weight *= adopt_weight(1, step, self.disc_start_step)
+ loss = nll_loss + d_weight * g_loss + codebook_loss
+
+ self.log_dict["loss"] = loss
+ self.log_dict["l1"] = recon_loss.mean().item()
+ self.log_dict["perceptual"] = p_loss.mean().item()
+ self.log_dict["nll_loss"] = nll_loss.item()
+ self.log_dict["g_loss"] = g_loss.item()
+ self.log_dict["d_weight"] = d_weight
+ self.log_dict["codebook_loss"] = codebook_loss.item()
+
+ if step > self.disc_start_step:
+ if self.diff_aug:
+ logits_real = self.disc(
+ DiffAugment(x.contiguous().detach(), policy=self.policy))
+ else:
+ logits_real = self.disc(x.contiguous().detach())
+ logits_fake = self.disc(xrec.contiguous().detach(
+ )) # detach so that generator isn"t also updated
+ d_loss = hinge_d_loss(logits_real, logits_fake)
+ self.log_dict["d_loss"] = d_loss
+ else:
+ d_loss = None
+
+ return loss, d_loss
+
+ @torch.no_grad()
+ def inference(self, data_loader, save_dir):
+ self.encoder.eval()
+ self.decoder.eval()
+ self.quantize.eval()
+ self.quant_conv.eval()
+ self.post_quant_conv.eval()
+
+ loss_total = 0
+ num = 0
+
+ for _, data in enumerate(data_loader):
+ img_name = data['img_name'][0]
+ x, mask = self.feed_data(data)
+ xrec, _ = self.forward_step(x, mask)
+
+ recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
+ p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
+ nll_loss = recon_loss + self.perceptual_weight * p_loss
+ nll_loss = torch.mean(nll_loss)
+ loss_total += nll_loss
+
+ num += x.size(0)
+
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ # convert logits to indices
+ xrec = torch.argmax(xrec, dim=1, keepdim=True)
+ xrec = F.one_hot(xrec, num_classes=x.shape[1])
+ xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+
+ img_cat = torch.cat([x, xrec], dim=3).detach()
+ img_cat = ((img_cat + 1) / 2)
+ img_cat = img_cat.clamp_(0, 1)
+ save_image(
+ img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
+
+ return (loss_total / num).item()
+
+ def encode(self, x, mask):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ quant, emb_loss, info = self.quantize(h, mask)
+ return quant, emb_loss, info
+
+ def decode(self, quant):
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+ return dec
+
+ def decode_code(self, code_b):
+ quant_b = self.quantize.embed_code(code_b)
+ dec = self.decode(quant_b)
+ return dec
+
+ def forward_step(self, input, mask):
+ quant, diff, _ = self.encode(input, mask)
+ dec = self.decode(quant)
+ return dec, diff
diff --git a/Text2Human/sample_from_parsing.py b/Text2Human/sample_from_parsing.py
new file mode 100644
index 0000000000000000000000000000000000000000..954f389e7e3b320c763e755400ea5fd6aaf8736d
--- /dev/null
+++ b/Text2Human/sample_from_parsing.py
@@ -0,0 +1,53 @@
+import argparse
+import logging
+import os.path as osp
+import random
+
+import torch
+
+from data.segm_attr_dataset import DeepFashionAttrSegmDataset
+from models import create_model
+from utils.logger import get_root_logger
+from utils.options import dict2str, dict_to_nonedict, parse
+from utils.util import make_exp_dirs, set_random_seed
+
+
+def main():
+ # options
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-opt', type=str, help='Path to option YAML file.')
+ args = parser.parse_args()
+ opt = parse(args.opt, is_train=False)
+
+ # mkdir and loggers
+ make_exp_dirs(opt)
+ log_file = osp.join(opt['path']['log'], f"test_{opt['name']}.log")
+ logger = get_root_logger(
+ logger_name='base', log_level=logging.INFO, log_file=log_file)
+ logger.info(dict2str(opt))
+
+ # convert to NoneDict, which returns None for missing keys
+ opt = dict_to_nonedict(opt)
+
+ # random seed
+ seed = opt['manual_seed']
+ if seed is None:
+ seed = random.randint(1, 10000)
+ logger.info(f'Random seed: {seed}')
+ set_random_seed(seed)
+
+ test_dataset = DeepFashionAttrSegmDataset(
+ img_dir=opt['test_img_dir'],
+ segm_dir=opt['segm_dir'],
+ pose_dir=opt['pose_dir'],
+ ann_dir=opt['test_ann_file'])
+ test_loader = torch.utils.data.DataLoader(
+ dataset=test_dataset, batch_size=4, shuffle=False)
+ logger.info(f'Number of test set: {len(test_dataset)}.')
+
+ model = create_model(opt)
+ _ = model.inference(test_loader, opt['path']['results_root'])
+
+
+if __name__ == '__main__':
+ main()
diff --git a/Text2Human/sample_from_pose.py b/Text2Human/sample_from_pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad1efa7835a5977dbf7fc99ebe037d2f3452d27c
--- /dev/null
+++ b/Text2Human/sample_from_pose.py
@@ -0,0 +1,52 @@
+import argparse
+import logging
+import os.path as osp
+import random
+
+import torch
+
+from data.pose_attr_dataset import DeepFashionAttrPoseDataset
+from models import create_model
+from utils.logger import get_root_logger
+from utils.options import dict2str, dict_to_nonedict, parse
+from utils.util import make_exp_dirs, set_random_seed
+
+
+def main():
+ # options
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-opt', type=str, help='Path to option YAML file.')
+ args = parser.parse_args()
+ opt = parse(args.opt, is_train=False)
+
+ # mkdir and loggers
+ make_exp_dirs(opt)
+ log_file = osp.join(opt['path']['log'], f"test_{opt['name']}.log")
+ logger = get_root_logger(
+ logger_name='base', log_level=logging.INFO, log_file=log_file)
+ logger.info(dict2str(opt))
+
+ # convert to NoneDict, which returns None for missing keys
+ opt = dict_to_nonedict(opt)
+
+ # random seed
+ seed = opt['manual_seed']
+ if seed is None:
+ seed = random.randint(1, 10000)
+ logger.info(f'Random seed: {seed}')
+ set_random_seed(seed)
+
+ test_dataset = DeepFashionAttrPoseDataset(
+ pose_dir=opt['pose_dir'],
+ texture_ann_dir=opt['texture_ann_file'],
+ shape_ann_path=opt['shape_ann_path'])
+ test_loader = torch.utils.data.DataLoader(
+ dataset=test_dataset, batch_size=4, shuffle=False)
+ logger.info(f'Number of test set: {len(test_dataset)}.')
+
+ model = create_model(opt)
+ _ = model.inference(test_loader, opt['path']['results_root'])
+
+
+if __name__ == '__main__':
+ main()
diff --git a/Text2Human/train_index_prediction.py b/Text2Human/train_index_prediction.py
new file mode 100644
index 0000000000000000000000000000000000000000..08c66dca912b94f4f2903edb8373978d8d6ae7c0
--- /dev/null
+++ b/Text2Human/train_index_prediction.py
@@ -0,0 +1,133 @@
+import argparse
+import logging
+import os
+import os.path as osp
+import random
+import time
+
+import torch
+
+from data.segm_attr_dataset import DeepFashionAttrSegmDataset
+from models import create_model
+from utils.logger import MessageLogger, get_root_logger, init_tb_logger
+from utils.options import dict2str, dict_to_nonedict, parse
+from utils.util import make_exp_dirs
+
+
+def main():
+ # options
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-opt', type=str, help='Path to option YAML file.')
+ args = parser.parse_args()
+ opt = parse(args.opt, is_train=True)
+
+ # mkdir and loggers
+ make_exp_dirs(opt)
+ log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
+ logger = get_root_logger(
+ logger_name='base', log_level=logging.INFO, log_file=log_file)
+ logger.info(dict2str(opt))
+ # initialize tensorboard logger
+ tb_logger = None
+ if opt['use_tb_logger'] and 'debug' not in opt['name']:
+ tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
+
+ # convert to NoneDict, which returns None for missing keys
+ opt = dict_to_nonedict(opt)
+
+ # set up data loader
+ train_dataset = DeepFashionAttrSegmDataset(
+ img_dir=opt['train_img_dir'],
+ segm_dir=opt['segm_dir'],
+ pose_dir=opt['pose_dir'],
+ ann_dir=opt['train_ann_file'],
+ xflip=True)
+ train_loader = torch.utils.data.DataLoader(
+ dataset=train_dataset,
+ batch_size=opt['batch_size'],
+ shuffle=True,
+ num_workers=opt['num_workers'],
+ drop_last=True)
+ logger.info(f'Number of train set: {len(train_dataset)}.')
+ opt['max_iters'] = opt['num_epochs'] * len(
+ train_dataset) // opt['batch_size']
+
+ val_dataset = DeepFashionAttrSegmDataset(
+ img_dir=opt['train_img_dir'],
+ segm_dir=opt['segm_dir'],
+ pose_dir=opt['pose_dir'],
+ ann_dir=opt['val_ann_file'])
+ val_loader = torch.utils.data.DataLoader(
+ dataset=val_dataset, batch_size=1, shuffle=False)
+ logger.info(f'Number of val set: {len(val_dataset)}.')
+
+ test_dataset = DeepFashionAttrSegmDataset(
+ img_dir=opt['test_img_dir'],
+ segm_dir=opt['segm_dir'],
+ pose_dir=opt['pose_dir'],
+ ann_dir=opt['test_ann_file'])
+ test_loader = torch.utils.data.DataLoader(
+ dataset=test_dataset, batch_size=1, shuffle=False)
+ logger.info(f'Number of test set: {len(test_dataset)}.')
+
+ current_iter = 0
+ best_epoch = None
+ best_acc = 0
+
+ model = create_model(opt)
+
+ data_time, iter_time = 0, 0
+ current_iter = 0
+
+ # create message logger (formatted outputs)
+ msg_logger = MessageLogger(opt, current_iter, tb_logger)
+
+ for epoch in range(opt['num_epochs']):
+ lr = model.update_learning_rate(epoch)
+
+ for _, batch_data in enumerate(train_loader):
+ data_time = time.time() - data_time
+
+ current_iter += 1
+
+ model.feed_data(batch_data)
+ model.optimize_parameters()
+
+ iter_time = time.time() - iter_time
+ if current_iter % opt['print_freq'] == 0:
+ log_vars = {'epoch': epoch, 'iter': current_iter}
+ log_vars.update({'lrs': [lr]})
+ log_vars.update({'time': iter_time, 'data_time': data_time})
+ log_vars.update(model.get_current_log())
+ msg_logger(log_vars)
+
+ data_time = time.time()
+ iter_time = time.time()
+
+ if epoch % opt['val_freq'] == 0:
+ save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa
+ os.makedirs(save_dir, exist_ok=opt['debug'])
+ val_acc = model.inference(val_loader, save_dir)
+
+ save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa
+ os.makedirs(save_dir, exist_ok=opt['debug'])
+ test_acc = model.inference(test_loader, save_dir)
+
+ logger.info(
+ f'Epoch: {epoch}, val_acc: {val_acc: .4f}, test_acc: {test_acc: .4f}.'
+ )
+
+ if test_acc > best_acc:
+ best_epoch = epoch
+ best_acc = test_acc
+
+ logger.info(f'Best epoch: {best_epoch}, '
+ f'Best test acc: {best_acc: .4f}.')
+
+ # save model
+ model.save_network(
+ f'{opt["path"]["models"]}/models_epoch{epoch}.pth')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/Text2Human/train_parsing_gen.py b/Text2Human/train_parsing_gen.py
new file mode 100644
index 0000000000000000000000000000000000000000..e032b002b1252e289266d048f6326e02b36a023c
--- /dev/null
+++ b/Text2Human/train_parsing_gen.py
@@ -0,0 +1,136 @@
+import argparse
+import logging
+import os
+import os.path as osp
+import random
+import time
+
+import torch
+
+from data.parsing_generation_segm_attr_dataset import \
+ ParsingGenerationDeepFashionAttrSegmDataset
+from models import create_model
+from utils.logger import MessageLogger, get_root_logger, init_tb_logger
+from utils.options import dict2str, dict_to_nonedict, parse
+from utils.util import make_exp_dirs
+
+
+def main():
+ # options
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-opt', type=str, help='Path to option YAML file.')
+ args = parser.parse_args()
+ opt = parse(args.opt, is_train=True)
+
+ # mkdir and loggers
+ make_exp_dirs(opt)
+ log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
+ logger = get_root_logger(
+ logger_name='base', log_level=logging.INFO, log_file=log_file)
+ logger.info(dict2str(opt))
+ # initialize tensorboard logger
+ tb_logger = None
+ if opt['use_tb_logger'] and 'debug' not in opt['name']:
+ tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
+
+ # convert to NoneDict, which returns None for missing keys
+ opt = dict_to_nonedict(opt)
+
+ # set up data loader
+ train_dataset = ParsingGenerationDeepFashionAttrSegmDataset(
+ segm_dir=opt['segm_dir'],
+ pose_dir=opt['pose_dir'],
+ ann_file=opt['train_ann_file'])
+ train_loader = torch.utils.data.DataLoader(
+ dataset=train_dataset,
+ batch_size=opt['batch_size'],
+ shuffle=True,
+ num_workers=opt['num_workers'],
+ drop_last=True)
+ logger.info(f'Number of train set: {len(train_dataset)}.')
+ opt['max_iters'] = opt['num_epochs'] * len(
+ train_dataset) // opt['batch_size']
+
+ val_dataset = ParsingGenerationDeepFashionAttrSegmDataset(
+ segm_dir=opt['segm_dir'],
+ pose_dir=opt['pose_dir'],
+ ann_file=opt['val_ann_file'])
+ val_loader = torch.utils.data.DataLoader(
+ dataset=val_dataset,
+ batch_size=1,
+ shuffle=False,
+ num_workers=opt['num_workers'])
+ logger.info(f'Number of val set: {len(val_dataset)}.')
+
+ test_dataset = ParsingGenerationDeepFashionAttrSegmDataset(
+ segm_dir=opt['segm_dir'],
+ pose_dir=opt['pose_dir'],
+ ann_file=opt['test_ann_file'])
+ test_loader = torch.utils.data.DataLoader(
+ dataset=test_dataset,
+ batch_size=1,
+ shuffle=False,
+ num_workers=opt['num_workers'])
+ logger.info(f'Number of test set: {len(test_dataset)}.')
+
+ current_iter = 0
+ best_epoch = None
+ best_acc = 0
+
+ model = create_model(opt)
+
+ data_time, iter_time = 0, 0
+ current_iter = 0
+
+ # create message logger (formatted outputs)
+ msg_logger = MessageLogger(opt, current_iter, tb_logger)
+
+ for epoch in range(opt['num_epochs']):
+ lr = model.update_learning_rate(epoch)
+
+ for _, batch_data in enumerate(train_loader):
+ data_time = time.time() - data_time
+
+ current_iter += 1
+
+ model.feed_data(batch_data)
+ model.optimize_parameters()
+
+ iter_time = time.time() - iter_time
+ if current_iter % opt['print_freq'] == 0:
+ log_vars = {'epoch': epoch, 'iter': current_iter}
+ log_vars.update({'lrs': [lr]})
+ log_vars.update({'time': iter_time, 'data_time': data_time})
+ log_vars.update(model.get_current_log())
+ msg_logger(log_vars)
+
+ data_time = time.time()
+ iter_time = time.time()
+
+ if epoch % opt['val_freq'] == 0:
+ save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}'
+ os.makedirs(save_dir, exist_ok=opt['debug'])
+ val_acc = model.inference(val_loader, save_dir)
+
+ save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}'
+ os.makedirs(save_dir, exist_ok=opt['debug'])
+ test_acc = model.inference(test_loader, save_dir)
+
+ logger.info(f'Epoch: {epoch}, '
+ f'val_acc: {val_acc: .4f}, '
+ f'test_acc: {test_acc: .4f}.')
+
+ if test_acc > best_acc:
+ best_epoch = epoch
+ best_acc = test_acc
+
+ logger.info(f'Best epoch: {best_epoch}, '
+ f'Best test acc: {best_acc: .4f}.')
+
+ # save model
+ model.save_network(
+ f'{opt["path"]["models"]}/parsing_generation_epoch{epoch}.pth')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/Text2Human/train_parsing_token.py b/Text2Human/train_parsing_token.py
new file mode 100644
index 0000000000000000000000000000000000000000..c58effd1f8bd1e622a7757da01f30b3de7007831
--- /dev/null
+++ b/Text2Human/train_parsing_token.py
@@ -0,0 +1,122 @@
+import argparse
+import logging
+import os
+import os.path as osp
+import random
+import time
+
+import torch
+
+from data.mask_dataset import MaskDataset
+from models import create_model
+from utils.logger import MessageLogger, get_root_logger, init_tb_logger
+from utils.options import dict2str, dict_to_nonedict, parse
+from utils.util import make_exp_dirs
+
+
+def main():
+ # options
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-opt', type=str, help='Path to option YAML file.')
+ args = parser.parse_args()
+ opt = parse(args.opt, is_train=True)
+
+ # mkdir and loggers
+ make_exp_dirs(opt)
+ log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
+ logger = get_root_logger(
+ logger_name='base', log_level=logging.INFO, log_file=log_file)
+ logger.info(dict2str(opt))
+ # initialize tensorboard logger
+ tb_logger = None
+ if opt['use_tb_logger'] and 'debug' not in opt['name']:
+ tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
+
+ # convert to NoneDict, which returns None for missing keys
+ opt = dict_to_nonedict(opt)
+
+ # set up data loader
+ train_dataset = MaskDataset(
+ segm_dir=opt['segm_dir'], ann_dir=opt['train_ann_file'], xflip=True)
+ train_loader = torch.utils.data.DataLoader(
+ dataset=train_dataset,
+ batch_size=opt['batch_size'],
+ shuffle=True,
+ num_workers=opt['num_workers'],
+ persistent_workers=True,
+ drop_last=True)
+ logger.info(f'Number of train set: {len(train_dataset)}.')
+ opt['max_iters'] = opt['num_epochs'] * len(
+ train_dataset) // opt['batch_size']
+
+ val_dataset = MaskDataset(
+ segm_dir=opt['segm_dir'], ann_dir=opt['val_ann_file'])
+ val_loader = torch.utils.data.DataLoader(
+ dataset=val_dataset, batch_size=1, shuffle=False)
+ logger.info(f'Number of val set: {len(val_dataset)}.')
+
+ test_dataset = MaskDataset(
+ segm_dir=opt['segm_dir'], ann_dir=opt['test_ann_file'])
+ test_loader = torch.utils.data.DataLoader(
+ dataset=test_dataset, batch_size=1, shuffle=False)
+ logger.info(f'Number of test set: {len(test_dataset)}.')
+
+ current_iter = 0
+ best_epoch = None
+ best_loss = 100000
+
+ model = create_model(opt)
+
+ data_time, iter_time = 0, 0
+ current_iter = 0
+
+ # create message logger (formatted outputs)
+ msg_logger = MessageLogger(opt, current_iter, tb_logger)
+
+ for epoch in range(opt['num_epochs']):
+ lr = model.update_learning_rate(epoch)
+
+ for _, batch_data in enumerate(train_loader):
+ data_time = time.time() - data_time
+
+ current_iter += 1
+
+ model.optimize_parameters(batch_data, current_iter)
+
+ iter_time = time.time() - iter_time
+ if current_iter % opt['print_freq'] == 0:
+ log_vars = {'epoch': epoch, 'iter': current_iter}
+ log_vars.update({'lrs': [lr]})
+ log_vars.update({'time': iter_time, 'data_time': data_time})
+ log_vars.update(model.get_current_log())
+ msg_logger(log_vars)
+
+ data_time = time.time()
+ iter_time = time.time()
+
+ if epoch % opt['val_freq'] == 0:
+ save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa
+ os.makedirs(save_dir, exist_ok=opt['debug'])
+ val_loss_total, _, _ = model.inference(val_loader, save_dir)
+
+ save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa
+ os.makedirs(save_dir, exist_ok=opt['debug'])
+ test_loss_total, _, _ = model.inference(test_loader, save_dir)
+
+ logger.info(f'Epoch: {epoch}, '
+ f'val_loss_total: {val_loss_total}, '
+ f'test_loss_total: {test_loss_total}.')
+
+ if test_loss_total < best_loss:
+ best_epoch = epoch
+ best_loss = test_loss_total
+
+ logger.info(f'Best epoch: {best_epoch}, '
+ f'Best test loss: {best_loss: .4f}.')
+
+ # save model
+ model.save_network(f'{opt["path"]["models"]}/epoch{epoch}.pth')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/Text2Human/train_sampler.py b/Text2Human/train_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..65f2fc975d519d70915a31bda04063314f4dbdf2
--- /dev/null
+++ b/Text2Human/train_sampler.py
@@ -0,0 +1,122 @@
+import argparse
+import logging
+import os
+import os.path as osp
+import random
+import time
+
+import torch
+
+from data.segm_attr_dataset import DeepFashionAttrSegmDataset
+from models import create_model
+from utils.logger import MessageLogger, get_root_logger, init_tb_logger
+from utils.options import dict2str, dict_to_nonedict, parse
+from utils.util import make_exp_dirs
+
+
+def main():
+ # options
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-opt', type=str, help='Path to option YAML file.')
+ args = parser.parse_args()
+ opt = parse(args.opt, is_train=True)
+
+ # mkdir and loggers
+ make_exp_dirs(opt)
+ log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
+ logger = get_root_logger(
+ logger_name='base', log_level=logging.INFO, log_file=log_file)
+ logger.info(dict2str(opt))
+ # initialize tensorboard logger
+ tb_logger = None
+ if opt['use_tb_logger'] and 'debug' not in opt['name']:
+ tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
+
+ # convert to NoneDict, which returns None for missing keys
+ opt = dict_to_nonedict(opt)
+
+ # set up data loader
+ train_dataset = DeepFashionAttrSegmDataset(
+ img_dir=opt['train_img_dir'],
+ segm_dir=opt['segm_dir'],
+ pose_dir=opt['pose_dir'],
+ ann_dir=opt['train_ann_file'],
+ xflip=True)
+ train_loader = torch.utils.data.DataLoader(
+ dataset=train_dataset,
+ batch_size=opt['batch_size'],
+ shuffle=True,
+ num_workers=opt['num_workers'],
+ persistent_workers=True,
+ drop_last=True)
+ logger.info(f'Number of train set: {len(train_dataset)}.')
+ opt['max_iters'] = opt['num_epochs'] * len(
+ train_dataset) // opt['batch_size']
+
+ val_dataset = DeepFashionAttrSegmDataset(
+ img_dir=opt['train_img_dir'],
+ segm_dir=opt['segm_dir'],
+ pose_dir=opt['pose_dir'],
+ ann_dir=opt['val_ann_file'])
+ val_loader = torch.utils.data.DataLoader(
+ dataset=val_dataset, batch_size=opt['batch_size'], shuffle=False)
+ logger.info(f'Number of val set: {len(val_dataset)}.')
+
+ test_dataset = DeepFashionAttrSegmDataset(
+ img_dir=opt['test_img_dir'],
+ segm_dir=opt['segm_dir'],
+ pose_dir=opt['pose_dir'],
+ ann_dir=opt['test_ann_file'])
+ test_loader = torch.utils.data.DataLoader(
+ dataset=test_dataset, batch_size=opt['batch_size'], shuffle=False)
+ logger.info(f'Number of test set: {len(test_dataset)}.')
+
+ current_iter = 0
+
+ model = create_model(opt)
+
+ data_time, iter_time = 0, 0
+ current_iter = 0
+
+ # create message logger (formatted outputs)
+ msg_logger = MessageLogger(opt, current_iter, tb_logger)
+
+ for epoch in range(opt['num_epochs']):
+ lr = model.update_learning_rate(epoch, current_iter)
+
+ for _, batch_data in enumerate(train_loader):
+ data_time = time.time() - data_time
+
+ current_iter += 1
+
+ model.feed_data(batch_data)
+ model.optimize_parameters()
+
+ iter_time = time.time() - iter_time
+ if current_iter % opt['print_freq'] == 0:
+ log_vars = {'epoch': epoch, 'iter': current_iter}
+ log_vars.update({'lrs': [lr]})
+ log_vars.update({'time': iter_time, 'data_time': data_time})
+ log_vars.update(model.get_current_log())
+ msg_logger(log_vars)
+
+ data_time = time.time()
+ iter_time = time.time()
+
+ if epoch % opt['val_freq'] == 0 and epoch != 0:
+ save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa
+ os.makedirs(save_dir, exist_ok=opt['debug'])
+ model.inference(val_loader, save_dir)
+
+ save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa
+ os.makedirs(save_dir, exist_ok=opt['debug'])
+ model.inference(test_loader, save_dir)
+
+ # save model
+ model.save_network(
+ model._denoise_fn,
+ f'{opt["path"]["models"]}/sampler_epoch{epoch}.pth')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/Text2Human/train_vqvae.py b/Text2Human/train_vqvae.py
new file mode 100644
index 0000000000000000000000000000000000000000..107702af553e9acb6281586b447279006b304e24
--- /dev/null
+++ b/Text2Human/train_vqvae.py
@@ -0,0 +1,132 @@
+import argparse
+import logging
+import os
+import os.path as osp
+import random
+import time
+
+import torch
+
+from data.segm_attr_dataset import DeepFashionAttrSegmDataset
+from models import create_model
+from utils.logger import MessageLogger, get_root_logger, init_tb_logger
+from utils.options import dict2str, dict_to_nonedict, parse
+from utils.util import make_exp_dirs
+
+
+def main():
+ # options
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-opt', type=str, help='Path to option YAML file.')
+ args = parser.parse_args()
+ opt = parse(args.opt, is_train=True)
+
+ # mkdir and loggers
+ make_exp_dirs(opt)
+ log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
+ logger = get_root_logger(
+ logger_name='base', log_level=logging.INFO, log_file=log_file)
+ logger.info(dict2str(opt))
+ # initialize tensorboard logger
+ tb_logger = None
+ if opt['use_tb_logger'] and 'debug' not in opt['name']:
+ tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
+
+ # convert to NoneDict, which returns None for missing keys
+ opt = dict_to_nonedict(opt)
+
+ # set up data loader
+ train_dataset = DeepFashionAttrSegmDataset(
+ img_dir=opt['train_img_dir'],
+ segm_dir=opt['segm_dir'],
+ pose_dir=opt['pose_dir'],
+ ann_dir=opt['train_ann_file'],
+ xflip=True)
+ train_loader = torch.utils.data.DataLoader(
+ dataset=train_dataset,
+ batch_size=opt['batch_size'],
+ shuffle=True,
+ num_workers=opt['num_workers'],
+ persistent_workers=True,
+ drop_last=True)
+ logger.info(f'Number of train set: {len(train_dataset)}.')
+ opt['max_iters'] = opt['num_epochs'] * len(
+ train_dataset) // opt['batch_size']
+
+ val_dataset = DeepFashionAttrSegmDataset(
+ img_dir=opt['train_img_dir'],
+ segm_dir=opt['segm_dir'],
+ pose_dir=opt['pose_dir'],
+ ann_dir=opt['val_ann_file'])
+ val_loader = torch.utils.data.DataLoader(
+ dataset=val_dataset, batch_size=1, shuffle=False)
+ logger.info(f'Number of val set: {len(val_dataset)}.')
+
+ test_dataset = DeepFashionAttrSegmDataset(
+ img_dir=opt['test_img_dir'],
+ segm_dir=opt['segm_dir'],
+ pose_dir=opt['pose_dir'],
+ ann_dir=opt['test_ann_file'])
+ test_loader = torch.utils.data.DataLoader(
+ dataset=test_dataset, batch_size=1, shuffle=False)
+ logger.info(f'Number of test set: {len(test_dataset)}.')
+
+ current_iter = 0
+ best_epoch = None
+ best_loss = 100000
+
+ model = create_model(opt)
+
+ data_time, iter_time = 0, 0
+ current_iter = 0
+
+ # create message logger (formatted outputs)
+ msg_logger = MessageLogger(opt, current_iter, tb_logger)
+
+ for epoch in range(opt['num_epochs']):
+ lr = model.update_learning_rate(epoch)
+
+ for _, batch_data in enumerate(train_loader):
+ data_time = time.time() - data_time
+
+ current_iter += 1
+
+ model.optimize_parameters(batch_data, current_iter)
+
+ iter_time = time.time() - iter_time
+ if current_iter % opt['print_freq'] == 0:
+ log_vars = {'epoch': epoch, 'iter': current_iter}
+ log_vars.update({'lrs': [lr]})
+ log_vars.update({'time': iter_time, 'data_time': data_time})
+ log_vars.update(model.get_current_log())
+ msg_logger(log_vars)
+
+ data_time = time.time()
+ iter_time = time.time()
+
+ if epoch % opt['val_freq'] == 0:
+ save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa
+ os.makedirs(save_dir, exist_ok=opt['debug'])
+ val_loss_total = model.inference(val_loader, save_dir)
+
+ save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa
+ os.makedirs(save_dir, exist_ok=opt['debug'])
+ test_loss_total = model.inference(test_loader, save_dir)
+
+ logger.info(f'Epoch: {epoch}, '
+ f'val_loss_total: {val_loss_total}, '
+ f'test_loss_total: {test_loss_total}.')
+
+ if test_loss_total < best_loss:
+ best_epoch = epoch
+ best_loss = test_loss_total
+
+ logger.info(f'Best epoch: {best_epoch}, '
+ f'Best test loss: {best_loss: .4f}.')
+
+ # save model
+ model.save_network(f'{opt["path"]["models"]}/epoch{epoch}.pth')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/Text2Human/ui/__init__.py b/Text2Human/ui/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Text2Human/ui/mouse_event.py b/Text2Human/ui/mouse_event.py
new file mode 100644
index 0000000000000000000000000000000000000000..87c5f85e0fde810bb72c0814352e30f475900d34
--- /dev/null
+++ b/Text2Human/ui/mouse_event.py
@@ -0,0 +1,129 @@
+# -*- coding: utf-8 -*-
+
+import numpy as np
+from PyQt5.QtCore import *
+from PyQt5.QtGui import *
+from PyQt5.QtWidgets import *
+
+color_list = [
+ QColor(0, 0, 0),
+ QColor(255, 250, 250),
+ QColor(220, 220, 220),
+ QColor(250, 235, 215),
+ QColor(255, 250, 205),
+ QColor(211, 211, 211),
+ QColor(70, 130, 180),
+ QColor(127, 255, 212),
+ QColor(0, 100, 0),
+ QColor(50, 205, 50),
+ QColor(255, 255, 0),
+ QColor(245, 222, 179),
+ QColor(255, 140, 0),
+ QColor(255, 0, 0),
+ QColor(16, 78, 139),
+ QColor(144, 238, 144),
+ QColor(50, 205, 174),
+ QColor(50, 155, 250),
+ QColor(160, 140, 88),
+ QColor(213, 140, 88),
+ QColor(90, 140, 90),
+ QColor(185, 210, 205),
+ QColor(130, 165, 180),
+ QColor(225, 141, 151)
+]
+
+
+class GraphicsScene(QGraphicsScene):
+
+ def __init__(self, mode, size, parent=None):
+ QGraphicsScene.__init__(self, parent)
+ self.mode = mode
+ self.size = size
+ self.mouse_clicked = False
+ self.prev_pt = None
+
+ # self.masked_image = None
+
+ # save the points
+ self.mask_points = []
+ for i in range(len(color_list)):
+ self.mask_points.append([])
+
+ # save the size of points
+ self.size_points = []
+ for i in range(len(color_list)):
+ self.size_points.append([])
+
+ # save the history of edit
+ self.history = []
+
+ def reset(self):
+ # save the points
+ self.mask_points = []
+ for i in range(len(color_list)):
+ self.mask_points.append([])
+ # save the size of points
+ self.size_points = []
+ for i in range(len(color_list)):
+ self.size_points.append([])
+ # save the history of edit
+ self.history = []
+
+ self.mode = 0
+ self.prev_pt = None
+
+ def mousePressEvent(self, event):
+ self.mouse_clicked = True
+
+ def mouseReleaseEvent(self, event):
+ self.prev_pt = None
+ self.mouse_clicked = False
+
+ def mouseMoveEvent(self, event): # drawing
+ if self.mouse_clicked:
+ if self.prev_pt:
+ self.drawMask(self.prev_pt, event.scenePos(),
+ color_list[self.mode], self.size)
+ pts = {}
+ pts['prev'] = (int(self.prev_pt.x()), int(self.prev_pt.y()))
+ pts['curr'] = (int(event.scenePos().x()),
+ int(event.scenePos().y()))
+
+ self.size_points[self.mode].append(self.size)
+ self.mask_points[self.mode].append(pts)
+ self.history.append(self.mode)
+ self.prev_pt = event.scenePos()
+ else:
+ self.prev_pt = event.scenePos()
+
+ def drawMask(self, prev_pt, curr_pt, color, size):
+ lineItem = QGraphicsLineItem(QLineF(prev_pt, curr_pt))
+ lineItem.setPen(QPen(color, size, Qt.SolidLine)) # rect
+ self.addItem(lineItem)
+
+ def erase_prev_pt(self):
+ self.prev_pt = None
+
+ def reset_items(self):
+ for i in range(len(self.items())):
+ item = self.items()[0]
+ self.removeItem(item)
+
+ def undo(self):
+ if len(self.items()) > 1:
+ if len(self.items()) >= 9:
+ for i in range(8):
+ item = self.items()[0]
+ self.removeItem(item)
+ if self.history[-1] == self.mode:
+ self.mask_points[self.mode].pop()
+ self.size_points[self.mode].pop()
+ self.history.pop()
+ else:
+ for i in range(len(self.items()) - 1):
+ item = self.items()[0]
+ self.removeItem(item)
+ if self.history[-1] == self.mode:
+ self.mask_points[self.mode].pop()
+ self.size_points[self.mode].pop()
+ self.history.pop()
diff --git a/Text2Human/ui/ui.py b/Text2Human/ui/ui.py
new file mode 100644
index 0000000000000000000000000000000000000000..179a5ee796d5f8561eacf64b16d9a713b64983cf
--- /dev/null
+++ b/Text2Human/ui/ui.py
@@ -0,0 +1,313 @@
+from PyQt5 import QtCore, QtGui, QtWidgets
+from PyQt5.QtCore import *
+from PyQt5.QtGui import *
+from PyQt5.QtWidgets import *
+
+
+class Ui_Form(object):
+
+ def setupUi(self, Form):
+ Form.setObjectName("Form")
+ Form.resize(1250, 670)
+
+ self.pushButton_2 = QtWidgets.QPushButton(Form)
+ self.pushButton_2.setGeometry(QtCore.QRect(20, 60, 97, 27))
+ self.pushButton_2.setObjectName("pushButton_2")
+
+ self.pushButton_6 = QtWidgets.QPushButton(Form)
+ self.pushButton_6.setGeometry(QtCore.QRect(20, 100, 97, 27))
+ self.pushButton_6.setObjectName("pushButton_6")
+
+ # Generate Parsing
+ self.pushButton_0 = QtWidgets.QPushButton(Form)
+ self.pushButton_0.setGeometry(QtCore.QRect(126, 60, 150, 27))
+ self.pushButton_0.setObjectName("pushButton_0")
+
+ # Generate Human
+ self.pushButton_1 = QtWidgets.QPushButton(Form)
+ self.pushButton_1.setGeometry(QtCore.QRect(126, 100, 150, 27))
+ self.pushButton_1.setObjectName("pushButton_1")
+
+ # shape text box
+ self.label_heading_1 = QtWidgets.QLabel(Form)
+ self.label_heading_1.setText('Describe the shape.')
+ self.label_heading_1.setObjectName("label_heading_1")
+ self.label_heading_1.setGeometry(QtCore.QRect(320, 20, 200, 20))
+
+ self.message_box_1 = QtWidgets.QLineEdit(Form)
+ self.message_box_1.setGeometry(QtCore.QRect(320, 50, 256, 80))
+ self.message_box_1.setObjectName("message_box_1")
+ self.message_box_1.setAlignment(Qt.AlignTop)
+
+ # texture text box
+ self.label_heading_2 = QtWidgets.QLabel(Form)
+ self.label_heading_2.setText('Describe the textures.')
+ self.label_heading_2.setObjectName("label_heading_2")
+ self.label_heading_2.setGeometry(QtCore.QRect(620, 20, 200, 20))
+
+ self.message_box_2 = QtWidgets.QLineEdit(Form)
+ self.message_box_2.setGeometry(QtCore.QRect(620, 50, 256, 80))
+ self.message_box_2.setObjectName("message_box_2")
+ self.message_box_2.setAlignment(Qt.AlignTop)
+
+ # title icon
+ self.title_icon = QtWidgets.QLabel(Form)
+ self.title_icon.setGeometry(QtCore.QRect(30, 10, 200, 50))
+ self.title_icon.setPixmap(
+ QtGui.QPixmap('./ui/icons/icon_title.png').scaledToWidth(200))
+
+ # palette icon
+ self.palette_icon = QtWidgets.QLabel(Form)
+ self.palette_icon.setGeometry(QtCore.QRect(950, 10, 256, 128))
+ self.palette_icon.setPixmap(
+ QtGui.QPixmap('./ui/icons/icon_palette.png').scaledToWidth(256))
+
+ # top
+ self.pushButton_8 = QtWidgets.QPushButton(' top', Form)
+ self.pushButton_8.setGeometry(QtCore.QRect(940, 120, 120, 27))
+ self.pushButton_8.setObjectName("pushButton_8")
+ self.pushButton_8.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_8.setIcon(QIcon('./ui/color_blocks/class_top.png'))
+ # skin
+ self.pushButton_9 = QtWidgets.QPushButton(' skin', Form)
+ self.pushButton_9.setGeometry(QtCore.QRect(940, 165, 120, 27))
+ self.pushButton_9.setObjectName("pushButton_9")
+ self.pushButton_9.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_9.setIcon(QIcon('./ui/color_blocks/class_skin.png'))
+ # outer
+ self.pushButton_10 = QtWidgets.QPushButton(' outer', Form)
+ self.pushButton_10.setGeometry(QtCore.QRect(940, 210, 120, 27))
+ self.pushButton_10.setObjectName("pushButton_10")
+ self.pushButton_10.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_10.setIcon(QIcon('./ui/color_blocks/class_outer.png'))
+ # face
+ self.pushButton_11 = QtWidgets.QPushButton(' face', Form)
+ self.pushButton_11.setGeometry(QtCore.QRect(940, 255, 120, 27))
+ self.pushButton_11.setObjectName("pushButton_11")
+ self.pushButton_11.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_11.setIcon(QIcon('./ui/color_blocks/class_face.png'))
+ # skirt
+ self.pushButton_12 = QtWidgets.QPushButton(' skirt', Form)
+ self.pushButton_12.setGeometry(QtCore.QRect(940, 300, 120, 27))
+ self.pushButton_12.setObjectName("pushButton_12")
+ self.pushButton_12.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_12.setIcon(QIcon('./ui/color_blocks/class_skirt.png'))
+ # hair
+ self.pushButton_13 = QtWidgets.QPushButton(' hair', Form)
+ self.pushButton_13.setGeometry(QtCore.QRect(940, 345, 120, 27))
+ self.pushButton_13.setObjectName("pushButton_13")
+ self.pushButton_13.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_13.setIcon(QIcon('./ui/color_blocks/class_hair.png'))
+ # dress
+ self.pushButton_14 = QtWidgets.QPushButton(' dress', Form)
+ self.pushButton_14.setGeometry(QtCore.QRect(940, 390, 120, 27))
+ self.pushButton_14.setObjectName("pushButton_14")
+ self.pushButton_14.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_14.setIcon(QIcon('./ui/color_blocks/class_dress.png'))
+ # headwear
+ self.pushButton_15 = QtWidgets.QPushButton(' headwear', Form)
+ self.pushButton_15.setGeometry(QtCore.QRect(940, 435, 120, 27))
+ self.pushButton_15.setObjectName("pushButton_15")
+ self.pushButton_15.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_15.setIcon(
+ QIcon('./ui/color_blocks/class_headwear.png'))
+ # pants
+ self.pushButton_16 = QtWidgets.QPushButton(' pants', Form)
+ self.pushButton_16.setGeometry(QtCore.QRect(940, 480, 120, 27))
+ self.pushButton_16.setObjectName("pushButton_16")
+ self.pushButton_16.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_16.setIcon(QIcon('./ui/color_blocks/class_pants.png'))
+ # eyeglasses
+ self.pushButton_17 = QtWidgets.QPushButton(' eyeglass', Form)
+ self.pushButton_17.setGeometry(QtCore.QRect(940, 525, 120, 27))
+ self.pushButton_17.setObjectName("pushButton_17")
+ self.pushButton_17.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_17.setIcon(
+ QIcon('./ui/color_blocks/class_eyeglass.png'))
+ # rompers
+ self.pushButton_18 = QtWidgets.QPushButton(' rompers', Form)
+ self.pushButton_18.setGeometry(QtCore.QRect(940, 570, 120, 27))
+ self.pushButton_18.setObjectName("pushButton_18")
+ self.pushButton_18.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_18.setIcon(
+ QIcon('./ui/color_blocks/class_rompers.png'))
+ # footwear
+ self.pushButton_19 = QtWidgets.QPushButton(' footwear', Form)
+ self.pushButton_19.setGeometry(QtCore.QRect(940, 615, 120, 27))
+ self.pushButton_19.setObjectName("pushButton_19")
+ self.pushButton_19.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_19.setIcon(
+ QIcon('./ui/color_blocks/class_footwear.png'))
+
+ # leggings
+ self.pushButton_20 = QtWidgets.QPushButton(' leggings', Form)
+ self.pushButton_20.setGeometry(QtCore.QRect(1100, 120, 120, 27))
+ self.pushButton_20.setObjectName("pushButton_10")
+ self.pushButton_20.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_20.setIcon(
+ QIcon('./ui/color_blocks/class_leggings.png'))
+
+ # ring
+ self.pushButton_21 = QtWidgets.QPushButton(' ring', Form)
+ self.pushButton_21.setGeometry(QtCore.QRect(1100, 165, 120, 27))
+ self.pushButton_21.setObjectName("pushButton_2`0`")
+ self.pushButton_21.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_21.setIcon(QIcon('./ui/color_blocks/class_ring.png'))
+
+ # belt
+ self.pushButton_22 = QtWidgets.QPushButton(' belt', Form)
+ self.pushButton_22.setGeometry(QtCore.QRect(1100, 210, 120, 27))
+ self.pushButton_22.setObjectName("pushButton_2`0`")
+ self.pushButton_22.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_22.setIcon(QIcon('./ui/color_blocks/class_belt.png'))
+
+ # neckwear
+ self.pushButton_23 = QtWidgets.QPushButton(' neckwear', Form)
+ self.pushButton_23.setGeometry(QtCore.QRect(1100, 255, 120, 27))
+ self.pushButton_23.setObjectName("pushButton_2`0`")
+ self.pushButton_23.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_23.setIcon(
+ QIcon('./ui/color_blocks/class_neckwear.png'))
+
+ # wrist
+ self.pushButton_24 = QtWidgets.QPushButton(' wrist', Form)
+ self.pushButton_24.setGeometry(QtCore.QRect(1100, 300, 120, 27))
+ self.pushButton_24.setObjectName("pushButton_2`0`")
+ self.pushButton_24.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_24.setIcon(QIcon('./ui/color_blocks/class_wrist.png'))
+
+ # socks
+ self.pushButton_25 = QtWidgets.QPushButton(' socks', Form)
+ self.pushButton_25.setGeometry(QtCore.QRect(1100, 345, 120, 27))
+ self.pushButton_25.setObjectName("pushButton_2`0`")
+ self.pushButton_25.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_25.setIcon(QIcon('./ui/color_blocks/class_socks.png'))
+
+ # tie
+ self.pushButton_26 = QtWidgets.QPushButton(' tie', Form)
+ self.pushButton_26.setGeometry(QtCore.QRect(1100, 390, 120, 27))
+ self.pushButton_26.setObjectName("pushButton_2`0`")
+ self.pushButton_26.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_26.setIcon(QIcon('./ui/color_blocks/class_tie.png'))
+
+ # earstuds
+ self.pushButton_27 = QtWidgets.QPushButton(' necklace', Form)
+ self.pushButton_27.setGeometry(QtCore.QRect(1100, 435, 120, 27))
+ self.pushButton_27.setObjectName("pushButton_2`0`")
+ self.pushButton_27.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_27.setIcon(
+ QIcon('./ui/color_blocks/class_necklace.png'))
+
+ # necklace
+ self.pushButton_28 = QtWidgets.QPushButton(' earstuds', Form)
+ self.pushButton_28.setGeometry(QtCore.QRect(1100, 480, 120, 27))
+ self.pushButton_28.setObjectName("pushButton_2`0`")
+ self.pushButton_28.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_28.setIcon(
+ QIcon('./ui/color_blocks/class_earstuds.png'))
+
+ # bag
+ self.pushButton_29 = QtWidgets.QPushButton(' bag', Form)
+ self.pushButton_29.setGeometry(QtCore.QRect(1100, 525, 120, 27))
+ self.pushButton_29.setObjectName("pushButton_2`0`")
+ self.pushButton_29.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_29.setIcon(QIcon('./ui/color_blocks/class_bag.png'))
+
+ # glove
+ self.pushButton_30 = QtWidgets.QPushButton(' glove', Form)
+ self.pushButton_30.setGeometry(QtCore.QRect(1100, 570, 120, 27))
+ self.pushButton_30.setObjectName("pushButton_2`0`")
+ self.pushButton_30.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_30.setIcon(QIcon('./ui/color_blocks/class_glove.png'))
+
+ # background
+ self.pushButton_31 = QtWidgets.QPushButton(' background', Form)
+ self.pushButton_31.setGeometry(QtCore.QRect(1100, 615, 120, 27))
+ self.pushButton_31.setObjectName("pushButton_2`0`")
+ self.pushButton_31.setStyleSheet(
+ "text-align: left; padding-left: 10px;")
+ self.pushButton_31.setIcon(QIcon('./ui/color_blocks/class_bg.png'))
+
+ self.graphicsView = QtWidgets.QGraphicsView(Form)
+ self.graphicsView.setGeometry(QtCore.QRect(20, 140, 256, 512))
+ self.graphicsView.setObjectName("graphicsView")
+ self.graphicsView_2 = QtWidgets.QGraphicsView(Form)
+ self.graphicsView_2.setGeometry(QtCore.QRect(320, 140, 256, 512))
+ self.graphicsView_2.setObjectName("graphicsView_2")
+ self.graphicsView_3 = QtWidgets.QGraphicsView(Form)
+ self.graphicsView_3.setGeometry(QtCore.QRect(620, 140, 256, 512))
+ self.graphicsView_3.setObjectName("graphicsView_3")
+
+ self.retranslateUi(Form)
+ self.pushButton_2.clicked.connect(Form.open_densepose)
+ self.pushButton_6.clicked.connect(Form.save_img)
+ self.pushButton_8.clicked.connect(Form.top_mode)
+ self.pushButton_9.clicked.connect(Form.skin_mode)
+ self.pushButton_10.clicked.connect(Form.outer_mode)
+ self.pushButton_11.clicked.connect(Form.face_mode)
+ self.pushButton_12.clicked.connect(Form.skirt_mode)
+ self.pushButton_13.clicked.connect(Form.hair_mode)
+ self.pushButton_14.clicked.connect(Form.dress_mode)
+ self.pushButton_15.clicked.connect(Form.headwear_mode)
+ self.pushButton_16.clicked.connect(Form.pants_mode)
+ self.pushButton_17.clicked.connect(Form.eyeglass_mode)
+ self.pushButton_18.clicked.connect(Form.rompers_mode)
+ self.pushButton_19.clicked.connect(Form.footwear_mode)
+ self.pushButton_20.clicked.connect(Form.leggings_mode)
+ self.pushButton_21.clicked.connect(Form.ring_mode)
+ self.pushButton_22.clicked.connect(Form.belt_mode)
+ self.pushButton_23.clicked.connect(Form.neckwear_mode)
+ self.pushButton_24.clicked.connect(Form.wrist_mode)
+ self.pushButton_25.clicked.connect(Form.socks_mode)
+ self.pushButton_26.clicked.connect(Form.tie_mode)
+ self.pushButton_27.clicked.connect(Form.earstuds_mode)
+ self.pushButton_28.clicked.connect(Form.necklace_mode)
+ self.pushButton_29.clicked.connect(Form.bag_mode)
+ self.pushButton_30.clicked.connect(Form.glove_mode)
+ self.pushButton_31.clicked.connect(Form.background_mode)
+ self.pushButton_0.clicked.connect(Form.generate_parsing)
+ self.pushButton_1.clicked.connect(Form.generate_human)
+
+ QtCore.QMetaObject.connectSlotsByName(Form)
+
+ def retranslateUi(self, Form):
+ _translate = QtCore.QCoreApplication.translate
+ Form.setWindowTitle(_translate("Form", "Text2Human"))
+ self.pushButton_2.setText(_translate("Form", "Load Pose"))
+ self.pushButton_6.setText(_translate("Form", "Save Image"))
+
+ self.pushButton_0.setText(_translate("Form", "Generate Parsing"))
+ self.pushButton_1.setText(_translate("Form", "Generate Human"))
+
+
+if __name__ == "__main__":
+ import sys
+ app = QtWidgets.QApplication(sys.argv)
+ Form = QtWidgets.QWidget()
+ ui = Ui_Form()
+ ui.setupUi(Form)
+ Form.show()
+ sys.exit(app.exec_())
diff --git a/Text2Human/ui_demo.py b/Text2Human/ui_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..a06cca245ab04a9420d428e681db1f4f2f5a03c2
--- /dev/null
+++ b/Text2Human/ui_demo.py
@@ -0,0 +1,285 @@
+import sys
+
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+from PyQt5.QtCore import *
+from PyQt5.QtGui import *
+from PyQt5.QtWidgets import *
+
+from models.sample_model import SampleFromPoseModel
+from ui.mouse_event import GraphicsScene
+from ui.ui import Ui_Form
+from utils.language_utils import (generate_shape_attributes,
+ generate_texture_attributes)
+from utils.options import dict_to_nonedict, parse
+
+color_list = [(0, 0, 0), (255, 250, 250), (220, 220, 220), (250, 235, 215),
+ (255, 250, 205), (211, 211, 211), (70, 130, 180),
+ (127, 255, 212), (0, 100, 0), (50, 205, 50), (255, 255, 0),
+ (245, 222, 179), (255, 140, 0), (255, 0, 0), (16, 78, 139),
+ (144, 238, 144), (50, 205, 174), (50, 155, 250), (160, 140, 88),
+ (213, 140, 88), (90, 140, 90), (185, 210, 205), (130, 165, 180),
+ (225, 141, 151)]
+
+
+class Ex(QWidget, Ui_Form):
+
+ def __init__(self, opt):
+ super(Ex, self).__init__()
+ self.setupUi(self)
+ self.show()
+
+ self.output_img = None
+
+ self.mat_img = None
+
+ self.mode = 0
+ self.size = 6
+ self.mask = None
+ self.mask_m = None
+ self.img = None
+
+ # about UI
+ self.mouse_clicked = False
+ self.scene = QGraphicsScene()
+ self.graphicsView.setScene(self.scene)
+ self.graphicsView.setAlignment(Qt.AlignTop | Qt.AlignLeft)
+ self.graphicsView.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
+ self.graphicsView.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
+
+ self.ref_scene = GraphicsScene(self.mode, self.size)
+ self.graphicsView_2.setScene(self.ref_scene)
+ self.graphicsView_2.setAlignment(Qt.AlignTop | Qt.AlignLeft)
+ self.graphicsView_2.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
+ self.graphicsView_2.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
+
+ self.result_scene = QGraphicsScene()
+ self.graphicsView_3.setScene(self.result_scene)
+ self.graphicsView_3.setAlignment(Qt.AlignTop | Qt.AlignLeft)
+ self.graphicsView_3.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
+ self.graphicsView_3.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
+
+ self.dlg = QColorDialog(self.graphicsView)
+ self.color = None
+
+ self.sample_model = SampleFromPoseModel(opt)
+
+ def open_densepose(self):
+ fileName, _ = QFileDialog.getOpenFileName(self, "Open File",
+ QDir.currentPath())
+ if fileName:
+ image = QPixmap(fileName)
+ mat_img = Image.open(fileName)
+ self.pose_img = mat_img.copy()
+ if image.isNull():
+ QMessageBox.information(self, "Image Viewer",
+ "Cannot load %s." % fileName)
+ return
+ image = image.scaled(self.graphicsView.size(),
+ Qt.IgnoreAspectRatio)
+
+ if len(self.scene.items()) > 0:
+ self.scene.removeItem(self.scene.items()[-1])
+ self.scene.addPixmap(image)
+
+ self.ref_scene.clear()
+ self.result_scene.clear()
+
+ # load pose to model
+ self.pose_img = np.array(
+ self.pose_img.resize(
+ size=(256, 512),
+ resample=Image.LANCZOS))[:, :, 2:].transpose(
+ 2, 0, 1).astype(np.float32)
+ self.pose_img = self.pose_img / 12. - 1
+
+ self.pose_img = torch.from_numpy(self.pose_img).unsqueeze(1)
+
+ self.sample_model.feed_pose_data(self.pose_img)
+
+ def generate_parsing(self):
+ self.ref_scene.reset_items()
+ self.ref_scene.reset()
+
+ shape_texts = self.message_box_1.text()
+
+ shape_attributes = generate_shape_attributes(shape_texts)
+ shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0)
+ self.sample_model.feed_shape_attributes(shape_attributes)
+
+ self.sample_model.generate_parsing_map()
+ self.sample_model.generate_quantized_segm()
+
+ self.colored_segm = self.sample_model.palette_result(
+ self.sample_model.segm[0].cpu())
+
+ self.mask_m = cv2.cvtColor(
+ cv2.cvtColor(self.colored_segm, cv2.COLOR_RGB2BGR),
+ cv2.COLOR_BGR2RGB)
+
+ qim = QImage(self.colored_segm.data.tobytes(),
+ self.colored_segm.shape[1], self.colored_segm.shape[0],
+ QImage.Format_RGB888)
+
+ image = QPixmap.fromImage(qim)
+
+ image = image.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio)
+
+ if len(self.ref_scene.items()) > 0:
+ self.ref_scene.removeItem(self.ref_scene.items()[-1])
+ self.ref_scene.addPixmap(image)
+
+ self.result_scene.clear()
+
+ def generate_human(self):
+ for i in range(24):
+ self.mask_m = self.make_mask(self.mask_m,
+ self.ref_scene.mask_points[i],
+ self.ref_scene.size_points[i],
+ color_list[i])
+
+ seg_map = np.full(self.mask_m.shape[:-1], -1)
+
+ # convert rgb to num
+ for index, color in enumerate(color_list):
+ seg_map[np.sum(self.mask_m == color, axis=2) == 3] = index
+ assert (seg_map != -1).all()
+
+ self.sample_model.segm = torch.from_numpy(seg_map).unsqueeze(
+ 0).unsqueeze(0).to(self.sample_model.device)
+ self.sample_model.generate_quantized_segm()
+
+ texture_texts = self.message_box_2.text()
+ texture_attributes = generate_texture_attributes(texture_texts)
+
+ texture_attributes = torch.LongTensor(texture_attributes)
+
+ self.sample_model.feed_texture_attributes(texture_attributes)
+
+ self.sample_model.generate_texture_map()
+ result = self.sample_model.sample_and_refine()
+ result = result.permute(0, 2, 3, 1)
+ result = result.detach().cpu().numpy()
+ result = result * 255
+
+ result = np.asarray(result[0, :, :, :], dtype=np.uint8)
+
+ self.output_img = result
+
+ qim = QImage(result.data.tobytes(), result.shape[1], result.shape[0],
+ QImage.Format_RGB888)
+ image = QPixmap.fromImage(qim)
+
+ image = image.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio)
+
+ if len(self.result_scene.items()) > 0:
+ self.result_scene.removeItem(self.result_scene.items()[-1])
+ self.result_scene.addPixmap(image)
+
+ def top_mode(self):
+ self.ref_scene.mode = 1
+
+ def skin_mode(self):
+ self.ref_scene.mode = 15
+
+ def outer_mode(self):
+ self.ref_scene.mode = 2
+
+ def face_mode(self):
+ self.ref_scene.mode = 14
+
+ def skirt_mode(self):
+ self.ref_scene.mode = 3
+
+ def hair_mode(self):
+ self.ref_scene.mode = 13
+
+ def dress_mode(self):
+ self.ref_scene.mode = 4
+
+ def headwear_mode(self):
+ self.ref_scene.mode = 7
+
+ def pants_mode(self):
+ self.ref_scene.mode = 5
+
+ def eyeglass_mode(self):
+ self.ref_scene.mode = 8
+
+ def rompers_mode(self):
+ self.ref_scene.mode = 21
+
+ def footwear_mode(self):
+ self.ref_scene.mode = 11
+
+ def leggings_mode(self):
+ self.ref_scene.mode = 6
+
+ def ring_mode(self):
+ self.ref_scene.mode = 16
+
+ def belt_mode(self):
+ self.ref_scene.mode = 10
+
+ def neckwear_mode(self):
+ self.ref_scene.mode = 9
+
+ def wrist_mode(self):
+ self.ref_scene.mode = 17
+
+ def socks_mode(self):
+ self.ref_scene.mode = 18
+
+ def tie_mode(self):
+ self.ref_scene.mode = 23
+
+ def earstuds_mode(self):
+ self.ref_scene.mode = 22
+
+ def necklace_mode(self):
+ self.ref_scene.mode = 20
+
+ def bag_mode(self):
+ self.ref_scene.mode = 12
+
+ def glove_mode(self):
+ self.ref_scene.mode = 19
+
+ def background_mode(self):
+ self.ref_scene.mode = 0
+
+ def make_mask(self, mask, pts, sizes, color):
+ if len(pts) > 0:
+ for idx, pt in enumerate(pts):
+ cv2.line(mask, pt['prev'], pt['curr'], color, sizes[idx])
+ return mask
+
+ def save_img(self):
+ if type(self.output_img):
+ fileName, _ = QFileDialog.getSaveFileName(self, "Save File",
+ QDir.currentPath())
+ cv2.imwrite(fileName + '.png', self.output_img[:, :, ::-1])
+
+ def undo(self):
+ self.scene.undo()
+
+ def clear(self):
+
+ self.ref_scene.reset_items()
+ self.ref_scene.reset()
+
+ self.ref_scene.clear()
+
+ self.result_scene.clear()
+
+
+if __name__ == '__main__':
+
+ app = QApplication(sys.argv)
+ opt = './configs/sample_from_pose.yml'
+ opt = parse(opt, is_train=False)
+ opt = dict_to_nonedict(opt)
+ ex = Ex(opt)
+ sys.exit(app.exec_())
diff --git a/Text2Human/ui_util/__init__.py b/Text2Human/ui_util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Text2Human/ui_util/config.py b/Text2Human/ui_util/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b3094b872bfb4077c6397369e1b3db401acde2d
--- /dev/null
+++ b/Text2Human/ui_util/config.py
@@ -0,0 +1,25 @@
+import argparse
+import logging
+import os
+
+import yaml
+
+logger = logging.getLogger()
+
+class Config(object):
+ def __init__(self, filename=None):
+ assert os.path.exists(filename), "ERROR: Config File doesn't exist."
+ try:
+ with open(filename, 'r') as f:
+ self._cfg_dict = yaml.load(f)
+ # parent of IOError, OSError *and* WindowsError where available
+ except EnvironmentError:
+ logger.error('Please check the file with name of "%s"', filename)
+ logger.info(' APP CONFIG '.center(80, '-'))
+ logger.info(''.center(80, '-'))
+
+ def __getattr__(self, name):
+ value = self._cfg_dict[name]
+ if isinstance(value, dict):
+ value = DictAsMember(value)
+ return value
diff --git a/Text2Human/utils/__init__.py b/Text2Human/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Text2Human/utils/language_utils.py b/Text2Human/utils/language_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb2ef69b3001f10b20069f40ec0141d28260482f
--- /dev/null
+++ b/Text2Human/utils/language_utils.py
@@ -0,0 +1,315 @@
+from curses import A_ATTRIBUTES
+
+import numpy
+import torch
+from pip import main
+from sentence_transformers import SentenceTransformer, util
+
+# predefined shape text
+upper_length_text = [
+ 'sleeveless', 'without sleeves', 'sleeves have been cut off', 'tank top',
+ 'tank shirt', 'muscle shirt', 'short-sleeve', 'short sleeves',
+ 'with short sleeves', 'medium-sleeve', 'medium sleeves',
+ 'with medium sleeves', 'sleeves reach elbow', 'long-sleeve',
+ 'long sleeves', 'with long sleeves'
+]
+upper_length_attr = {
+ 'sleeveless': 0,
+ 'without sleeves': 0,
+ 'sleeves have been cut off': 0,
+ 'tank top': 0,
+ 'tank shirt': 0,
+ 'muscle shirt': 0,
+ 'short-sleeve': 1,
+ 'with short sleeves': 1,
+ 'short sleeves': 1,
+ 'medium-sleeve': 2,
+ 'with medium sleeves': 2,
+ 'medium sleeves': 2,
+ 'sleeves reach elbow': 2,
+ 'long-sleeve': 3,
+ 'long sleeves': 3,
+ 'with long sleeves': 3
+}
+lower_length_text = [
+ 'three-point', 'medium', 'short', 'covering knee', 'cropped',
+ 'three-quarter', 'long', 'slack', 'of long length'
+]
+lower_length_attr = {
+ 'three-point': 0,
+ 'medium': 1,
+ 'covering knee': 1,
+ 'short': 1,
+ 'cropped': 2,
+ 'three-quarter': 2,
+ 'long': 3,
+ 'slack': 3,
+ 'of long length': 3
+}
+socks_length_text = [
+ 'socks', 'stocking', 'pantyhose', 'leggings', 'sheer hosiery'
+]
+socks_length_attr = {
+ 'socks': 0,
+ 'stocking': 1,
+ 'pantyhose': 1,
+ 'leggings': 1,
+ 'sheer hosiery': 1
+}
+hat_text = ['hat', 'cap', 'chapeau']
+eyeglasses_text = ['sunglasses']
+belt_text = ['belt', 'with a dress tied around the waist']
+outer_shape_text = [
+ 'with outer clothing open', 'with outer clothing unzipped',
+ 'covering inner clothes', 'with outer clothing zipped'
+]
+outer_shape_attr = {
+ 'with outer clothing open': 0,
+ 'with outer clothing unzipped': 0,
+ 'covering inner clothes': 1,
+ 'with outer clothing zipped': 1
+}
+
+upper_types = [
+ 'T-shirt', 'shirt', 'sweater', 'hoodie', 'tops', 'blouse', 'Basic Tee'
+]
+outer_types = [
+ 'jacket', 'outer clothing', 'coat', 'overcoat', 'blazer', 'outerwear',
+ 'duffle', 'cardigan'
+]
+skirt_types = ['skirt']
+dress_types = ['dress']
+pant_types = ['jeans', 'pants', 'trousers']
+rompers_types = ['rompers', 'bodysuit', 'jumpsuit']
+
+attr_names_list = [
+ 'gender', 'hair length', '0 upper clothing length',
+ '1 lower clothing length', '2 socks', '3 hat', '4 eyeglasses', '5 belt',
+ '6 opening of outer clothing', '7 upper clothes', '8 outer clothing',
+ '9 skirt', '10 dress', '11 pants', '12 rompers'
+]
+
+
+def generate_shape_attributes(user_shape_texts):
+ model = SentenceTransformer('all-MiniLM-L6-v2')
+ parsed_texts = user_shape_texts.split(',')
+
+ text_num = len(parsed_texts)
+
+ human_attr = [0, 0]
+ attr = [1, 3, 0, 0, 0, 3, 1, 1, 0, 0, 0, 0, 0]
+
+ changed = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ for text_id, text in enumerate(parsed_texts):
+ user_embeddings = model.encode(text)
+ if ('man' in text) and (text_id == 0):
+ human_attr[0] = 0
+ human_attr[1] = 0
+
+ if ('woman' in text or 'lady' in text) and (text_id == 0):
+ human_attr[0] = 1
+ human_attr[1] = 2
+
+ if (not changed[0]) and (text_id == 1):
+ # upper length
+ predefined_embeddings = model.encode(upper_length_text)
+ similarities = util.dot_score(user_embeddings,
+ predefined_embeddings)
+ arg_idx = torch.argmax(similarities).item()
+ attr[0] = upper_length_attr[upper_length_text[arg_idx]]
+ changed[0] = 1
+
+ if (not changed[1]) and ((text_num == 2 and text_id == 1) or
+ (text_num > 2 and text_id == 2)):
+ # lower length
+ predefined_embeddings = model.encode(lower_length_text)
+ similarities = util.dot_score(user_embeddings,
+ predefined_embeddings)
+ arg_idx = torch.argmax(similarities).item()
+ attr[1] = lower_length_attr[lower_length_text[arg_idx]]
+ changed[1] = 1
+
+ if (not changed[2]) and (text_id > 2):
+ # socks length
+ predefined_embeddings = model.encode(socks_length_text)
+ similarities = util.dot_score(user_embeddings,
+ predefined_embeddings)
+ arg_idx = torch.argmax(similarities).item()
+ if similarities[0][arg_idx] > 0.7:
+ attr[2] = arg_idx + 1
+ changed[2] = 1
+
+ if (not changed[3]) and (text_id > 2):
+ # hat
+ predefined_embeddings = model.encode(hat_text)
+ similarities = util.dot_score(user_embeddings,
+ predefined_embeddings)
+ if similarities[0][0] > 0.7:
+ attr[3] = 1
+ changed[3] = 1
+
+ if (not changed[4]) and (text_id > 2):
+ # glasses
+ predefined_embeddings = model.encode(eyeglasses_text)
+ similarities = util.dot_score(user_embeddings,
+ predefined_embeddings)
+ arg_idx = torch.argmax(similarities).item()
+ if similarities[0][arg_idx] > 0.7:
+ attr[4] = arg_idx + 1
+ changed[4] = 1
+
+ if (not changed[5]) and (text_id > 2):
+ # belt
+ predefined_embeddings = model.encode(belt_text)
+ similarities = util.dot_score(user_embeddings,
+ predefined_embeddings)
+ arg_idx = torch.argmax(similarities).item()
+ if similarities[0][arg_idx] > 0.7:
+ attr[5] = arg_idx + 1
+ changed[5] = 1
+
+ if (not changed[6]) and (text_id == 3):
+ # outer coverage
+ predefined_embeddings = model.encode(outer_shape_text)
+ similarities = util.dot_score(user_embeddings,
+ predefined_embeddings)
+ arg_idx = torch.argmax(similarities).item()
+ if similarities[0][arg_idx] > 0.7:
+ attr[6] = arg_idx
+ changed[6] = 1
+
+ if (not changed[10]) and (text_num == 2 and text_id == 1):
+ # dress_types
+ predefined_embeddings = model.encode(dress_types)
+ similarities = util.dot_score(user_embeddings,
+ predefined_embeddings)
+ similarity_skirt = util.dot_score(user_embeddings,
+ model.encode(skirt_types))
+ if similarities[0][0] > 0.5 and similarities[0][
+ 0] > similarity_skirt[0][0]:
+ attr[10] = 1
+ attr[7] = 0
+ attr[8] = 0
+ attr[9] = 0
+ attr[11] = 0
+ attr[12] = 0
+
+ changed[0] = 1
+ changed[10] = 1
+ changed[7] = 1
+ changed[8] = 1
+ changed[9] = 1
+ changed[11] = 1
+ changed[12] = 1
+
+ if (not changed[12]) and (text_num == 2 and text_id == 1):
+ # rompers_types
+ predefined_embeddings = model.encode(rompers_types)
+ similarities = util.dot_score(user_embeddings,
+ predefined_embeddings)
+ max_similarity = torch.max(similarities).item()
+ if max_similarity > 0.6:
+ attr[12] = 1
+ attr[7] = 0
+ attr[8] = 0
+ attr[9] = 0
+ attr[10] = 0
+ attr[11] = 0
+
+ changed[12] = 1
+ changed[7] = 1
+ changed[8] = 1
+ changed[9] = 1
+ changed[10] = 1
+ changed[11] = 1
+
+ if (not changed[7]) and (text_num > 2 and text_id == 1):
+ # upper_types
+ predefined_embeddings = model.encode(upper_types)
+ similarities = util.dot_score(user_embeddings,
+ predefined_embeddings)
+ max_similarity = torch.max(similarities).item()
+ if max_similarity > 0.6:
+ attr[7] = 1
+ changed[7] = 1
+
+ if (not changed[8]) and (text_id == 3):
+ # outer_types
+ predefined_embeddings = model.encode(outer_types)
+ similarities = util.dot_score(user_embeddings,
+ predefined_embeddings)
+ arg_idx = torch.argmax(similarities).item()
+ if similarities[0][arg_idx] > 0.7:
+ attr[6] = outer_shape_attr[outer_shape_text[arg_idx]]
+ attr[8] = 1
+ changed[8] = 1
+
+ if (not changed[9]) and (text_num > 2 and text_id == 2):
+ # skirt_types
+ predefined_embeddings = model.encode(skirt_types)
+ similarity_skirt = util.dot_score(user_embeddings,
+ predefined_embeddings)
+ similarity_dress = util.dot_score(user_embeddings,
+ model.encode(dress_types))
+ if similarity_skirt[0][0] > 0.7 and similarity_skirt[0][
+ 0] > similarity_dress[0][0]:
+ attr[9] = 1
+ attr[10] = 0
+ changed[9] = 1
+ changed[10] = 1
+
+ if (not changed[11]) and (text_num > 2 and text_id == 2):
+ # pant_types
+ predefined_embeddings = model.encode(pant_types)
+ similarities = util.dot_score(user_embeddings,
+ predefined_embeddings)
+ max_similarity = torch.max(similarities).item()
+ if max_similarity > 0.6:
+ attr[11] = 1
+ attr[9] = 0
+ attr[10] = 0
+ attr[12] = 0
+ changed[11] = 1
+ changed[9] = 1
+ changed[10] = 1
+ changed[12] = 1
+
+ return human_attr + attr
+
+
+def generate_texture_attributes(user_text):
+ parsed_texts = user_text.split(',')
+
+ attr = []
+ for text in parsed_texts:
+ if ('pure color' in text) or ('solid color' in text):
+ attr.append(4)
+ elif ('spline' in text) or ('stripe' in text):
+ attr.append(3)
+ elif ('plaid' in text) or ('lattice' in text):
+ attr.append(5)
+ elif 'floral' in text:
+ attr.append(1)
+ elif 'denim' in text:
+ attr.append(0)
+ else:
+ attr.append(17)
+
+ if len(attr) == 1:
+ attr.append(attr[0])
+ attr.append(17)
+
+ if len(attr) == 2:
+ attr.append(17)
+
+ return attr
+
+
+if __name__ == "__main__":
+ user_request = input('Enter your request: ')
+ while user_request != '\\q':
+ attr = generate_shape_attributes(user_request)
+ print(attr)
+ for attr_name, attr_value in zip(attr_names_list, attr):
+ print(attr_name, attr_value)
+ user_request = input('Enter your request: ')
diff --git a/Text2Human/utils/logger.py b/Text2Human/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fee1a2b221c4d219206fd8f3201db3b52566adb
--- /dev/null
+++ b/Text2Human/utils/logger.py
@@ -0,0 +1,112 @@
+import datetime
+import logging
+import time
+
+
+class MessageLogger():
+ """Message logger for printing.
+
+ Args:
+ opt (dict): Config. It contains the following keys:
+ name (str): Exp name.
+ logger (dict): Contains 'print_freq' (str) for logger interval.
+ train (dict): Contains 'niter' (int) for total iters.
+ use_tb_logger (bool): Use tensorboard logger.
+ start_iter (int): Start iter. Default: 1.
+ tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
+ """
+
+ def __init__(self, opt, start_iter=1, tb_logger=None):
+ self.exp_name = opt['name']
+ self.interval = opt['print_freq']
+ self.start_iter = start_iter
+ self.max_iters = opt['max_iters']
+ self.use_tb_logger = opt['use_tb_logger']
+ self.tb_logger = tb_logger
+ self.start_time = time.time()
+ self.logger = get_root_logger()
+
+ def __call__(self, log_vars):
+ """Format logging message.
+
+ Args:
+ log_vars (dict): It contains the following keys:
+ epoch (int): Epoch number.
+ iter (int): Current iter.
+ lrs (list): List for learning rates.
+
+ time (float): Iter time.
+ data_time (float): Data time for each iter.
+ """
+ # epoch, iter, learning rates
+ epoch = log_vars.pop('epoch')
+ current_iter = log_vars.pop('iter')
+ lrs = log_vars.pop('lrs')
+
+ message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, '
+ f'iter:{current_iter:8,d}, lr:(')
+ for v in lrs:
+ message += f'{v:.3e},'
+ message += ')] '
+
+ # time and estimated time
+ if 'time' in log_vars.keys():
+ iter_time = log_vars.pop('time')
+ data_time = log_vars.pop('data_time')
+
+ total_time = time.time() - self.start_time
+ time_sec_avg = total_time / (current_iter - self.start_iter + 1)
+ eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
+ eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
+ message += f'[eta: {eta_str}, '
+ message += f'time: {iter_time:.3f}, data_time: {data_time:.3f}] '
+
+ # other items, especially losses
+ for k, v in log_vars.items():
+ message += f'{k}: {v:.4e} '
+ # tensorboard logger
+ if self.use_tb_logger and 'debug' not in self.exp_name:
+ self.tb_logger.add_scalar(k, v, current_iter)
+
+ self.logger.info(message)
+
+
+def init_tb_logger(log_dir):
+ from torch.utils.tensorboard import SummaryWriter
+ tb_logger = SummaryWriter(log_dir=log_dir)
+ return tb_logger
+
+
+def get_root_logger(logger_name='base', log_level=logging.INFO, log_file=None):
+ """Get the root logger.
+
+ The logger will be initialized if it has not been initialized. By default a
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
+ also be added.
+
+ Args:
+ logger_name (str): root logger name. Default: base.
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the root logger.
+ log_level (int): The root logger level. Note that only the process of
+ rank 0 is affected, while other processes will set the level to
+ "Error" and be silent most of the time.
+
+ Returns:
+ logging.Logger: The root logger.
+ """
+ logger = logging.getLogger(logger_name)
+ # if the logger has been initialized, just return it
+ if logger.hasHandlers():
+ return logger
+
+ format_str = '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s'
+ logging.basicConfig(format=format_str, level=log_level)
+
+ if log_file is not None:
+ file_handler = logging.FileHandler(log_file, 'w')
+ file_handler.setFormatter(logging.Formatter(format_str))
+ file_handler.setLevel(log_level)
+ logger.addHandler(file_handler)
+
+ return logger
diff --git a/Text2Human/utils/options.py b/Text2Human/utils/options.py
new file mode 100644
index 0000000000000000000000000000000000000000..1045dd07381bd680b623d0187be5353e2e3dee80
--- /dev/null
+++ b/Text2Human/utils/options.py
@@ -0,0 +1,129 @@
+import os
+import os.path as osp
+from collections import OrderedDict
+
+import yaml
+
+
+def ordered_yaml():
+ """Support OrderedDict for yaml.
+
+ Returns:
+ yaml Loader and Dumper.
+ """
+ try:
+ from yaml import CDumper as Dumper
+ from yaml import CLoader as Loader
+ except ImportError:
+ from yaml import Dumper, Loader
+
+ _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
+
+ def dict_representer(dumper, data):
+ return dumper.represent_dict(data.items())
+
+ def dict_constructor(loader, node):
+ return OrderedDict(loader.construct_pairs(node))
+
+ Dumper.add_representer(OrderedDict, dict_representer)
+ Loader.add_constructor(_mapping_tag, dict_constructor)
+ return Loader, Dumper
+
+
+def parse(opt_path, is_train=True):
+ """Parse option file.
+
+ Args:
+ opt_path (str): Option file path.
+ is_train (str): Indicate whether in training or not. Default: True.
+
+ Returns:
+ (dict): Options.
+ """
+ with open(opt_path, mode='r') as f:
+ Loader, _ = ordered_yaml()
+ opt = yaml.load(f, Loader=Loader)
+
+ gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
+ if opt.get('set_CUDA_VISIBLE_DEVICES', None):
+ os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
+ print('export CUDA_VISIBLE_DEVICES=' + gpu_list, flush=True)
+ else:
+ print('gpu_list: ', gpu_list, flush=True)
+
+ opt['is_train'] = is_train
+
+ # paths
+ opt['path'] = {}
+ opt['path']['root'] = osp.abspath(
+ osp.join(__file__, osp.pardir, osp.pardir))
+ if is_train:
+ experiments_root = osp.join(opt['path']['root'], 'experiments',
+ opt['name'])
+ opt['path']['experiments_root'] = experiments_root
+ opt['path']['models'] = osp.join(experiments_root, 'models')
+ opt['path']['log'] = experiments_root
+ opt['path']['visualization'] = osp.join(experiments_root,
+ 'visualization')
+
+ # change some options for debug mode
+ if 'debug' in opt['name']:
+ opt['debug'] = True
+ opt['val_freq'] = 1
+ opt['print_freq'] = 1
+ opt['save_checkpoint_freq'] = 1
+ else: # test
+ results_root = osp.join(opt['path']['root'], 'results', opt['name'])
+ opt['path']['results_root'] = results_root
+ opt['path']['log'] = results_root
+ opt['path']['visualization'] = osp.join(results_root, 'visualization')
+
+ return opt
+
+
+def dict2str(opt, indent_level=1):
+ """dict to string for printing options.
+
+ Args:
+ opt (dict): Option dict.
+ indent_level (int): Indent level. Default: 1.
+
+ Return:
+ (str): Option string for printing.
+ """
+ msg = ''
+ for k, v in opt.items():
+ if isinstance(v, dict):
+ msg += ' ' * (indent_level * 2) + k + ':[\n'
+ msg += dict2str(v, indent_level + 1)
+ msg += ' ' * (indent_level * 2) + ']\n'
+ else:
+ msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
+ return msg
+
+
+class NoneDict(dict):
+ """None dict. It will return none if key is not in the dict."""
+
+ def __missing__(self, key):
+ return None
+
+
+def dict_to_nonedict(opt):
+ """Convert to NoneDict, which returns None for missing keys.
+
+ Args:
+ opt (dict): Option dict.
+
+ Returns:
+ (dict): NoneDict for options.
+ """
+ if isinstance(opt, dict):
+ new_opt = dict()
+ for key, sub_opt in opt.items():
+ new_opt[key] = dict_to_nonedict(sub_opt)
+ return NoneDict(**new_opt)
+ elif isinstance(opt, list):
+ return [dict_to_nonedict(sub_opt) for sub_opt in opt]
+ else:
+ return opt
diff --git a/Text2Human/utils/util.py b/Text2Human/utils/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..f51663ff61a7ebb0b6c3b34633dcf417039fb762
--- /dev/null
+++ b/Text2Human/utils/util.py
@@ -0,0 +1,123 @@
+import logging
+import os
+import random
+import sys
+import time
+from shutil import get_terminal_size
+
+import numpy as np
+import torch
+
+logger = logging.getLogger('base')
+
+
+def make_exp_dirs(opt):
+ """Make dirs for experiments."""
+ path_opt = opt['path'].copy()
+ if opt['is_train']:
+ overwrite = True if 'debug' in opt['name'] else False
+ os.makedirs(path_opt.pop('experiments_root'), exist_ok=overwrite)
+ os.makedirs(path_opt.pop('models'), exist_ok=overwrite)
+ else:
+ os.makedirs(path_opt.pop('results_root'))
+
+
+def set_random_seed(seed):
+ """Set random seeds."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+class ProgressBar(object):
+ """A progress bar which can print the progress.
+
+ Modified from:
+ https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
+ """
+
+ def __init__(self, task_num=0, bar_width=50, start=True):
+ self.task_num = task_num
+ max_bar_width = self._get_max_bar_width()
+ self.bar_width = (
+ bar_width if bar_width <= max_bar_width else max_bar_width)
+ self.completed = 0
+ if start:
+ self.start()
+
+ def _get_max_bar_width(self):
+ terminal_width, _ = get_terminal_size()
+ max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
+ if max_bar_width < 10:
+ print(f'terminal width is too small ({terminal_width}), '
+ 'please consider widen the terminal for better '
+ 'progressbar visualization')
+ max_bar_width = 10
+ return max_bar_width
+
+ def start(self):
+ if self.task_num > 0:
+ sys.stdout.write(f"[{' ' * self.bar_width}] 0/{self.task_num}, "
+ f'elapsed: 0s, ETA:\nStart...\n')
+ else:
+ sys.stdout.write('completed: 0, elapsed: 0s')
+ sys.stdout.flush()
+ self.start_time = time.time()
+
+ def update(self, msg='In progress...'):
+ self.completed += 1
+ elapsed = time.time() - self.start_time
+ fps = self.completed / elapsed
+ if self.task_num > 0:
+ percentage = self.completed / float(self.task_num)
+ eta = int(elapsed * (1 - percentage) / percentage + 0.5)
+ mark_width = int(self.bar_width * percentage)
+ bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
+ sys.stdout.write('\033[2F') # cursor up 2 lines
+ sys.stdout.write(
+ '\033[J'
+ ) # clean the output (remove extra chars since last display)
+ sys.stdout.write(
+ f'[{bar_chars}] {self.completed}/{self.task_num}, '
+ f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, '
+ f'ETA: {eta:5}s\n{msg}\n')
+ else:
+ sys.stdout.write(
+ f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s, '
+ f'{fps:.1f} tasks/s')
+ sys.stdout.flush()
+
+
+class AverageMeter(object):
+ """
+ Computes and stores the average and current value
+ Imported from
+ https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
+ """
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0 # running average = running sum / running count
+ self.sum = 0 # running sum
+ self.count = 0 # running count
+
+ def update(self, val, n=1):
+ # n = batch_size
+
+ # val = batch accuracy for an attribute
+ # self.val = val
+
+ # sum = 100 * accumulative correct predictions for this attribute
+ self.sum += val * n
+
+ # count = total samples so far
+ self.count += n
+
+ # avg = 100 * avg accuracy for this attribute
+ # for all the batches so far
+ self.avg = self.sum / self.count
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5e32059b050bb09ba0808f6a9434c78e6f31964
--- /dev/null
+++ b/app.py
@@ -0,0 +1,188 @@
+#!/usr/bin/env python
+
+from __future__ import annotations
+
+import argparse
+import os
+import pathlib
+import subprocess
+
+import gradio as gr
+
+if os.getenv("SYSTEM") == "spaces":
+ import mim
+
+ mim.uninstall("mmcv-full", confirm_yes=True)
+ mim.install("mmcv-full==1.5.2", is_yes=True)
+
+ with open("patch") as f:
+ subprocess.run("patch -p1".split(), cwd="Text2Human", stdin=f)
+
+from model import Model
+
+DESCRIPTION = """# Text2Human
+
+- Algorthm is original from https://github.com/yumingj/Text2Human made by @hysts. Thanks for it's awesome work.
+
+- By varying seeds, you can sample different human images under the same pose, shape description, and texture description. The larger the sample steps, the better quality of the generated images. (The default value of sample steps is 256 in the original repo.)
+
+- Label image generation step can be skipped. However, in that case, the input label image must be 512x256 in size and must contain only the specified colors.
+"""
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--device", type=str, default="cpu")
+ parser.add_argument("--theme", type=str)
+ parser.add_argument("--share", action="store_true")
+ parser.add_argument("--port", type=int)
+ parser.add_argument("--disable-queue", dest="enable_queue", action="store_false")
+ return parser.parse_args()
+
+
+# def set_example_image(example: list) -> dict:
+# return gr.Image.update(value=example[0])
+
+
+def set_example_image(example: list) -> dict:
+ return gr.update(value=example[0]["path"])
+
+
+# def set_example_text(example: list) -> dict:
+# return gr.Textbox.change(value=example[0])
+
+
+def set_example_text(example: list) -> dict:
+ # Update the Textbox with the example text
+ return gr.update(value=example[0])
+
+
+def main():
+ args = parse_args()
+ print(args.device)
+ model = Model(args.device)
+
+ with gr.Blocks(theme=args.theme, css="style.css") as demo:
+ gr.Markdown(DESCRIPTION)
+
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ input_image = gr.Image(
+ label="Input Pose Image", type="pil", elem_id="input-image"
+ )
+ pose_data = gr.State()
+ with gr.Row():
+ paths = sorted(pathlib.Path("pose_images").glob("*.png"))
+ example_images = gr.Dataset(
+ components=[input_image],
+ samples=[[path.as_posix()] for path in paths],
+ )
+
+ with gr.Row():
+ shape_text = gr.Textbox(
+ label="Shape Description",
+ placeholder=""", , , , , ...
+Note: The outer clothing type and accessories can be omitted.""",
+ )
+ with gr.Row():
+ shape_example_texts = gr.Dataset(
+ components=[shape_text],
+ samples=[
+ ["man, sleeveless T-shirt, long pants"],
+ ["woman, short-sleeve T-shirt, short jeans"],
+ ],
+ )
+ with gr.Row():
+ generate_label_button = gr.Button("Generate Label Image")
+
+ with gr.Column():
+ with gr.Row():
+ label_image = gr.Image(
+ label="Label Image", type="numpy", elem_id="label-image"
+ )
+
+ with gr.Row():
+ texture_text = gr.Textbox(
+ label="Texture Description",
+ placeholder=""", ,
+Note: Currently, only 5 types of textures are supported, i.e., pure color, stripe/spline, plaid/lattice, floral, denim.""",
+ )
+ with gr.Row():
+ texture_example_texts = gr.Dataset(
+ components=[texture_text],
+ samples=[["pure color, denim"], ["floral, stripe"]],
+ )
+ with gr.Row():
+ sample_steps = gr.Slider(
+ 10, 300, value=10, step=10, label="Sample Steps"
+ )
+ with gr.Row():
+ seed = gr.Slider(0, 1000000, value=0, step=1, label="Seed")
+ with gr.Row():
+ generate_human_button = gr.Button("Generate Human")
+
+ with gr.Column():
+ with gr.Row():
+ result = gr.Image(
+ label="Result", type="numpy", elem_id="result-image"
+ )
+
+
+ input_image.change(
+ fn=model.process_pose_image, inputs=input_image, outputs=pose_data
+ )
+ generate_label_button.click(
+ fn=model.generate_label_image,
+ inputs=[
+ pose_data,
+ shape_text,
+ ],
+ outputs=label_image,
+ )
+ # generate_human_button.click(
+ # fn=model.generate_human,
+ # inputs=[
+ # label_image,
+ # texture_text,
+ # sample_steps,
+ # seed,
+ # ],
+ # outputs=result,
+ # )
+ generate_human_button.click(
+ fn=model.generate_human,
+ inputs=[
+ pose_data,
+ shape_text,
+ texture_text,
+ sample_steps,
+ seed,
+ ],
+ outputs=result,
+ )
+ example_images.click(
+ fn=set_example_image,
+ inputs=example_images,
+ outputs=example_images._components,
+ )
+ shape_example_texts.click(
+ fn=set_example_text,
+ inputs=shape_example_texts,
+ outputs=shape_example_texts._components,
+ )
+ texture_example_texts.click(
+ fn=set_example_text,
+ inputs=texture_example_texts,
+ outputs=texture_example_texts._components,
+ )
+
+ demo.launch(
+ # enable_queue=args.enable_queue,
+ server_port=args.port,
+ share=args.share,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/demo.py b/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bdb29eefb6a3ecd5adbb072cb93b4b4f0a3a277
--- /dev/null
+++ b/demo.py
@@ -0,0 +1,76 @@
+# -*- coding: utf-8 -*-
+
+# @title Load modules
+import os
+import random
+
+import numpy as np
+import torch
+from IPython.display import display
+from PIL import Image
+
+from model import Model
+
+# @title Load model
+device = "cuda" if torch.cuda.is_available() else "cpu"
+model = Model(device)
+print(
+ "Model loaded. Parameters:",
+ sum(
+ x.numel()
+ for y in [
+ model.model.shape_attr_embedder.parameters(),
+ model.model.shape_parsing_encoder.parameters(),
+ model.model.shape_parsing_decoder.parameters(),
+ ]
+ for x in y
+ ),
+)
+
+# @title Patch PIL
+from collections import namedtuple
+
+Image.Resampling = namedtuple("Patch", ["LANCZOS"])(Image.LANCZOS)
+
+"""# Usage"""
+
+# @title Generation parameters
+# @markdown Can be a URL or a file link (if you upload your own image)
+pose_image = Image.open("./001.png")
+# @markdown Shape text for the general shape, texture text for the color texture
+shape_text = "A lady with a T-shirt and a skirt" # @param {type: "string"}
+texture_text = "Lady wears a short-sleeve T-shirt with pure color pattern, and a short and denim skirt." # @param {type: "string"}
+steps = 50 # @param {type: "slider", min: 10, max:300, step: 10}
+
+seed = -1 # @param {type: "integer"}
+if seed == -1:
+ seed = random.getrandbits(16)
+print("Seed:", seed)
+
+# %%time
+# @title Generate label image
+print("Pose image:")
+display(pose_image)
+print(type(pose_image))
+print(pose_image.size)
+print("Shape description:", shape_text)
+label_image = model.generate_label_image(
+ pose_data=model.process_pose_image(pose_image), shape_text=shape_text
+)
+print("Label image:")
+print(np.sum(label_image == -1))
+display(Image.fromarray(label_image).resize((128, 256)))
+
+# Commented out IPython magic to ensure Python compatibility.
+# %%time
+# #@title Generate human image
+# print("Label mask:")
+# display(Image.fromarray(label_image).resize((128, 256)))
+# print("Texture text:", texture_text)
+# print("Generation steps:", steps)
+# result = model.generate_human(label_image=label_image,
+# texture_text=texture_text,
+# sample_steps=steps,
+# seed=0)
+# print("Resulting image:")
+# display(Image.fromarray(result))
diff --git a/final_image.jpg b/final_image.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..981cd13336e2e68c6f4e2f357f0f595daf13a1f1
Binary files /dev/null and b/final_image.jpg differ
diff --git a/model.py b/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b80f490158072b165e79c5dfa8355d028be584c7
--- /dev/null
+++ b/model.py
@@ -0,0 +1,202 @@
+from __future__ import annotations
+
+import os
+import pathlib
+import sys
+import zipfile
+
+import huggingface_hub
+import numpy as np
+import PIL.Image
+import torch
+
+sys.path.insert(0, 'Text2Human')
+
+from models.sample_model import SampleFromPoseModel
+from utils.language_utils import (generate_shape_attributes,
+ generate_texture_attributes)
+from utils.options import dict_to_nonedict, parse
+from utils.util import set_random_seed
+
+COLOR_LIST = [
+ (0, 0, 0),
+ (255, 250, 250),
+ (220, 220, 220),
+ (250, 235, 215),
+ (255, 250, 205),
+ (211, 211, 211),
+ (70, 130, 180),
+ (127, 255, 212),
+ (0, 100, 0),
+ (50, 205, 50),
+ (255, 255, 0),
+ (245, 222, 179),
+ (255, 140, 0),
+ (255, 0, 0),
+ (16, 78, 139),
+ (144, 238, 144),
+ (50, 205, 174),
+ (50, 155, 250),
+ (160, 140, 88),
+ (213, 140, 88),
+ (90, 140, 90),
+ (185, 210, 205),
+ (130, 165, 180),
+ (225, 141, 151),
+]
+
+
+class Model:
+ def __init__(self, device: str):
+ self.config = self._load_config()
+ self.config['device'] = device
+ self._download_models()
+ self.model = SampleFromPoseModel(self.config)
+ self.model.batch_size = 1
+
+ def _load_config(self) -> dict:
+ path = 'Text2Human/configs/sample_from_pose.yml'
+ config = parse(path, is_train=False)
+ config = dict_to_nonedict(config)
+ return config
+
+ def _download_models(self) -> None:
+ model_dir = pathlib.Path('pretrained_models')
+ if model_dir.exists():
+ return
+ token = os.getenv('HF_TOKEN')
+ path = huggingface_hub.hf_hub_download('yumingj/Text2Human_SSHQ',
+ 'pretrained_models.zip',
+ use_auth_token=token)
+ model_dir.mkdir()
+ with zipfile.ZipFile(path) as f:
+ f.extractall(model_dir)
+
+ @staticmethod
+ def preprocess_pose_image(image: PIL.Image.Image) -> torch.Tensor:
+ image = np.array(
+ image.resize(
+ size=(256, 512),
+ resample=PIL.Image.Resampling.LANCZOS))[:, :, 2:].transpose(
+ 2, 0, 1).astype(np.float32)
+ image = image / 12. - 1
+ data = torch.from_numpy(image).unsqueeze(1)
+ return data
+
+ @staticmethod
+ def process_mask(mask: np.ndarray) -> np.ndarray:
+ if mask.shape != (512, 256, 3):
+ return None
+ seg_map = np.full(mask.shape[:-1], -1)
+ for index, color in enumerate(COLOR_LIST):
+ seg_map[np.sum(mask == color, axis=2) == 3] = index
+ if not (seg_map != -1).all():
+ return None
+ return seg_map
+ # def process_mask(self, mask: np.ndarray) -> np.ndarray:
+ # if mask.shape != (512, 256, 3):
+ # return None
+ # seg_map = np.full(mask.shape[:-1], -1)
+ # for index, color in enumerate(COLOR_LIST):
+ # seg_map[np.sum(mask == color, axis=2) == 3] = index
+
+ # # 创建一个新的 3 通道图像用于输出结果
+ # result = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
+
+ # # 将匹配的像素分配对应的颜色
+ # for index, color in enumerate(COLOR_LIST):
+ # result[seg_map == index] = color
+
+ # # 将未匹配的像素设置为白色
+ # result[seg_map == -1] = (255, 250, 250)
+
+ # return result
+
+
+ @staticmethod
+ def postprocess(result: torch.Tensor) -> np.ndarray:
+ result = result.permute(0, 2, 3, 1)
+ result = result.detach().cpu().numpy()
+ result = result * 255
+ result = np.asarray(result[0, :, :, :], dtype=np.uint8)
+ return result
+
+ def process_pose_image(self, pose_image: PIL.Image.Image) -> torch.Tensor:
+ if pose_image is None:
+ return
+ data = self.preprocess_pose_image(pose_image)
+ self.model.feed_pose_data(data)
+ return data
+
+ def generate_label_image(self, pose_data: torch.Tensor,
+ shape_text: str) -> np.ndarray:
+ if pose_data is None:
+ return
+ self.model.feed_pose_data(pose_data)
+ shape_attributes = generate_shape_attributes(shape_text)
+ shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0)
+ self.model.feed_shape_attributes(shape_attributes)
+ self.model.generate_parsing_map()
+ self.model.generate_quantized_segm()
+ colored_segm = self.model.palette_result(self.model.segm[0].cpu())
+ return colored_segm
+
+ # def generate_human(self, label_image: np.ndarray, texture_text: str,
+ # sample_steps: int, seed: int) -> np.ndarray:
+ # if label_image is None:
+ # return
+ # mask = label_image.copy()
+ # seg_map = self.process_mask(mask)
+ # if seg_map is None:
+ # return
+ # self.model.segm = torch.from_numpy(seg_map).unsqueeze(0).unsqueeze(
+ # 0).to(self.model.device)
+ # self.model.generate_quantized_segm()
+
+ # set_random_seed(seed)
+
+ # texture_attributes = generate_texture_attributes(texture_text)
+ # texture_attributes = torch.LongTensor(texture_attributes)
+ # self.model.feed_texture_attributes(texture_attributes)
+ # self.model.generate_texture_map()
+
+ # self.model.sample_steps = sample_steps
+ # out = self.model.sample_and_refine()
+ # res = self.postprocess(out)
+ # return res
+ def generate_human(self,pose_data,shape_text,texture_text,sample_steps,seed):
+ if pose_data is None:
+ return
+ self.model.feed_pose_data(pose_data)
+ shape_attributes = generate_shape_attributes(shape_text)
+ shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0)
+ self.model.feed_shape_attributes(shape_attributes)
+ self.model.generate_parsing_map()
+ self.model.generate_quantized_segm()
+ set_random_seed(seed)
+
+ texture_attributes = generate_texture_attributes(texture_text)
+ texture_attributes = torch.LongTensor(texture_attributes)
+ self.model.feed_texture_attributes(texture_attributes)
+ self.model.generate_texture_map()
+
+ self.model.sample_steps = sample_steps
+ out = self.model.sample_and_refine()
+ res = self.postprocess(out)
+ return res
+
+if __name__ == "__main__":
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ model = Model(device)
+ pose_image = PIL.Image.open("./001.png")
+ input_image=model.process_pose_image(pose_image)
+ shape_text = "A lady with a T-shirt and a skirt"
+ # res = model.generate_label_image(pose_data=input_image, shape_text=shape_text)
+ # # PIL.Image.SAVE(res, "result.png")
+ # im = PIL.Image.fromarray(res)
+ # im.save("label_image.jpg")
+ # print(res.shape)
+ all_res = model.generate_human(pose_data=input_image,shape_text=shape_text,texture_text="A lady with a T-shirt and a skirt",sample_steps=10,seed=0)
+ final_im = PIL.Image.fromarray(all_res)
+ final_im.save("final_image.jpg")
+ print(all_res.shape)
\ No newline at end of file
diff --git a/patch b/patch
new file mode 100644
index 0000000000000000000000000000000000000000..875e83b569e51cd77c3c16840d41e7d9f1c3acfa
--- /dev/null
+++ b/patch
@@ -0,0 +1,169 @@
+diff --git a/models/hierarchy_inference_model.py b/models/hierarchy_inference_model.py
+index 3116307..5de661d 100644
+--- a/models/hierarchy_inference_model.py
++++ b/models/hierarchy_inference_model.py
+@@ -21,7 +21,7 @@ class VQGANTextureAwareSpatialHierarchyInferenceModel():
+
+ def __init__(self, opt):
+ self.opt = opt
+- self.device = torch.device('cuda')
++ self.device = torch.device(opt['device'])
+ self.is_train = opt['is_train']
+
+ self.top_encoder = Encoder(
+diff --git a/models/hierarchy_vqgan_model.py b/models/hierarchy_vqgan_model.py
+index 4b0d657..0bf4712 100644
+--- a/models/hierarchy_vqgan_model.py
++++ b/models/hierarchy_vqgan_model.py
+@@ -20,7 +20,7 @@ class HierarchyVQSpatialTextureAwareModel():
+
+ def __init__(self, opt):
+ self.opt = opt
+- self.device = torch.device('cuda')
++ self.device = torch.device(opt['device'])
+ self.top_encoder = Encoder(
+ ch=opt['top_ch'],
+ num_res_blocks=opt['top_num_res_blocks'],
+diff --git a/models/parsing_gen_model.py b/models/parsing_gen_model.py
+index 9440345..15a1ecb 100644
+--- a/models/parsing_gen_model.py
++++ b/models/parsing_gen_model.py
+@@ -22,7 +22,7 @@ class ParsingGenModel():
+
+ def __init__(self, opt):
+ self.opt = opt
+- self.device = torch.device('cuda')
++ self.device = torch.device(opt['device'])
+ self.is_train = opt['is_train']
+
+ self.attr_embedder = ShapeAttrEmbedding(
+diff --git a/models/sample_model.py b/models/sample_model.py
+index 4c60e3f..5265cd0 100644
+--- a/models/sample_model.py
++++ b/models/sample_model.py
+@@ -23,7 +23,7 @@ class BaseSampleModel():
+
+ def __init__(self, opt):
+ self.opt = opt
+- self.device = torch.device('cuda')
++ self.device = torch.device(opt['device'])
+
+ # hierarchical VQVAE
+ self.decoder = Decoder(
+@@ -123,7 +123,7 @@ class BaseSampleModel():
+
+ def load_top_pretrain_models(self):
+ # load pretrained vqgan
+- top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
++ top_vae_checkpoint = torch.load(self.opt['top_vae_path'], map_location=self.device)
+
+ self.decoder.load_state_dict(
+ top_vae_checkpoint['decoder'], strict=True)
+@@ -137,7 +137,7 @@ class BaseSampleModel():
+ self.top_post_quant_conv.eval()
+
+ def load_bot_pretrain_network(self):
+- checkpoint = torch.load(self.opt['bot_vae_path'])
++ checkpoint = torch.load(self.opt['bot_vae_path'], map_location=self.device)
+ self.bot_decoder_res.load_state_dict(
+ checkpoint['bot_decoder_res'], strict=True)
+ self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
+@@ -153,7 +153,7 @@ class BaseSampleModel():
+
+ def load_pretrained_segm_token(self):
+ # load pretrained vqgan for segmentation mask
+- segm_token_checkpoint = torch.load(self.opt['segm_token_path'])
++ segm_token_checkpoint = torch.load(self.opt['segm_token_path'], map_location=self.device)
+ self.segm_encoder.load_state_dict(
+ segm_token_checkpoint['encoder'], strict=True)
+ self.segm_quantizer.load_state_dict(
+@@ -166,7 +166,7 @@ class BaseSampleModel():
+ self.segm_quant_conv.eval()
+
+ def load_index_pred_network(self):
+- checkpoint = torch.load(self.opt['pretrained_index_network'])
++ checkpoint = torch.load(self.opt['pretrained_index_network'], map_location=self.device)
+ self.index_pred_guidance_encoder.load_state_dict(
+ checkpoint['guidance_encoder'], strict=True)
+ self.index_pred_decoder.load_state_dict(
+@@ -176,7 +176,7 @@ class BaseSampleModel():
+ self.index_pred_decoder.eval()
+
+ def load_sampler_pretrained_network(self):
+- checkpoint = torch.load(self.opt['pretrained_sampler'])
++ checkpoint = torch.load(self.opt['pretrained_sampler'], map_location=self.device)
+ self.sampler_fn.load_state_dict(checkpoint, strict=True)
+ self.sampler_fn.eval()
+
+@@ -397,7 +397,7 @@ class SampleFromPoseModel(BaseSampleModel):
+ [185, 210, 205], [130, 165, 180], [225, 141, 151]]
+
+ def load_shape_generation_models(self):
+- checkpoint = torch.load(self.opt['pretrained_parsing_gen'])
++ checkpoint = torch.load(self.opt['pretrained_parsing_gen'], map_location=self.device)
+
+ self.shape_attr_embedder.load_state_dict(
+ checkpoint['embedder'], strict=True)
+diff --git a/models/transformer_model.py b/models/transformer_model.py
+index 7db0f3e..4523d17 100644
+--- a/models/transformer_model.py
++++ b/models/transformer_model.py
+@@ -21,7 +21,7 @@ class TransformerTextureAwareModel():
+
+ def __init__(self, opt):
+ self.opt = opt
+- self.device = torch.device('cuda')
++ self.device = torch.device(opt['device'])
+ self.is_train = opt['is_train']
+
+ # VQVAE for image
+@@ -317,10 +317,10 @@ class TransformerTextureAwareModel():
+ def sample_fn(self, temp=1.0, sample_steps=None):
+ self._denoise_fn.eval()
+
+- b, device = self.image.size(0), 'cuda'
++ b = self.image.size(0)
+ x_t = torch.ones(
+- (b, np.prod(self.shape)), device=device).long() * self.mask_id
+- unmasked = torch.zeros_like(x_t, device=device).bool()
++ (b, np.prod(self.shape)), device=self.device).long() * self.mask_id
++ unmasked = torch.zeros_like(x_t, device=self.device).bool()
+ sample_steps = list(range(1, sample_steps + 1))
+
+ texture_mask_flatten = self.texture_tokens.view(-1)
+@@ -336,11 +336,11 @@ class TransformerTextureAwareModel():
+
+ for t in reversed(sample_steps):
+ print(f'Sample timestep {t:4d}', end='\r')
+- t = torch.full((b, ), t, device=device, dtype=torch.long)
++ t = torch.full((b, ), t, device=self.device, dtype=torch.long)
+
+ # where to unmask
+ changes = torch.rand(
+- x_t.shape, device=device) < 1 / t.float().unsqueeze(-1)
++ x_t.shape, device=self.device) < 1 / t.float().unsqueeze(-1)
+ # don't unmask somewhere already unmasked
+ changes = torch.bitwise_xor(changes,
+ torch.bitwise_and(changes, unmasked))
+diff --git a/models/vqgan_model.py b/models/vqgan_model.py
+index 13a2e70..9c840f1 100644
+--- a/models/vqgan_model.py
++++ b/models/vqgan_model.py
+@@ -20,7 +20,7 @@ class VQModel():
+ def __init__(self, opt):
+ super().__init__()
+ self.opt = opt
+- self.device = torch.device('cuda')
++ self.device = torch.device(opt['device'])
+ self.encoder = Encoder(
+ ch=opt['ch'],
+ num_res_blocks=opt['num_res_blocks'],
+@@ -390,7 +390,7 @@ class VQImageSegmTextureModel(VQImageModel):
+
+ def __init__(self, opt):
+ self.opt = opt
+- self.device = torch.device('cuda')
++ self.device = torch.device(opt['device'])
+ self.encoder = Encoder(
+ ch=opt['ch'],
+ num_res_blocks=opt['num_res_blocks'],
diff --git a/pose_images/000.png b/pose_images/000.png
new file mode 100644
index 0000000000000000000000000000000000000000..5ac17b76d375708117bf68d4e31a0a644a2f1f8c
Binary files /dev/null and b/pose_images/000.png differ
diff --git a/pose_images/001.png b/pose_images/001.png
new file mode 100644
index 0000000000000000000000000000000000000000..06b3319c92e999d22b68a76c76fc3a1bdb0cc4c3
Binary files /dev/null and b/pose_images/001.png differ
diff --git a/pose_images/002.png b/pose_images/002.png
new file mode 100644
index 0000000000000000000000000000000000000000..06764d11b32add939777793c70174da7550d288a
Binary files /dev/null and b/pose_images/002.png differ
diff --git a/pose_images/003.png b/pose_images/003.png
new file mode 100644
index 0000000000000000000000000000000000000000..06b3319c92e999d22b68a76c76fc3a1bdb0cc4c3
Binary files /dev/null and b/pose_images/003.png differ
diff --git a/pose_images/004.png b/pose_images/004.png
new file mode 100644
index 0000000000000000000000000000000000000000..b3c10c3b68e4bc8f04140afa9797a70bd4ad2d08
Binary files /dev/null and b/pose_images/004.png differ
diff --git a/pose_images/005.png b/pose_images/005.png
new file mode 100644
index 0000000000000000000000000000000000000000..d65e96e50c0ed355d0f1d51391539ba7ed66dc06
Binary files /dev/null and b/pose_images/005.png differ
diff --git a/pretrained_models/__MACOSX/._index_pred_net.pth b/pretrained_models/__MACOSX/._index_pred_net.pth
new file mode 100644
index 0000000000000000000000000000000000000000..0d57e39e26775a0b0b44ec9305927f28b75afe63
--- /dev/null
+++ b/pretrained_models/__MACOSX/._index_pred_net.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7a2dcb650ef91e9300ae27f9ae1656620ba2058d2d9c3b954c634228a0dc4af7
+size 212
diff --git a/pretrained_models/__MACOSX/._parsing_gen.pth b/pretrained_models/__MACOSX/._parsing_gen.pth
new file mode 100644
index 0000000000000000000000000000000000000000..0d57e39e26775a0b0b44ec9305927f28b75afe63
--- /dev/null
+++ b/pretrained_models/__MACOSX/._parsing_gen.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7a2dcb650ef91e9300ae27f9ae1656620ba2058d2d9c3b954c634228a0dc4af7
+size 212
diff --git a/pretrained_models/__MACOSX/._parsing_token.pth b/pretrained_models/__MACOSX/._parsing_token.pth
new file mode 100644
index 0000000000000000000000000000000000000000..0d57e39e26775a0b0b44ec9305927f28b75afe63
--- /dev/null
+++ b/pretrained_models/__MACOSX/._parsing_token.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7a2dcb650ef91e9300ae27f9ae1656620ba2058d2d9c3b954c634228a0dc4af7
+size 212
diff --git a/pretrained_models/__MACOSX/._sampler.pth b/pretrained_models/__MACOSX/._sampler.pth
new file mode 100644
index 0000000000000000000000000000000000000000..0d57e39e26775a0b0b44ec9305927f28b75afe63
--- /dev/null
+++ b/pretrained_models/__MACOSX/._sampler.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7a2dcb650ef91e9300ae27f9ae1656620ba2058d2d9c3b954c634228a0dc4af7
+size 212
diff --git a/pretrained_models/__MACOSX/._vqvae_bottom.pth b/pretrained_models/__MACOSX/._vqvae_bottom.pth
new file mode 100644
index 0000000000000000000000000000000000000000..0d57e39e26775a0b0b44ec9305927f28b75afe63
--- /dev/null
+++ b/pretrained_models/__MACOSX/._vqvae_bottom.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7a2dcb650ef91e9300ae27f9ae1656620ba2058d2d9c3b954c634228a0dc4af7
+size 212
diff --git a/pretrained_models/__MACOSX/._vqvae_top.pth b/pretrained_models/__MACOSX/._vqvae_top.pth
new file mode 100644
index 0000000000000000000000000000000000000000..0d57e39e26775a0b0b44ec9305927f28b75afe63
--- /dev/null
+++ b/pretrained_models/__MACOSX/._vqvae_top.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7a2dcb650ef91e9300ae27f9ae1656620ba2058d2d9c3b954c634228a0dc4af7
+size 212
diff --git a/pretrained_models/index_pred_net.pth b/pretrained_models/index_pred_net.pth
new file mode 100644
index 0000000000000000000000000000000000000000..8b161ff6cb88ffcee8e7b3be88407673441382be
--- /dev/null
+++ b/pretrained_models/index_pred_net.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:02a28bc47e52d1a1b584f20de73d8757811e934995df05fa5706720937d3ee3b
+size 121610787
diff --git a/pretrained_models/parsing_gen.pth b/pretrained_models/parsing_gen.pth
new file mode 100644
index 0000000000000000000000000000000000000000..9414e1a5f230be0cd9943e76c3bbb7318b238d3b
--- /dev/null
+++ b/pretrained_models/parsing_gen.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:21b95728125a150b2e7c18ab8c0da0fcbdc0d50b37a77f94fe1b8b2cbd18a402
+size 125357083
diff --git a/pretrained_models/parsing_token.pth b/pretrained_models/parsing_token.pth
new file mode 100644
index 0000000000000000000000000000000000000000..e1985ac8cb5fc4dd11036997fb7c8c588d65b31d
--- /dev/null
+++ b/pretrained_models/parsing_token.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:30c57ece5aa98e7f9a6fe0d80df5bc5e12221559b26d0309c345c8160cf8a7ca
+size 49738412
diff --git a/pretrained_models/sampler.pth b/pretrained_models/sampler.pth
new file mode 100644
index 0000000000000000000000000000000000000000..c62ff46449d518a3ee8762c0f5656b3220dd4896
--- /dev/null
+++ b/pretrained_models/sampler.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:87278ce360641eb5e6ea58d294c788523b3f8b301506517459ed9f8d605abde7
+size 381469328
diff --git a/pretrained_models/vqvae_bottom.pth b/pretrained_models/vqvae_bottom.pth
new file mode 100644
index 0000000000000000000000000000000000000000..d219004afe3db8d31b1cb201645a41e87d366fe3
--- /dev/null
+++ b/pretrained_models/vqvae_bottom.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aa9a0b8b9d0d5df22dbb8edd7df1b6e801575b95a899fd5937587be297d6b5f2
+size 371372343
diff --git a/pretrained_models/vqvae_top.pth b/pretrained_models/vqvae_top.pth
new file mode 100644
index 0000000000000000000000000000000000000000..189c743e02ff8f19c71e323e14ee854db5278397
--- /dev/null
+++ b/pretrained_models/vqvae_top.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4e9071e037739e2a5f39ab363653866743611f58dd9fef3be1327023f47aa9cd
+size 317603805
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..46d8d007aa59c41ffc8114bda3ea70862efd1e8d
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,119 @@
+addict==2.4.0
+aiofiles==23.2.1
+aliyun-python-sdk-core==2.15.1
+aliyun-python-sdk-kms==2.16.3
+altair==5.3.0
+annotated-types==0.6.0
+anyio==4.3.0
+attrs==23.2.0
+certifi==2024.2.2
+cffi==1.16.0
+charset-normalizer==2.1.1
+click==8.1.7
+colorama==0.4.6
+contourpy==1.2.1
+crcmod==1.7
+cryptography==42.0.7
+cycler==0.12.1
+dnspython==2.6.1
+einops==0.4.1
+email_validator==2.1.1
+exceptiongroup==1.2.1
+fastapi==0.111.0
+fastapi-cli==0.0.3
+ffmpy==0.3.2
+filelock==3.14.0
+fonttools==4.51.0
+fsspec==2024.3.1
+gradio==4.31.2
+gradio_client==0.16.3
+h11==0.14.0
+httpcore==1.0.5
+httptools==0.6.1
+httpx==0.27.0
+huggingface-hub==0.23.0
+idna==3.7
+importlib_metadata==7.1.0
+importlib_resources==6.4.0
+Jinja2==3.1.4
+jmespath==0.10.0
+joblib==1.4.2
+jsonschema==4.22.0
+jsonschema-specifications==2023.12.1
+kiwisolver==1.4.5
+lpips==0.1.4
+Markdown==3.6
+markdown-it-py==3.0.0
+MarkupSafe==2.1.5
+matplotlib==3.8.4
+mdurl==0.1.2
+mmcls==0.25.0
+mmcv-full==1.5.2
+mmsegmentation==0.24.1
+model-index==0.1.11
+nltk==3.8.1
+numpy==1.22.3
+opencv-python==4.9.0.80
+opendatalab==0.0.10
+openmim==0.3.9
+openxlab==0.0.35
+ordered-set==4.1.0
+orjson==3.10.3
+oss2==2.17.0
+packaging==24.0
+pandas==2.0.3
+Pillow==9.1.1
+platformdirs==4.2.1
+prettytable==3.10.0
+pycparser==2.22
+pycryptodome==3.20.0
+pydantic==2.7.1
+pydantic_core==2.18.2
+pydub==0.25.1
+Pygments==2.18.0
+pyparsing==3.1.2
+python-dateutil==2.9.0.post0
+python-dotenv==1.0.1
+python-multipart==0.0.9
+pytz==2023.4
+pywin32==306
+PyYAML==6.0.1
+referencing==0.35.1
+regex==2024.5.10
+requests==2.28.2
+rich==13.4.2
+rpds-py==0.18.1
+ruff==0.4.4
+safetensors==0.4.3
+scikit-learn==1.4.2
+scipy==1.11.4
+semantic-version==2.10.0
+sentence-transformers==2.7.0
+sentencepiece==0.2.0
+shellingham==1.5.4
+six==1.16.0
+sniffio==1.3.1
+starlette==0.37.2
+tabulate==0.9.0
+threadpoolctl==3.5.0
+tokenizers==0.19.1
+tomli==2.0.1
+tomlkit==0.12.0
+toolz==0.12.1
+torch==1.13.0+cpu
+torchaudio==0.13.0+cpu
+torchvision==0.14.0+cpu
+tqdm==4.65.2
+transformers==4.40.2
+typer==0.12.3
+typing_extensions==4.11.0
+tzdata==2024.1
+ujson==5.10.0
+urllib3==1.26.18
+uvicorn==0.29.0
+watchfiles==0.21.0
+wcwidth==0.2.13
+websockets==11.0.3
+windows-curses==2.3.3
+yapf==0.40.2
+zipp==3.18.1
diff --git a/style.css b/style.css
new file mode 100644
index 0000000000000000000000000000000000000000..22ad0be91ed35841bc456be4a0044474affc9a17
--- /dev/null
+++ b/style.css
@@ -0,0 +1,16 @@
+h1 {
+ text-align: center;
+}
+#input-image {
+ max-height: 300px;
+}
+#label-image {
+ height: 300px;
+}
+#result-image {
+ height: 300px;
+}
+img#visitor-badge {
+ display: block;
+ margin: auto;
+}