diff --git a/apps/examples/1.jpg b/apps/examples/1.jpg deleted file mode 100644 index b4b30221c1008797a291ded245f4153bf927fbb7..0000000000000000000000000000000000000000 Binary files a/apps/examples/1.jpg and /dev/null differ diff --git a/apps/examples/1_cute_girl.webp b/apps/examples/1_cute_girl.webp deleted file mode 100644 index 648206b5aa91124707cd5513bcc6ed8e866e1de6..0000000000000000000000000000000000000000 Binary files a/apps/examples/1_cute_girl.webp and /dev/null differ diff --git a/apps/examples/bird.jpg b/apps/examples/bird.jpg deleted file mode 100644 index 62622f31d4fe078c93c546941b9b64701889f37e..0000000000000000000000000000000000000000 Binary files a/apps/examples/bird.jpg and /dev/null differ diff --git a/apps/examples/blue_monster.webp b/apps/examples/blue_monster.webp deleted file mode 100644 index 7204f5fa63d86843269210c306488444e0efafe1..0000000000000000000000000000000000000000 Binary files a/apps/examples/blue_monster.webp and /dev/null differ diff --git a/apps/examples/boy2.webp b/apps/examples/boy2.webp deleted file mode 100644 index a2e1615efd6ca3f87aac3f43884666b85a804dc4..0000000000000000000000000000000000000000 Binary files a/apps/examples/boy2.webp and /dev/null differ diff --git a/apps/examples/bulldog.webp b/apps/examples/bulldog.webp deleted file mode 100644 index 318fcf22cf756b31e25364f8fe28a3852c72fc9c..0000000000000000000000000000000000000000 Binary files a/apps/examples/bulldog.webp and /dev/null differ diff --git a/apps/examples/catman.webp b/apps/examples/catman.webp deleted file mode 100644 index 5dd68d7308d4785f314e976561facc43e9d9e53a..0000000000000000000000000000000000000000 Binary files a/apps/examples/catman.webp and /dev/null differ diff --git a/apps/examples/cyberpunk_man.webp b/apps/examples/cyberpunk_man.webp deleted file mode 100644 index 4c609f49d6dcfcb93d11c85b675fe1b0641300c9..0000000000000000000000000000000000000000 Binary files a/apps/examples/cyberpunk_man.webp and /dev/null differ diff --git a/apps/examples/dinosaur_boy.webp b/apps/examples/dinosaur_boy.webp deleted file mode 100644 index afd68e6205bba8dc5dd178ec8dcd0a6fd24287bb..0000000000000000000000000000000000000000 Binary files a/apps/examples/dinosaur_boy.webp and /dev/null differ diff --git a/apps/examples/doraemon.webp b/apps/examples/doraemon.webp deleted file mode 100644 index b5923198c1a6d3da835f9d9f6299ad8f3ad81732..0000000000000000000000000000000000000000 Binary files a/apps/examples/doraemon.webp and /dev/null differ diff --git a/apps/examples/dragon.webp b/apps/examples/dragon.webp deleted file mode 100644 index 2debdc5f7c546e8830ee6f79b7470edc43aff33a..0000000000000000000000000000000000000000 Binary files a/apps/examples/dragon.webp and /dev/null differ diff --git a/apps/examples/dragontoy.jpg b/apps/examples/dragontoy.jpg deleted file mode 100644 index 39ce417c23734e91a667eeb852fca63e1f6249b3..0000000000000000000000000000000000000000 Binary files a/apps/examples/dragontoy.jpg and /dev/null differ diff --git a/apps/examples/girl1.webp b/apps/examples/girl1.webp deleted file mode 100644 index 3e42bbdffd07d3a5449399e4c2e229d3cd1da7ec..0000000000000000000000000000000000000000 Binary files a/apps/examples/girl1.webp and /dev/null differ diff --git a/apps/examples/gun.webp b/apps/examples/gun.webp deleted file mode 100644 index 589c5391e53bb89c98673d5b2b2e0d76677bcb32..0000000000000000000000000000000000000000 Binary files a/apps/examples/gun.webp and /dev/null differ diff --git a/apps/examples/kunkun.webp b/apps/examples/kunkun.webp deleted file mode 100644 index a0c45d3c7091cdad8ef69374a9cbd065353e7982..0000000000000000000000000000000000000000 Binary files a/apps/examples/kunkun.webp and /dev/null differ diff --git a/apps/examples/link.webp b/apps/examples/link.webp deleted file mode 100644 index 1b5bcfc99fd3e2e6d674b5c877f4128109355900..0000000000000000000000000000000000000000 Binary files a/apps/examples/link.webp and /dev/null differ diff --git a/apps/examples/mushroom1.webp b/apps/examples/mushroom1.webp deleted file mode 100644 index bc41abe7d5113bb147a85fdb4f5b8f2d8e1e3851..0000000000000000000000000000000000000000 Binary files a/apps/examples/mushroom1.webp and /dev/null differ diff --git a/apps/examples/mushroom2.webp b/apps/examples/mushroom2.webp deleted file mode 100644 index 08d870a503df6970534797730f7e429e57a84ab8..0000000000000000000000000000000000000000 Binary files a/apps/examples/mushroom2.webp and /dev/null differ diff --git a/apps/examples/phoenix.webp b/apps/examples/phoenix.webp deleted file mode 100644 index 7a15695cbdbda761ac08926c40cb9d19400051d8..0000000000000000000000000000000000000000 Binary files a/apps/examples/phoenix.webp and /dev/null differ diff --git a/apps/examples/robot.png b/apps/examples/robot.png deleted file mode 100644 index 5522be23a8373e96af1f6ba0b08fe31ec0467a41..0000000000000000000000000000000000000000 Binary files a/apps/examples/robot.png and /dev/null differ diff --git a/apps/examples/rose.webp b/apps/examples/rose.webp deleted file mode 100644 index d245944b32d29a5a7c56bf72b595fcc2b7de6650..0000000000000000000000000000000000000000 Binary files a/apps/examples/rose.webp and /dev/null differ diff --git a/apps/examples/shoe.webp b/apps/examples/shoe.webp deleted file mode 100644 index 41bbb431f566663f3fb8ebd5dc32533cd4a2b990..0000000000000000000000000000000000000000 Binary files a/apps/examples/shoe.webp and /dev/null differ diff --git a/apps/examples/sports_girl.webp b/apps/examples/sports_girl.webp deleted file mode 100644 index 15006e2e0608f602aa79d17ea0f24d94b9e62061..0000000000000000000000000000000000000000 Binary files a/apps/examples/sports_girl.webp and /dev/null differ diff --git a/apps/examples/stone.webp b/apps/examples/stone.webp deleted file mode 100644 index c462867d2f5d5a0f9ab7414b20af8b6adb46e2ae..0000000000000000000000000000000000000000 Binary files a/apps/examples/stone.webp and /dev/null differ diff --git a/apps/examples/sweater.webp b/apps/examples/sweater.webp deleted file mode 100644 index 31bcb40ca77d0b2a6cb37ebde421c1d72fd44ddd..0000000000000000000000000000000000000000 Binary files a/apps/examples/sweater.webp and /dev/null differ diff --git a/apps/examples/sword.webp b/apps/examples/sword.webp deleted file mode 100644 index be576fa4a76ca8ddcc934d81caaa65eac2faa04c..0000000000000000000000000000000000000000 Binary files a/apps/examples/sword.webp and /dev/null differ diff --git a/apps/examples/teapot.webp b/apps/examples/teapot.webp deleted file mode 100644 index a1e5809eeae47e05149c9f7b19d80810de58d19f..0000000000000000000000000000000000000000 Binary files a/apps/examples/teapot.webp and /dev/null differ diff --git a/apps/examples/toy_bear.webp b/apps/examples/toy_bear.webp deleted file mode 100644 index 7cfff09ab76ea5f0ea6ac6a180d349e593e05fee..0000000000000000000000000000000000000000 Binary files a/apps/examples/toy_bear.webp and /dev/null differ diff --git a/apps/examples/toy_dog.webp b/apps/examples/toy_dog.webp deleted file mode 100644 index 9e9d6aae6fa32807b734df2351a3d07e7e1ffa55..0000000000000000000000000000000000000000 Binary files a/apps/examples/toy_dog.webp and /dev/null differ diff --git a/apps/examples/toy_pig.webp b/apps/examples/toy_pig.webp deleted file mode 100644 index 600edfaea0c067f78d9db4c1e9af044b92437c91..0000000000000000000000000000000000000000 Binary files a/apps/examples/toy_pig.webp and /dev/null differ diff --git a/apps/examples/toy_rabbit.webp b/apps/examples/toy_rabbit.webp deleted file mode 100644 index 0c51b6e83ba928031967ae091c31b40efd9580ae..0000000000000000000000000000000000000000 Binary files a/apps/examples/toy_rabbit.webp and /dev/null differ diff --git a/apps/examples/wiking.webp b/apps/examples/wiking.webp deleted file mode 100644 index 71ff562749bbe8a43c154d282f1c21ca21dd5a96..0000000000000000000000000000000000000000 Binary files a/apps/examples/wiking.webp and /dev/null differ diff --git a/apps/examples/wings.webp b/apps/examples/wings.webp deleted file mode 100644 index 41b9e0e629d029bd27674b420d6a9fe0b3655156..0000000000000000000000000000000000000000 Binary files a/apps/examples/wings.webp and /dev/null differ diff --git a/apps/third_party/CRM/LICENSE b/apps/third_party/CRM/LICENSE deleted file mode 100644 index 8840910d3e0d809884ce88440d615cf493475272..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2024 TSAIL group - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/apps/third_party/CRM/README.md b/apps/third_party/CRM/README.md deleted file mode 100644 index 0fb9821a04f51330e00206575bab88fce08fb786..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/README.md +++ /dev/null @@ -1,85 +0,0 @@ -# Convolutional Reconstruction Model - -Official implementation for *CRM: Single Image to 3D Textured Mesh with Convolutional Reconstruction Model*. - -**CRM is a feed-forward model which can generate 3D textured mesh in 10 seconds.** - -## [Project Page](https://ml.cs.tsinghua.edu.cn/~zhengyi/CRM/) | [Arxiv](https://arxiv.org/abs/2403.05034) | [HF-Demo](https://huggingface.co./spaces/Zhengyi/CRM) | [Weights](https://huggingface.co./Zhengyi/CRM) - -https://github.com/thu-ml/CRM/assets/40787266/8b325bc0-aa74-4c26-92e8-a8f0c1079382 - -## Try CRM 🍻 -* Try CRM at [Huggingface Demo](https://huggingface.co./spaces/Zhengyi/CRM). -* Try CRM at [Replicate Demo](https://replicate.com/camenduru/crm). Thanks [@camenduru](https://github.com/camenduru)! - -## Install - -### Step 1 - Base - -Install package one by one, we use **python 3.9** - -```bash -pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117 -pip install torch-scatter==2.1.1 -f https://data.pyg.org/whl/torch-1.13.1+cu117.html -pip install kaolin==0.14.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-1.13.1_cu117.html -pip install -r requirements.txt -``` - -besides, one by one need to install xformers manually according to the official [doc](https://github.com/facebookresearch/xformers?tab=readme-ov-file#installing-xformers) (**conda no need**), e.g. - -```bash -pip install ninja -pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers -``` - -### Step 2 - Nvdiffrast - -Install nvdiffrast according to the official [doc](https://nvlabs.github.io/nvdiffrast/#installation), e.g. - -```bash -pip install git+https://github.com/NVlabs/nvdiffrast -``` - - - -## Inference - -We suggest gradio for a visualized inference. - -``` -gradio app.py -``` - -![image](https://github.com/thu-ml/CRM/assets/40787266/4354d22a-a641-4531-8408-c761ead8b1a2) - -For inference in command lines, simply run -```bash -CUDA_VISIBLE_DEVICES="0" python run.py --inputdir "examples/kunkun.webp" -``` -It will output the preprocessed image, generated 6-view images and CCMs and a 3D model in obj format. - -**Tips:** (1) If the result is unsatisfatory, please check whether the input image is correctly pre-processed into a grey background. Otherwise the results will be unpredictable. -(2) Different from the [Huggingface Demo](https://huggingface.co./spaces/Zhengyi/CRM), this official implementation uses UV texture instead of vertex color. It has better texture than the online demo but longer generating time owing to the UV texturing. - -## Todo List -- [x] Release inference code. -- [x] Release pretrained models. -- [ ] Optimize inference code to fit in low memery GPU. -- [ ] Upload training code. - -## Acknowledgement -- [ImageDream](https://github.com/bytedance/ImageDream) -- [nvdiffrast](https://github.com/NVlabs/nvdiffrast) -- [kiuikit](https://github.com/ashawkey/kiuikit) -- [GET3D](https://github.com/nv-tlabs/GET3D) - -## Citation - -``` -@article{wang2024crm, - title={CRM: Single Image to 3D Textured Mesh with Convolutional Reconstruction Model}, - author={Zhengyi Wang and Yikai Wang and Yifei Chen and Chendong Xiang and Shuo Chen and Dajiang Yu and Chongxuan Li and Hang Su and Jun Zhu}, - journal={arXiv preprint arXiv:2403.05034}, - year={2024} -} -``` diff --git a/apps/third_party/CRM/app.py b/apps/third_party/CRM/app.py deleted file mode 100644 index dae2f6baf19cf3bb2817f166c81679ae4eec0685..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/app.py +++ /dev/null @@ -1,228 +0,0 @@ -# Not ready to use yet -import argparse -import numpy as np -import gradio as gr -from omegaconf import OmegaConf -import torch -from PIL import Image -import PIL -from pipelines import TwoStagePipeline -from huggingface_hub import hf_hub_download -import os -import rembg -from typing import Any -import json -import os -import json -import argparse - -from model import CRM -from inference import generate3d - -pipeline = None -rembg_session = rembg.new_session() - - -def expand_to_square(image, bg_color=(0, 0, 0, 0)): - # expand image to 1:1 - width, height = image.size - if width == height: - return image - new_size = (max(width, height), max(width, height)) - new_image = Image.new("RGBA", new_size, bg_color) - paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2) - new_image.paste(image, paste_position) - return new_image - -def check_input_image(input_image): - if input_image is None: - raise gr.Error("No image uploaded!") - - -def remove_background( - image: PIL.Image.Image, - rembg_session = None, - force: bool = False, - **rembg_kwargs, -) -> PIL.Image.Image: - do_remove = True - if image.mode == "RGBA" and image.getextrema()[3][0] < 255: - # explain why current do not rm bg - print("alhpa channl not enpty, skip remove background, using alpha channel as mask") - background = Image.new("RGBA", image.size, (0, 0, 0, 0)) - image = Image.alpha_composite(background, image) - do_remove = False - do_remove = do_remove or force - if do_remove: - image = rembg.remove(image, session=rembg_session, **rembg_kwargs) - return image - -def do_resize_content(original_image: Image, scale_rate): - # resize image content wile retain the original image size - if scale_rate != 1: - # Calculate the new size after rescaling - new_size = tuple(int(dim * scale_rate) for dim in original_image.size) - # Resize the image while maintaining the aspect ratio - resized_image = original_image.resize(new_size) - # Create a new image with the original size and black background - padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0)) - paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2) - padded_image.paste(resized_image, paste_position) - return padded_image - else: - return original_image - -def add_background(image, bg_color=(255, 255, 255)): - # given an RGBA image, alpha channel is used as mask to add background color - background = Image.new("RGBA", image.size, bg_color) - return Image.alpha_composite(background, image) - - -def preprocess_image(image, background_choice, foreground_ratio, backgroud_color): - """ - input image is a pil image in RGBA, return RGB image - """ - print(background_choice) - if background_choice == "Alpha as mask": - background = Image.new("RGBA", image.size, (0, 0, 0, 0)) - image = Image.alpha_composite(background, image) - else: - image = remove_background(image, rembg_session, force_remove=True) - image = do_resize_content(image, foreground_ratio) - image = expand_to_square(image) - image = add_background(image, backgroud_color) - return image.convert("RGB") - - -def gen_image(input_image, seed, scale, step): - global pipeline, model, args - pipeline.set_seed(seed) - rt_dict = pipeline(input_image, scale=scale, step=step) - stage1_images = rt_dict["stage1_images"] - stage2_images = rt_dict["stage2_images"] - np_imgs = np.concatenate(stage1_images, 1) - np_xyzs = np.concatenate(stage2_images, 1) - - glb_path, obj_path = generate3d(model, np_imgs, np_xyzs, args.device) - return Image.fromarray(np_imgs), Image.fromarray(np_xyzs), glb_path, obj_path - - -parser = argparse.ArgumentParser() -parser.add_argument( - "--stage1_config", - type=str, - default="configs/nf7_v3_SNR_rd_size_stroke.yaml", - help="config for stage1", -) -parser.add_argument( - "--stage2_config", - type=str, - default="configs/stage2-v2-snr.yaml", - help="config for stage2", -) - -parser.add_argument("--device", type=str, default="cuda") -args = parser.parse_args() - -crm_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="CRM.pth") -specs = json.load(open("configs/specs_objaverse_total.json")) -model = CRM(specs).to(args.device) -model.load_state_dict(torch.load(crm_path, map_location = args.device), strict=False) - -stage1_config = OmegaConf.load(args.stage1_config).config -stage2_config = OmegaConf.load(args.stage2_config).config -stage2_sampler_config = stage2_config.sampler -stage1_sampler_config = stage1_config.sampler - -stage1_model_config = stage1_config.models -stage2_model_config = stage2_config.models - -xyz_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="ccm-diffusion.pth") -pixel_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="pixel-diffusion.pth") -stage1_model_config.resume = pixel_path -stage2_model_config.resume = xyz_path - -pipeline = TwoStagePipeline( - stage1_model_config, - stage2_model_config, - stage1_sampler_config, - stage2_sampler_config, - device=args.device, - dtype=torch.float16 -) - -with gr.Blocks() as demo: - gr.Markdown("# CRM: Single Image to 3D Textured Mesh with Convolutional Reconstruction Model") - with gr.Row(): - with gr.Column(): - with gr.Row(): - image_input = gr.Image( - label="Image input", - image_mode="RGBA", - sources="upload", - type="pil", - ) - processed_image = gr.Image(label="Processed Image", interactive=False, type="pil", image_mode="RGB") - with gr.Row(): - with gr.Column(): - with gr.Row(): - background_choice = gr.Radio([ - "Alpha as mask", - "Auto Remove background" - ], value="Auto Remove background", - label="backgroud choice") - # do_remove_background = gr.Checkbox(label=, value=True) - # force_remove = gr.Checkbox(label=, value=False) - back_groud_color = gr.ColorPicker(label="Background Color", value="#7F7F7F", interactive=False) - foreground_ratio = gr.Slider( - label="Foreground Ratio", - minimum=0.5, - maximum=1.0, - value=1.0, - step=0.05, - ) - - with gr.Column(): - seed = gr.Number(value=1234, label="seed", precision=0) - guidance_scale = gr.Number(value=5.5, minimum=3, maximum=10, label="guidance_scale") - step = gr.Number(value=50, minimum=30, maximum=100, label="sample steps", precision=0) - text_button = gr.Button("Generate 3D shape") - gr.Examples( - examples=[os.path.join("examples", i) for i in os.listdir("examples")], - inputs=[image_input], - ) - with gr.Column(): - image_output = gr.Image(interactive=False, label="Output RGB image") - xyz_ouput = gr.Image(interactive=False, label="Output CCM image") - - output_model = gr.Model3D( - label="Output GLB", - interactive=False, - ) - gr.Markdown("Note: The GLB model shown here has a darker lighting and enlarged UV seams. Download for correct results.") - output_obj = gr.File(interactive=False, label="Output OBJ") - - inputs = [ - processed_image, - seed, - guidance_scale, - step, - ] - outputs = [ - image_output, - xyz_ouput, - output_model, - output_obj, - ] - - - text_button.click(fn=check_input_image, inputs=[image_input]).success( - fn=preprocess_image, - inputs=[image_input, background_choice, foreground_ratio, back_groud_color], - outputs=[processed_image], - ).success( - fn=gen_image, - inputs=inputs, - outputs=outputs, - ) - demo.queue().launch() diff --git a/apps/third_party/CRM/configs/nf7_v3_SNR_rd_size_stroke.yaml b/apps/third_party/CRM/configs/nf7_v3_SNR_rd_size_stroke.yaml deleted file mode 100644 index 760f41f2728a94114f674e2160a75a65a8a3a656..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/configs/nf7_v3_SNR_rd_size_stroke.yaml +++ /dev/null @@ -1,21 +0,0 @@ -config: -# others - seed: 1234 - num_frames: 7 - mode: pixel - offset_noise: true -# model related - models: - config: imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml - resume: models/pixel.pth -# sampler related - sampler: - target: libs.sample.ImageDreamDiffusion - params: - mode: pixel - num_frames: 7 - camera_views: [1, 2, 3, 4, 5, 0, 0] - ref_position: 6 - random_background: false - offset_noise: true - resize_rate: 1.0 \ No newline at end of file diff --git a/apps/third_party/CRM/configs/specs_objaverse_total.json b/apps/third_party/CRM/configs/specs_objaverse_total.json deleted file mode 100644 index c99ebee563a7d44859338382b197ef55963e87d0..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/configs/specs_objaverse_total.json +++ /dev/null @@ -1,57 +0,0 @@ -{ - "Input": { - "img_num": 16, - "class": "all", - "camera_angle_num": 8, - "tet_grid_size": 80, - "validate_num": 16, - "scale": 0.95, - "radius": 3, - "resolution": [256, 256] - }, - - "Pretrain": { - "mode": null, - "sdf_threshold": 0.1, - "sdf_scale": 10, - "batch_infer": false, - "lr": 1e-4, - "radius": 0.5 - }, - - "Train": { - "mode": "rnd", - "num_epochs": 500, - "grad_acc": 1, - "warm_up": 0, - "decay": 0.000, - "learning_rate": { - "init": 1e-4, - "sdf_decay": 1, - "rgb_decay": 1 - }, - "batch_size": 4, - "eva_iter": 80, - "eva_all_epoch": 10, - "tex_sup_mode": "blender", - "exp_uv_mesh": false, - "doub": false, - "random_bg": false, - "shift": 0, - "aug_shift": 0, - "geo_type": "flex" - }, - - "ArchSpecs": { - "unet_type": "diffusers", - "use_3D_aware": false, - "fea_concat": false, - "mlp_bias": true - }, - - "DecoderSpecs": { - "c_dim": 32, - "plane_resolution": 256 - } -} - diff --git a/apps/third_party/CRM/configs/stage2-v2-snr.yaml b/apps/third_party/CRM/configs/stage2-v2-snr.yaml deleted file mode 100644 index 8e76d1a2a8ff71ba1318c9ff6ff6c59a7a9e606e..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/configs/stage2-v2-snr.yaml +++ /dev/null @@ -1,25 +0,0 @@ -config: -# others - seed: 1234 - num_frames: 6 - mode: pixel - offset_noise: true - gd_type: xyz -# model related - models: - config: imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml - resume: models/xyz.pth - -# eval related - sampler: - target: libs.sample.ImageDreamDiffusionStage2 - params: - mode: pixel - num_frames: 6 - camera_views: [1, 2, 3, 4, 5, 0] - ref_position: null - random_background: false - offset_noise: true - resize_rate: 1.0 - - diff --git a/apps/third_party/CRM/imagedream/.DS_Store b/apps/third_party/CRM/imagedream/.DS_Store deleted file mode 100644 index 5e042f9aa5cf18c3b51e766fee10546755d5ed4d..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/.DS_Store and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/__init__.py b/apps/third_party/CRM/imagedream/__init__.py deleted file mode 100644 index 326f18c2d65a018b1c214c71f1c44428a8a8089b..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/imagedream/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .model_zoo import build_model diff --git a/apps/third_party/CRM/imagedream/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 5514f1066fcdb863f45a135e26483e3b31c2a44d..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index f7c150c7cb625327e1ccb59742e7c4842561eb1d..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/__pycache__/camera_utils.cpython-310.pyc b/apps/third_party/CRM/imagedream/__pycache__/camera_utils.cpython-310.pyc deleted file mode 100644 index aaaf53fc96fd79461da559cd87213eac03d63e1a..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/__pycache__/camera_utils.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/__pycache__/camera_utils.cpython-38.pyc b/apps/third_party/CRM/imagedream/__pycache__/camera_utils.cpython-38.pyc deleted file mode 100644 index 599048b52c049602f55ce7f6b569e278ae8e8dd6..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/__pycache__/camera_utils.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/__pycache__/model_zoo.cpython-310.pyc b/apps/third_party/CRM/imagedream/__pycache__/model_zoo.cpython-310.pyc deleted file mode 100644 index 7eaa9fbbbfe649e49709f705e513ef0f40d7690d..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/__pycache__/model_zoo.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/__pycache__/model_zoo.cpython-38.pyc b/apps/third_party/CRM/imagedream/__pycache__/model_zoo.cpython-38.pyc deleted file mode 100644 index 6bd45cca067b99143299e1ba9b78a2a23846e390..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/__pycache__/model_zoo.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/camera_utils.py b/apps/third_party/CRM/imagedream/camera_utils.py deleted file mode 100644 index 6fb352745d737d96b2652f80c23778058e81f6e6..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/imagedream/camera_utils.py +++ /dev/null @@ -1,99 +0,0 @@ -import numpy as np -import torch - - -def create_camera_to_world_matrix(elevation, azimuth): - elevation = np.radians(elevation) - azimuth = np.radians(azimuth) - # Convert elevation and azimuth angles to Cartesian coordinates on a unit sphere - x = np.cos(elevation) * np.sin(azimuth) - y = np.sin(elevation) - z = np.cos(elevation) * np.cos(azimuth) - - # Calculate camera position, target, and up vectors - camera_pos = np.array([x, y, z]) - target = np.array([0, 0, 0]) - up = np.array([0, 1, 0]) - - # Construct view matrix - forward = target - camera_pos - forward /= np.linalg.norm(forward) - right = np.cross(forward, up) - right /= np.linalg.norm(right) - new_up = np.cross(right, forward) - new_up /= np.linalg.norm(new_up) - cam2world = np.eye(4) - cam2world[:3, :3] = np.array([right, new_up, -forward]).T - cam2world[:3, 3] = camera_pos - return cam2world - - -def convert_opengl_to_blender(camera_matrix): - if isinstance(camera_matrix, np.ndarray): - # Construct transformation matrix to convert from OpenGL space to Blender space - flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) - camera_matrix_blender = np.dot(flip_yz, camera_matrix) - else: - # Construct transformation matrix to convert from OpenGL space to Blender space - flip_yz = torch.tensor( - [[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]] - ) - if camera_matrix.ndim == 3: - flip_yz = flip_yz.unsqueeze(0) - camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix), camera_matrix) - return camera_matrix_blender - - -def normalize_camera(camera_matrix): - """normalize the camera location onto a unit-sphere""" - if isinstance(camera_matrix, np.ndarray): - camera_matrix = camera_matrix.reshape(-1, 4, 4) - translation = camera_matrix[:, :3, 3] - translation = translation / ( - np.linalg.norm(translation, axis=1, keepdims=True) + 1e-8 - ) - camera_matrix[:, :3, 3] = translation - else: - camera_matrix = camera_matrix.reshape(-1, 4, 4) - translation = camera_matrix[:, :3, 3] - translation = translation / ( - torch.norm(translation, dim=1, keepdim=True) + 1e-8 - ) - camera_matrix[:, :3, 3] = translation - return camera_matrix.reshape(-1, 16) - - -def get_camera( - num_frames, - elevation=15, - azimuth_start=0, - azimuth_span=360, - blender_coord=True, - extra_view=False, -): - angle_gap = azimuth_span / num_frames - cameras = [] - for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap): - camera_matrix = create_camera_to_world_matrix(elevation, azimuth) - if blender_coord: - camera_matrix = convert_opengl_to_blender(camera_matrix) - cameras.append(camera_matrix.flatten()) - - if extra_view: - dim = len(cameras[0]) - cameras.append(np.zeros(dim)) - return torch.tensor(np.stack(cameras, 0)).float() - - -def get_camera_for_index(data_index): - """ - 按照当前我们的数据格式, 以000为正对我们的情况: - 000是正面, ev: 0, azimuth: 0 - 001是左边, ev: 0, azimuth: -90 - 002是下面, ev: -90, azimuth: 0 - 003是背面, ev: 0, azimuth: 180 - 004是右边, ev: 0, azimuth: 90 - 005是上面, ev: 90, azimuth: 0 - """ - params = [(0, 0), (0, -90), (-90, 0), (0, 180), (0, 90), (90, 0)] - return get_camera(1, *params[data_index]) \ No newline at end of file diff --git a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv.yaml b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv.yaml deleted file mode 100644 index b4ecc2a694dbbde82d45bc3b16d1ebbb27ac552a..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv.yaml +++ /dev/null @@ -1,61 +0,0 @@ -model: - target: imagedream.ldm.interface.LatentDiffusionInterface - params: - linear_start: 0.00085 - linear_end: 0.0120 - timesteps: 1000 - scale_factor: 0.18215 - parameterization: "eps" - - unet_config: - target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_head_channels: 64 # need to fix for flash-attn - use_spatial_transformer: True - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - use_checkpoint: False - legacy: False - camera_dim: 16 - with_ip: True - ip_dim: 16 # ip token length - ip_mode: "local_resample" - - vae_config: - target: imagedream.ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - #attn_type: "vanilla-xformers" - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - clip_config: - target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - freeze: True - layer: "penultimate" - ip_mode: "local_resample" diff --git a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_ch8.yaml b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_ch8.yaml deleted file mode 100644 index ee80712395fa6323bceb437a40b4829bb497adf7..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_ch8.yaml +++ /dev/null @@ -1,61 +0,0 @@ -model: - target: imagedream.ldm.interface.LatentDiffusionInterface - params: - linear_start: 0.00085 - linear_end: 0.0120 - timesteps: 1000 - scale_factor: 0.18215 - parameterization: "eps" - - unet_config: - target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel - params: - image_size: 32 # unused - in_channels: 8 - out_channels: 8 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_head_channels: 64 # need to fix for flash-attn - use_spatial_transformer: True - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - use_checkpoint: False - legacy: False - camera_dim: 16 - with_ip: True - ip_dim: 16 # ip token length - ip_mode: "local_resample" - - vae_config: - target: imagedream.ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - #attn_type: "vanilla-xformers" - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - clip_config: - target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - freeze: True - layer: "penultimate" - ip_mode: "local_resample" diff --git a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_chin8.yaml b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_chin8.yaml deleted file mode 100644 index ce64b053fc76ded2ed85af7c5d398e2981e51c3d..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_chin8.yaml +++ /dev/null @@ -1,61 +0,0 @@ -model: - target: imagedream.ldm.interface.LatentDiffusionInterface - params: - linear_start: 0.00085 - linear_end: 0.0120 - timesteps: 1000 - scale_factor: 0.18215 - parameterization: "eps" - - unet_config: - target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModelStage2 - params: - image_size: 32 # unused - in_channels: 8 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_head_channels: 64 # need to fix for flash-attn - use_spatial_transformer: True - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - use_checkpoint: False - legacy: False - camera_dim: 16 - with_ip: True - ip_dim: 16 # ip token length - ip_mode: "local_resample" - - vae_config: - target: imagedream.ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - #attn_type: "vanilla-xformers" - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - clip_config: - target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - freeze: True - layer: "penultimate" - ip_mode: "local_resample" diff --git a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml deleted file mode 100644 index bd9c835f06707aeffcd5cd8fd4cd7ec262646905..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml +++ /dev/null @@ -1,62 +0,0 @@ -model: - target: imagedream.ldm.interface.LatentDiffusionInterface - params: - linear_start: 0.00085 - linear_end: 0.0120 - timesteps: 1000 - scale_factor: 0.18215 - parameterization: "eps" - zero_snr: true - - unet_config: - target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModelStage2 - params: - image_size: 32 # unused - in_channels: 8 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_head_channels: 64 # need to fix for flash-attn - use_spatial_transformer: True - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - use_checkpoint: False - legacy: False - camera_dim: 16 - with_ip: True - ip_dim: 16 # ip token length - ip_mode: "local_resample" - - vae_config: - target: imagedream.ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - #attn_type: "vanilla-xformers" - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - clip_config: - target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - freeze: True - layer: "penultimate" - ip_mode: "local_resample" diff --git a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_local.yaml b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_local.yaml deleted file mode 100644 index b4fbe27c84d5c2b05737894dc8970f45f08ba606..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_local.yaml +++ /dev/null @@ -1,62 +0,0 @@ -model: - target: imagedream.ldm.interface.LatentDiffusionInterface - params: - linear_start: 0.00085 - linear_end: 0.0120 - timesteps: 1000 - scale_factor: 0.18215 - parameterization: "eps" - - unet_config: - target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_head_channels: 64 # need to fix for flash-attn - use_spatial_transformer: True - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - use_checkpoint: False - legacy: False - camera_dim: 16 - with_ip: True - ip_dim: 16 # ip token length - ip_mode: "local_resample" - ip_weight: 1.0 # adjust for similarity to image - - vae_config: - target: imagedream.ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - #attn_type: "vanilla-xformers" - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - clip_config: - target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - freeze: True - layer: "penultimate" - ip_mode: "local_resample" diff --git a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml deleted file mode 100644 index 5824cd7f25a9d1ef2e397341a39f7c724eb1cf76..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml +++ /dev/null @@ -1,62 +0,0 @@ -model: - target: imagedream.ldm.interface.LatentDiffusionInterface - params: - linear_start: 0.00085 - linear_end: 0.0120 - timesteps: 1000 - scale_factor: 0.18215 - parameterization: "eps" - zero_snr: true - - unet_config: - target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_head_channels: 64 # need to fix for flash-attn - use_spatial_transformer: True - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - use_checkpoint: False - legacy: False - camera_dim: 16 - with_ip: True - ip_dim: 16 # ip token length - ip_mode: "local_resample" - - vae_config: - target: imagedream.ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - #attn_type: "vanilla-xformers" - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - clip_config: - target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - freeze: True - layer: "penultimate" - ip_mode: "local_resample" diff --git a/apps/third_party/CRM/imagedream/ldm/__init__.py b/apps/third_party/CRM/imagedream/ldm/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/apps/third_party/CRM/imagedream/ldm/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index d029d6984122456fadb1fcb2eb074aa9471b82d4..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 02bf8990d7b4f5cdaf387895d86be5e17bf8ed39..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/__pycache__/interface.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/__pycache__/interface.cpython-310.pyc deleted file mode 100644 index 01d44982e9c3c83b6c7479b28fe0604fd8f79fbd..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/__pycache__/interface.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/__pycache__/interface.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/__pycache__/interface.cpython-38.pyc deleted file mode 100644 index d9f7089053acea4170588b4fc2a48adf647beaf0..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/__pycache__/interface.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/__pycache__/util.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/__pycache__/util.cpython-310.pyc deleted file mode 100644 index 0890976472bd4db53fbbf6883313845de8002651..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/__pycache__/util.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/__pycache__/util.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/__pycache__/util.cpython-38.pyc deleted file mode 100644 index e30a7fbad8ea87fce676138a6a8e481191d09dd7..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/__pycache__/util.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/interface.py b/apps/third_party/CRM/imagedream/ldm/interface.py deleted file mode 100644 index 3bbeed7921ea30efb81573c3efed42d8375cb21a..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/imagedream/ldm/interface.py +++ /dev/null @@ -1,206 +0,0 @@ -from typing import List -from functools import partial - -import numpy as np -import torch -import torch.nn as nn - -from .modules.diffusionmodules.util import ( - make_beta_schedule, - extract_into_tensor, - enforce_zero_terminal_snr, - noise_like, -) -from .util import exists, default, instantiate_from_config -from .modules.distributions.distributions import DiagonalGaussianDistribution - - -class DiffusionWrapper(nn.Module): - def __init__(self, diffusion_model): - super().__init__() - self.diffusion_model = diffusion_model - - def forward(self, *args, **kwargs): - return self.diffusion_model(*args, **kwargs) - - -class LatentDiffusionInterface(nn.Module): - """a simple interface class for LDM inference""" - - def __init__( - self, - unet_config, - clip_config, - vae_config, - parameterization="eps", - scale_factor=0.18215, - beta_schedule="linear", - timesteps=1000, - linear_start=0.00085, - linear_end=0.0120, - cosine_s=8e-3, - given_betas=None, - zero_snr=False, - *args, - **kwargs, - ): - super().__init__() - - unet = instantiate_from_config(unet_config) - self.model = DiffusionWrapper(unet) - self.clip_model = instantiate_from_config(clip_config) - self.vae_model = instantiate_from_config(vae_config) - - self.parameterization = parameterization - self.scale_factor = scale_factor - self.register_schedule( - given_betas=given_betas, - beta_schedule=beta_schedule, - timesteps=timesteps, - linear_start=linear_start, - linear_end=linear_end, - cosine_s=cosine_s, - zero_snr=zero_snr - ) - - def register_schedule( - self, - given_betas=None, - beta_schedule="linear", - timesteps=1000, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - zero_snr=False - ): - if exists(given_betas): - betas = given_betas - else: - betas = make_beta_schedule( - beta_schedule, - timesteps, - linear_start=linear_start, - linear_end=linear_end, - cosine_s=cosine_s, - ) - if zero_snr: - print("--- using zero snr---") - betas = enforce_zero_terminal_snr(betas).numpy() - alphas = 1.0 - betas - alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) - - (timesteps,) = betas.shape - self.num_timesteps = int(timesteps) - self.linear_start = linear_start - self.linear_end = linear_end - assert ( - alphas_cumprod.shape[0] == self.num_timesteps - ), "alphas have to be defined for each timestep" - - to_torch = partial(torch.tensor, dtype=torch.float32) - - self.register_buffer("betas", to_torch(betas)) - self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) - self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) - self.register_buffer( - "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) - ) - self.register_buffer( - "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) - ) - eps = 1e-8 # adding small epsilon value to avoid devide by zero error - self.register_buffer( - "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps))) - ) - self.register_buffer( - "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps) - 1)) - ) - - # calculations for posterior q(x_{t-1} | x_t, x_0) - self.v_posterior = 0 - posterior_variance = (1 - self.v_posterior) * betas * ( - 1.0 - alphas_cumprod_prev - ) / (1.0 - alphas_cumprod) + self.v_posterior * betas - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.register_buffer("posterior_variance", to_torch(posterior_variance)) - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.register_buffer( - "posterior_log_variance_clipped", - to_torch(np.log(np.maximum(posterior_variance, 1e-20))), - ) - self.register_buffer( - "posterior_mean_coef1", - to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), - ) - self.register_buffer( - "posterior_mean_coef2", - to_torch( - (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) - ), - ) - - def q_sample(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - return ( - extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) - * noise - ) - - def get_v(self, x, noise, t): - return ( - extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x - ) - - def predict_start_from_noise(self, x_t, t, noise): - return ( - extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) - * noise - ) - - def predict_start_from_z_and_v(self, x_t, t, v): - return ( - extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v - ) - - def predict_eps_from_z_and_v(self, x_t, t, v): - return ( - extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) - * x_t - ) - - def apply_model(self, x_noisy, t, cond, **kwargs): - assert isinstance(cond, dict), "cond has to be a dictionary" - return self.model(x_noisy, t, **cond, **kwargs) - - def get_learned_conditioning(self, prompts: List[str]): - return self.clip_model(prompts) - - def get_learned_image_conditioning(self, images): - return self.clip_model.forward_image(images) - - def get_first_stage_encoding(self, encoder_posterior): - if isinstance(encoder_posterior, DiagonalGaussianDistribution): - z = encoder_posterior.sample() - elif isinstance(encoder_posterior, torch.Tensor): - z = encoder_posterior - else: - raise NotImplementedError( - f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" - ) - return self.scale_factor * z - - def encode_first_stage(self, x): - return self.vae_model.encode(x) - - def decode_first_stage(self, z): - z = 1.0 / self.scale_factor * z - return self.vae_model.decode(z) diff --git a/apps/third_party/CRM/imagedream/ldm/models/__init__.py b/apps/third_party/CRM/imagedream/ldm/models/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/apps/third_party/CRM/imagedream/ldm/models/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/models/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 672ea939d54368780d02dd86a921336018342e63..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/models/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/models/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/models/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index e64c9b27cd3f052820627d22906cadc4b3e9dd43..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/models/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/models/__pycache__/autoencoder.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/models/__pycache__/autoencoder.cpython-310.pyc deleted file mode 100644 index 6ab9a034b2a331fc87a7e50e1a44af00ff39efac..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/models/__pycache__/autoencoder.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/models/__pycache__/autoencoder.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/models/__pycache__/autoencoder.cpython-38.pyc deleted file mode 100644 index 5e3a48c1e459902bdb6f0621a7ede086b9c8cd08..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/models/__pycache__/autoencoder.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/models/autoencoder.py b/apps/third_party/CRM/imagedream/ldm/models/autoencoder.py deleted file mode 100644 index 92f83096ddaf2146772f6b49b23d8f99e787fbb4..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/imagedream/ldm/models/autoencoder.py +++ /dev/null @@ -1,270 +0,0 @@ -import torch -import torch.nn.functional as F -from contextlib import contextmanager - -from ..modules.diffusionmodules.model import Encoder, Decoder -from ..modules.distributions.distributions import DiagonalGaussianDistribution - -from ..util import instantiate_from_config -from ..modules.ema import LitEma - - -class AutoencoderKL(torch.nn.Module): - def __init__( - self, - ddconfig, - lossconfig, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - ema_decay=None, - learn_logvar=False, - ): - super().__init__() - self.learn_logvar = learn_logvar - self.image_key = image_key - self.encoder = Encoder(**ddconfig) - self.decoder = Decoder(**ddconfig) - self.loss = instantiate_from_config(lossconfig) - assert ddconfig["double_z"] - self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) - self.embed_dim = embed_dim - if colorize_nlabels is not None: - assert type(colorize_nlabels) == int - self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) - if monitor is not None: - self.monitor = monitor - - self.use_ema = ema_decay is not None - if self.use_ema: - self.ema_decay = ema_decay - assert 0.0 < ema_decay < 1.0 - self.model_ema = LitEma(self, decay=ema_decay) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") - - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - - def init_from_ckpt(self, path, ignore_keys=list()): - sd = torch.load(path, map_location="cpu")["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - self.load_state_dict(sd, strict=False) - print(f"Restored from {path}") - - @contextmanager - def ema_scope(self, context=None): - if self.use_ema: - self.model_ema.store(self.parameters()) - self.model_ema.copy_to(self) - if context is not None: - print(f"{context}: Switched to EMA weights") - try: - yield None - finally: - if self.use_ema: - self.model_ema.restore(self.parameters()) - if context is not None: - print(f"{context}: Restored training weights") - - def on_train_batch_end(self, *args, **kwargs): - if self.use_ema: - self.model_ema(self) - - def encode(self, x): - h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - return posterior - - def decode(self, z): - z = self.post_quant_conv(z) - dec = self.decoder(z) - return dec - - def forward(self, input, sample_posterior=True): - posterior = self.encode(input) - if sample_posterior: - z = posterior.sample() - else: - z = posterior.mode() - dec = self.decode(z) - return dec, posterior - - def get_input(self, batch, k): - x = batch[k] - if len(x.shape) == 3: - x = x[..., None] - x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() - return x - - def training_step(self, batch, batch_idx, optimizer_idx): - inputs = self.get_input(batch, self.image_key) - reconstructions, posterior = self(inputs) - - if optimizer_idx == 0: - # train encoder+decoder+logvar - aeloss, log_dict_ae = self.loss( - inputs, - reconstructions, - posterior, - optimizer_idx, - self.global_step, - last_layer=self.get_last_layer(), - split="train", - ) - self.log( - "aeloss", - aeloss, - prog_bar=True, - logger=True, - on_step=True, - on_epoch=True, - ) - self.log_dict( - log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False - ) - return aeloss - - if optimizer_idx == 1: - # train the discriminator - discloss, log_dict_disc = self.loss( - inputs, - reconstructions, - posterior, - optimizer_idx, - self.global_step, - last_layer=self.get_last_layer(), - split="train", - ) - - self.log( - "discloss", - discloss, - prog_bar=True, - logger=True, - on_step=True, - on_epoch=True, - ) - self.log_dict( - log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False - ) - return discloss - - def validation_step(self, batch, batch_idx): - log_dict = self._validation_step(batch, batch_idx) - with self.ema_scope(): - log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") - return log_dict - - def _validation_step(self, batch, batch_idx, postfix=""): - inputs = self.get_input(batch, self.image_key) - reconstructions, posterior = self(inputs) - aeloss, log_dict_ae = self.loss( - inputs, - reconstructions, - posterior, - 0, - self.global_step, - last_layer=self.get_last_layer(), - split="val" + postfix, - ) - - discloss, log_dict_disc = self.loss( - inputs, - reconstructions, - posterior, - 1, - self.global_step, - last_layer=self.get_last_layer(), - split="val" + postfix, - ) - - self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) - self.log_dict(log_dict_ae) - self.log_dict(log_dict_disc) - return self.log_dict - - def configure_optimizers(self): - lr = self.learning_rate - ae_params_list = ( - list(self.encoder.parameters()) - + list(self.decoder.parameters()) - + list(self.quant_conv.parameters()) - + list(self.post_quant_conv.parameters()) - ) - if self.learn_logvar: - print(f"{self.__class__.__name__}: Learning logvar") - ae_params_list.append(self.loss.logvar) - opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam( - self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9) - ) - return [opt_ae, opt_disc], [] - - def get_last_layer(self): - return self.decoder.conv_out.weight - - @torch.no_grad() - def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): - log = dict() - x = self.get_input(batch, self.image_key) - x = x.to(self.device) - if not only_inputs: - xrec, posterior = self(x) - if x.shape[1] > 3: - # colorize with random projection - assert xrec.shape[1] > 3 - x = self.to_rgb(x) - xrec = self.to_rgb(xrec) - log["samples"] = self.decode(torch.randn_like(posterior.sample())) - log["reconstructions"] = xrec - if log_ema or self.use_ema: - with self.ema_scope(): - xrec_ema, posterior_ema = self(x) - if x.shape[1] > 3: - # colorize with random projection - assert xrec_ema.shape[1] > 3 - xrec_ema = self.to_rgb(xrec_ema) - log["samples_ema"] = self.decode( - torch.randn_like(posterior_ema.sample()) - ) - log["reconstructions_ema"] = xrec_ema - log["inputs"] = x - return log - - def to_rgb(self, x): - assert self.image_key == "segmentation" - if not hasattr(self, "colorize"): - self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) - x = F.conv2d(x, weight=self.colorize) - x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 - return x - - -class IdentityFirstStage(torch.nn.Module): - def __init__(self, *args, vq_interface=False, **kwargs): - self.vq_interface = vq_interface - super().__init__() - - def encode(self, x, *args, **kwargs): - return x - - def decode(self, x, *args, **kwargs): - return x - - def quantize(self, x, *args, **kwargs): - if self.vq_interface: - return x, None, [None, None, None] - return x - - def forward(self, x, *args, **kwargs): - return x diff --git a/apps/third_party/CRM/imagedream/ldm/models/diffusion/__init__.py b/apps/third_party/CRM/imagedream/ldm/models/diffusion/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index a5c8f08217241b9ddc7d371105a50328a41ca897..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 9f083019b05ad6e286010c0cd0a21e6d9f9b2b9d..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc deleted file mode 100644 index 62e7bce7e47a73aa87c94c0cfdd772df724aff7a..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc deleted file mode 100644 index 0512edd43d2c8d4c576a7076c33332b30b0e0d6e..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/models/diffusion/ddim.py b/apps/third_party/CRM/imagedream/ldm/models/diffusion/ddim.py deleted file mode 100644 index 4c10321c39078e985f18a0ca7b388086a4aa4e2f..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/imagedream/ldm/models/diffusion/ddim.py +++ /dev/null @@ -1,430 +0,0 @@ -"""SAMPLING ONLY.""" - -import torch -import numpy as np -from tqdm import tqdm -from functools import partial - -from ...modules.diffusionmodules.util import ( - make_ddim_sampling_parameters, - make_ddim_timesteps, - noise_like, - extract_into_tensor, -) - - -class DDIMSampler(object): - def __init__(self, model, schedule="linear", **kwargs): - super().__init__() - self.model = model - self.ddpm_num_timesteps = model.num_timesteps - self.schedule = schedule - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) - setattr(self, name, attr) - - def make_schedule( - self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True - ): - self.ddim_timesteps = make_ddim_timesteps( - ddim_discr_method=ddim_discretize, - num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps, - verbose=verbose, - ) - alphas_cumprod = self.model.alphas_cumprod - assert ( - alphas_cumprod.shape[0] == self.ddpm_num_timesteps - ), "alphas have to be defined for each timestep" - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - - self.register_buffer("betas", to_torch(self.model.betas)) - self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) - self.register_buffer( - "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) - ) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer( - "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) - ) - self.register_buffer( - "sqrt_one_minus_alphas_cumprod", - to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), - ) - self.register_buffer( - "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) - ) - self.register_buffer( - "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) - ) - self.register_buffer( - "sqrt_recipm1_alphas_cumprod", - to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), - ) - - # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( - alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta, - verbose=verbose, - ) - self.register_buffer("ddim_sigmas", ddim_sigmas) - self.register_buffer("ddim_alphas", ddim_alphas) - self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) - self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) - sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) - / (1 - self.alphas_cumprod) - * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) - ) - self.register_buffer( - "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps - ) - - @torch.no_grad() - def sample( - self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0.0, - mask=None, - x0=None, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs, - ): - if conditioning is not None: - if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] - if cbs != batch_size: - print( - f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" - ) - else: - if conditioning.shape[0] != batch_size: - print( - f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" - ) - - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - - samples, intermediates = self.ddim_sampling( - conditioning, - size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, - x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - **kwargs, - ) - return samples, intermediates - - @torch.no_grad() - def ddim_sampling( - self, - cond, - shape, - x_T=None, - ddim_use_original_steps=False, - callback=None, - timesteps=None, - quantize_denoised=False, - mask=None, - x0=None, - img_callback=None, - log_every_t=100, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - **kwargs, - ): - """ - when inference time: all values of parameter - cond.keys(): dict_keys(['context', 'camera', 'num_frames', 'ip', 'ip_img']) - shape: (5, 4, 32, 32) - x_T: None - ddim_use_original_steps: False - timesteps: None - callback: None - quantize_denoised: False - mask: None - image_callback: None - log_every_t: 100 - temperature: 1.0 - noise_dropout: 0.0 - score_corrector: None - corrector_kwargs: None - unconditional_guidance_scale: 5 - unconditional_conditioning.keys(): dict_keys(['context', 'camera', 'num_frames', 'ip', 'ip_img']) - kwargs: {} - """ - device = self.model.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) # shape: torch.Size([5, 4, 32, 32]) mean: -0.00, std: 1.00, min: -3.64, max: 3.94 - else: - img = x_T - - if timesteps is None: # equal with set time step in hf - timesteps = ( - self.ddpm_num_timesteps - if ddim_use_original_steps - else self.ddim_timesteps - ) - elif timesteps is not None and not ddim_use_original_steps: - subset_end = ( - int( - min(timesteps / self.ddim_timesteps.shape[0], 1) - * self.ddim_timesteps.shape[0] - ) - - 1 - ) - timesteps = self.ddim_timesteps[:subset_end] - - intermediates = {"x_inter": [img], "pred_x0": [img]} - time_range = ( # reversed timesteps - reversed(range(0, timesteps)) - if ddim_use_original_steps - else np.flip(timesteps) - ) - total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] - iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) - - if mask is not None: - assert x0 is not None - img_orig = self.model.q_sample( - x0, ts - ) # TODO: deterministic forward pass? - img = img_orig * mask + (1.0 - mask) * img - - outs = self.p_sample_ddim( - img, - cond, - ts, - index=index, - use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, - temperature=temperature, - noise_dropout=noise_dropout, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - **kwargs, - ) - img, pred_x0 = outs - if callback: - callback(i) - if img_callback: - img_callback(pred_x0, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates["x_inter"].append(img) - intermediates["pred_x0"].append(pred_x0) - - return img, intermediates - - @torch.no_grad() - def p_sample_ddim( - self, - x, - c, - t, - index, - repeat_noise=False, - use_original_steps=False, - quantize_denoised=False, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - dynamic_threshold=None, - **kwargs, - ): - b, *_, device = *x.shape, x.device - - if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: - model_output = self.model.apply_model(x, t, c) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - if isinstance(c, dict): - assert isinstance(unconditional_conditioning, dict) - c_in = dict() - for k in c: - if isinstance(c[k], list): - c_in[k] = [ - torch.cat([unconditional_conditioning[k][i], c[k][i]]) - for i in range(len(c[k])) - ] - elif isinstance(c[k], torch.Tensor): - c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) - else: - assert c[k] == unconditional_conditioning[k] - c_in[k] = c[k] - elif isinstance(c, list): - c_in = list() - assert isinstance(unconditional_conditioning, list) - for i in range(len(c)): - c_in.append(torch.cat([unconditional_conditioning[i], c[i]])) - else: - c_in = torch.cat([unconditional_conditioning, c]) - model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - model_output = model_uncond + unconditional_guidance_scale * ( - model_t - model_uncond - ) - - - if self.model.parameterization == "v": - print("using v!") - e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) - else: - e_t = model_output - - if score_corrector is not None: - assert self.model.parameterization == "eps", "not implemented" - e_t = score_corrector.modify_score( - self.model, e_t, x, t, c, **corrector_kwargs - ) - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = ( - self.model.alphas_cumprod_prev - if use_original_steps - else self.ddim_alphas_prev - ) - sqrt_one_minus_alphas = ( - self.model.sqrt_one_minus_alphas_cumprod - if use_original_steps - else self.ddim_sqrt_one_minus_alphas - ) - sigmas = ( - self.model.ddim_sigmas_for_original_num_steps - if use_original_steps - else self.ddim_sigmas - ) - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full( - (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device - ) - - # current prediction for x_0 - if self.model.parameterization != "v": - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - else: - pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) - - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - - if dynamic_threshold is not None: - raise NotImplementedError() - - # direction pointing to x_t - dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.0: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - @torch.no_grad() - def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): - # fast, but does not allow for exact reconstruction - # t serves as an index to gather the correct alphas - if use_original_steps: - sqrt_alphas_cumprod = self.sqrt_alphas_cumprod - sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod - else: - sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) - sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas - - if noise is None: - noise = torch.randn_like(x0) - return ( - extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 - + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise - ) - - @torch.no_grad() - def decode( - self, - x_latent, - cond, - t_start, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - use_original_steps=False, - **kwargs, - ): - timesteps = ( - np.arange(self.ddpm_num_timesteps) - if use_original_steps - else self.ddim_timesteps - ) - timesteps = timesteps[:t_start] - - time_range = np.flip(timesteps) - total_steps = timesteps.shape[0] - - iterator = tqdm(time_range, desc="Decoding image", total=total_steps) - x_dec = x_latent - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full( - (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long - ) - x_dec, _ = self.p_sample_ddim( - x_dec, - cond, - ts, - index=index, - use_original_steps=use_original_steps, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - **kwargs, - ) - return x_dec diff --git a/apps/third_party/CRM/imagedream/ldm/modules/__init__.py b/apps/third_party/CRM/imagedream/ldm/modules/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 188b2eab520e6aaf281181afb6d84e1a3098a914..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 8b89c02f3c6827cdac3513cca77f2efd913115b0..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/attention.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/attention.cpython-310.pyc deleted file mode 100644 index e77dd1a9bb9eae984eb7a6f2f33a98a54662d64a..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/attention.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/attention.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/attention.cpython-38.pyc deleted file mode 100644 index 378dad50ffdfa150bf000062d153d2dd6fb715b0..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/attention.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/ema.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/ema.cpython-310.pyc deleted file mode 100644 index a6c836e51e276424666c8d22441a804bf8ae4722..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/ema.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/ema.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/ema.cpython-38.pyc deleted file mode 100644 index ed990a3f549b22da782b4f7e7cfe07bf64d4969c..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/ema.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/attention.py b/apps/third_party/CRM/imagedream/ldm/modules/attention.py deleted file mode 100644 index 9578027d3d9e9b8766941dc0c986d42cd93b04bf..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/imagedream/ldm/modules/attention.py +++ /dev/null @@ -1,456 +0,0 @@ -from inspect import isfunction -import math -import torch -import torch.nn.functional as F -from torch import nn, einsum -from einops import rearrange, repeat -from typing import Optional, Any - -from .diffusionmodules.util import checkpoint - - -try: - import xformers - import xformers.ops - - XFORMERS_IS_AVAILBLE = True -except: - XFORMERS_IS_AVAILBLE = False - -# CrossAttn precision handling -import os - -_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") - - -def exists(val): - return val is not None - - -def uniq(arr): - return {el: True for el in arr}.keys() - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def max_neg_value(t): - return -torch.finfo(t.dtype).max - - -def init_(tensor): - dim = tensor.shape[-1] - std = 1 / math.sqrt(dim) - tensor.uniform_(-std, std) - return tensor - - -# feedforward -class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def forward(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) - - -class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): - super().__init__() - inner_dim = int(dim * mult) - dim_out = default(dim_out, dim) - project_in = ( - nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) - if not glu - else GEGLU(dim, inner_dim) - ) - - self.net = nn.Sequential( - project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) - ) - - def forward(self, x): - return self.net(x) - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -def Normalize(in_channels): - return torch.nn.GroupNorm( - num_groups=32, num_channels=in_channels, eps=1e-6, affine=True - ) - - -class SpatialSelfAttention(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q.shape - q = rearrange(q, "b c h w -> b (h w) c") - k = rearrange(k, "b c h w -> b c (h w)") - w_ = torch.einsum("bij,bjk->bik", q, k) - - w_ = w_ * (int(c) ** (-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = rearrange(v, "b c h w -> b c (h w)") - w_ = rearrange(w_, "b i j -> b j i") - h_ = torch.einsum("bij,bjk->bik", v, w_) - h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) - h_ = self.proj_out(h_) - - return x + h_ - - -class MemoryEfficientCrossAttention(nn.Module): - # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs): - super().__init__() - print( - f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " - f"{heads} heads." - ) - inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) - - self.heads = heads - self.dim_head = dim_head - - self.with_ip = kwargs.get("with_ip", False) - if self.with_ip and (context_dim is not None): - self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False) - self.ip_dim= kwargs.get("ip_dim", 16) - self.ip_weight = kwargs.get("ip_weight", 1.0) - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) - ) - self.attention_op = None - - def forward(self, x, context=None, mask=None): - q = self.to_q(x) - - has_ip = self.with_ip and (context is not None) - if has_ip: - # context dim [(b frame_num), (77 + img_token), 1024] - token_len = context.shape[1] - context_ip = context[:, -self.ip_dim:, :] - k_ip = self.to_k_ip(context_ip) - v_ip = self.to_v_ip(context_ip) - context = context[:, :(token_len - self.ip_dim), :] - - context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) - - b, _, _ = q.shape - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(b, t.shape[1], self.heads, self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b * self.heads, t.shape[1], self.dim_head) - .contiguous(), - (q, k, v), - ) - - # actually compute the attention, what we cannot get enough of - out = xformers.ops.memory_efficient_attention( - q, k, v, attn_bias=None, op=self.attention_op - ) - - if has_ip: - k_ip, v_ip = map( - lambda t: t.unsqueeze(3) - .reshape(b, t.shape[1], self.heads, self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b * self.heads, t.shape[1], self.dim_head) - .contiguous(), - (k_ip, v_ip), - ) - # actually compute the attention, what we cannot get enough of - out_ip = xformers.ops.memory_efficient_attention( - q, k_ip, v_ip, attn_bias=None, op=self.attention_op - ) - out = out + self.ip_weight * out_ip - - if exists(mask): - raise NotImplementedError - out = ( - out.unsqueeze(0) - .reshape(b, self.heads, out.shape[1], self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b, out.shape[1], self.heads * self.dim_head) - ) - return self.to_out(out) - - -class BasicTransformerBlock(nn.Module): - def __init__( - self, - dim, - n_heads, - d_head, - dropout=0.0, - context_dim=None, - gated_ff=True, - checkpoint=True, - disable_self_attn=False, - **kwargs - ): - super().__init__() - assert XFORMERS_IS_AVAILBLE, "xformers is not available" - attn_cls = MemoryEfficientCrossAttention - self.disable_self_attn = disable_self_attn - self.attn1 = attn_cls( - query_dim=dim, - heads=n_heads, - dim_head=d_head, - dropout=dropout, - context_dim=context_dim if self.disable_self_attn else None, - ) # is a self-attention if not self.disable_self_attn - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = attn_cls( - query_dim=dim, - context_dim=context_dim, - heads=n_heads, - dim_head=d_head, - dropout=dropout, - **kwargs - ) # is self-attn if context is none - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - self.norm3 = nn.LayerNorm(dim) - self.checkpoint = checkpoint - - def forward(self, x, context=None): - return checkpoint( - self._forward, (x, context), self.parameters(), self.checkpoint - ) - - def _forward(self, x, context=None): - x = ( - self.attn1( - self.norm1(x), context=context if self.disable_self_attn else None - ) - + x - ) - x = self.attn2(self.norm2(x), context=context) + x - x = self.ff(self.norm3(x)) + x - return x - - -class SpatialTransformer(nn.Module): - """ - Transformer block for image-like data. - First, project the input (aka embedding) - and reshape to b, t, d. - Then apply standard transformer action. - Finally, reshape to image - NEW: use_linear for more efficiency instead of the 1x1 convs - """ - - def __init__( - self, - in_channels, - n_heads, - d_head, - depth=1, - dropout=0.0, - context_dim=None, - disable_self_attn=False, - use_linear=False, - use_checkpoint=True, - **kwargs - ): - super().__init__() - if exists(context_dim) and not isinstance(context_dim, list): - context_dim = [context_dim] - self.in_channels = in_channels - inner_dim = n_heads * d_head - self.norm = Normalize(in_channels) - if not use_linear: - self.proj_in = nn.Conv2d( - in_channels, inner_dim, kernel_size=1, stride=1, padding=0 - ) - else: - self.proj_in = nn.Linear(in_channels, inner_dim) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - n_heads, - d_head, - dropout=dropout, - context_dim=context_dim[d], - disable_self_attn=disable_self_attn, - checkpoint=use_checkpoint, - **kwargs - ) - for d in range(depth) - ] - ) - if not use_linear: - self.proj_out = zero_module( - nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - ) - else: - self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) - self.use_linear = use_linear - - def forward(self, x, context=None): - # note: if no context is given, cross-attention defaults to self-attention - if not isinstance(context, list): - context = [context] - b, c, h, w = x.shape - x_in = x - x = self.norm(x) - if not self.use_linear: - x = self.proj_in(x) - x = rearrange(x, "b c h w -> b (h w) c").contiguous() - if self.use_linear: - x = self.proj_in(x) - for i, block in enumerate(self.transformer_blocks): - x = block(x, context=context[i]) - if self.use_linear: - x = self.proj_out(x) - x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() - if not self.use_linear: - x = self.proj_out(x) - return x + x_in - - -class BasicTransformerBlock3D(BasicTransformerBlock): - def forward(self, x, context=None, num_frames=1): - return checkpoint( - self._forward, (x, context, num_frames), self.parameters(), self.checkpoint - ) - - def _forward(self, x, context=None, num_frames=1): - x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous() - x = ( - self.attn1( - self.norm1(x), - context=context if self.disable_self_attn else None - ) - + x - ) - x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous() - x = self.attn2(self.norm2(x), context=context) + x - x = self.ff(self.norm3(x)) + x - return x - - -class SpatialTransformer3D(nn.Module): - """3D self-attention""" - - def __init__( - self, - in_channels, - n_heads, - d_head, - depth=1, - dropout=0.0, - context_dim=None, - disable_self_attn=False, - use_linear=False, - use_checkpoint=True, - **kwargs - ): - super().__init__() - if exists(context_dim) and not isinstance(context_dim, list): - context_dim = [context_dim] - self.in_channels = in_channels - inner_dim = n_heads * d_head - self.norm = Normalize(in_channels) - if not use_linear: - self.proj_in = nn.Conv2d( - in_channels, inner_dim, kernel_size=1, stride=1, padding=0 - ) - else: - self.proj_in = nn.Linear(in_channels, inner_dim) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock3D( - inner_dim, - n_heads, - d_head, - dropout=dropout, - context_dim=context_dim[d], - disable_self_attn=disable_self_attn, - checkpoint=use_checkpoint, - **kwargs - ) - for d in range(depth) - ] - ) - if not use_linear: - self.proj_out = zero_module( - nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - ) - else: - self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) - self.use_linear = use_linear - - def forward(self, x, context=None, num_frames=1): - # note: if no context is given, cross-attention defaults to self-attention - if not isinstance(context, list): - context = [context] - b, c, h, w = x.shape - x_in = x - x = self.norm(x) - if not self.use_linear: - x = self.proj_in(x) - x = rearrange(x, "b c h w -> b (h w) c").contiguous() - if self.use_linear: - x = self.proj_in(x) - for i, block in enumerate(self.transformer_blocks): - x = block(x, context=context[i], num_frames=num_frames) - if self.use_linear: - x = self.proj_out(x) - x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() - if not self.use_linear: - x = self.proj_out(x) - return x + x_in diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__init__.py b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 87936969c3da946fe6a3d832741e28e3f8c5a465..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 709d11f26d4279dc79749621e1fadb3d3214664c..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/adaptors.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/adaptors.cpython-310.pyc deleted file mode 100644 index a776d040c73ba392b8c4498ceae40daa0ac5c375..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/adaptors.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/adaptors.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/adaptors.cpython-38.pyc deleted file mode 100644 index 009edfe033ed8828bdc0ae3ae92f56076fe6752a..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/adaptors.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc deleted file mode 100644 index e42f7d66fb4493c5d29a2e00d78db7ee9459f090..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc deleted file mode 100644 index ef2270fc4896075b92413258975349f6c68c9bec..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc deleted file mode 100644 index a1a0a6de705d51589cf78005348fa64f9e7eb29e..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc deleted file mode 100644 index f5e8589aeb17fb3c779c10c01f058ddaf3165198..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc deleted file mode 100644 index 8aa9c29edf8fd31ddfd84c3b6784e57a6037b2fb..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc deleted file mode 100644 index 0a455c6a98a6d27b74477615cf6bf53f0899da51..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/adaptors.py b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/adaptors.py deleted file mode 100644 index 8d66e480728073294015cf0eb906dba471d602ca..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/adaptors.py +++ /dev/null @@ -1,163 +0,0 @@ -# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py -import math - -import torch -import torch.nn as nn - - -# FFN -def FeedForward(dim, mult=4): - inner_dim = int(dim * mult) - return nn.Sequential( - nn.LayerNorm(dim), - nn.Linear(dim, inner_dim, bias=False), - nn.GELU(), - nn.Linear(inner_dim, dim, bias=False), - ) - - -def reshape_tensor(x, heads): - bs, length, width = x.shape - #(bs, length, width) --> (bs, length, n_heads, dim_per_head) - x = x.view(bs, length, heads, -1) - # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) - x = x.transpose(1, 2) - # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) - x = x.reshape(bs, heads, length, -1) - return x - - -class PerceiverAttention(nn.Module): - def __init__(self, *, dim, dim_head=64, heads=8): - super().__init__() - self.scale = dim_head**-0.5 - self.dim_head = dim_head - self.heads = heads - inner_dim = dim_head * heads - - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - - self.to_q = nn.Linear(dim, inner_dim, bias=False) - self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) - self.to_out = nn.Linear(inner_dim, dim, bias=False) - - - def forward(self, x, latents): - """ - Args: - x (torch.Tensor): image features - shape (b, n1, D) - latent (torch.Tensor): latent features - shape (b, n2, D) - """ - x = self.norm1(x) - latents = self.norm2(latents) - - b, l, _ = latents.shape - - q = self.to_q(latents) - kv_input = torch.cat((x, latents), dim=-2) - k, v = self.to_kv(kv_input).chunk(2, dim=-1) - - q = reshape_tensor(q, self.heads) - k = reshape_tensor(k, self.heads) - v = reshape_tensor(v, self.heads) - - # attention - scale = 1 / math.sqrt(math.sqrt(self.dim_head)) - weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - out = weight @ v - - out = out.permute(0, 2, 1, 3).reshape(b, l, -1) - - return self.to_out(out) - - -class ImageProjModel(torch.nn.Module): - """Projection Model""" - def __init__(self, - cross_attention_dim=1024, - clip_embeddings_dim=1024, - clip_extra_context_tokens=4): - super().__init__() - self.cross_attention_dim = cross_attention_dim - self.clip_extra_context_tokens = clip_extra_context_tokens - - # from 1024 -> 4 * 1024 - self.proj = torch.nn.Linear( - clip_embeddings_dim, - self.clip_extra_context_tokens * cross_attention_dim) - self.norm = torch.nn.LayerNorm(cross_attention_dim) - - def forward(self, image_embeds): - embeds = image_embeds - clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) - clip_extra_context_tokens = self.norm(clip_extra_context_tokens) - return clip_extra_context_tokens - - -class SimpleReSampler(nn.Module): - def __init__(self, embedding_dim=1280, output_dim=1024): - super().__init__() - self.proj_out = nn.Linear(embedding_dim, output_dim) - self.norm_out = nn.LayerNorm(output_dim) - - def forward(self, latents): - """ - latents: B 256 N - """ - latents = self.proj_out(latents) - return self.norm_out(latents) - - -class Resampler(nn.Module): - def __init__( - self, - dim=1024, - depth=8, - dim_head=64, - heads=16, - num_queries=8, - embedding_dim=768, - output_dim=1024, - ff_mult=4, - ): - super().__init__() - self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) - self.proj_in = nn.Linear(embedding_dim, dim) - self.proj_out = nn.Linear(dim, output_dim) - self.norm_out = nn.LayerNorm(output_dim) - - self.layers = nn.ModuleList([]) - for _ in range(depth): - self.layers.append( - nn.ModuleList( - [ - PerceiverAttention(dim=dim, - dim_head=dim_head, - heads=heads), - FeedForward(dim=dim, mult=ff_mult), - ] - ) - ) - - def forward(self, x): - latents = self.latents.repeat(x.size(0), 1, 1) - x = self.proj_in(x) - for attn, ff in self.layers: - latents = attn(x, latents) + latents - latents = ff(latents) + latents - - latents = self.proj_out(latents) - return self.norm_out(latents) - - -if __name__ == '__main__': - resampler = Resampler(embedding_dim=1280) - resampler = SimpleReSampler(embedding_dim=1280) - tensor = torch.rand(4, 257, 1280) - embed = resampler(tensor) - # embed = (tensor) - print(embed.shape) diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/model.py b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/model.py deleted file mode 100644 index 52a1d5c8e8ba62dd25133ffe76d370c637c5d25e..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/model.py +++ /dev/null @@ -1,1018 +0,0 @@ -# pytorch_diffusion + derived encoder decoder -import math -import torch -import torch.nn as nn -import numpy as np -from einops import rearrange -from typing import Optional, Any - -from ..attention import MemoryEfficientCrossAttention - -try: - import xformers - import xformers.ops - - XFORMERS_IS_AVAILBLE = True -except: - XFORMERS_IS_AVAILBLE = False - print("No module 'xformers'. Proceeding without it.") - - -def get_timestep_embedding(timesteps, embedding_dim): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: - From Fairseq. - Build sinusoidal embeddings. - This matches the implementation in tensor2tensor, but differs slightly - from the description in Section 3.5 of "Attention Is All You Need". - """ - assert len(timesteps.shape) == 1 - - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) - emb = emb.to(device=timesteps.device) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - - -def nonlinearity(x): - # swish - return x * torch.sigmoid(x) - - -def Normalize(in_channels, num_groups=32): - return torch.nn.GroupNorm( - num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True - ) - - -class Upsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if self.with_conv: - x = self.conv(x) - return x - - -class Downsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=2, padding=0 - ) - - def forward(self, x): - if self.with_conv: - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - else: - x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - return x - - -class ResnetBlock(nn.Module): - def __init__( - self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout, - temb_channels=512, - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) - self.norm2 = Normalize(out_channels) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d( - out_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - else: - self.nin_shortcut = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - - if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x + h - - -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q.shape - q = q.reshape(b, c, h * w) - q = q.permute(0, 2, 1) # b,hw,c - k = k.reshape(b, c, h * w) # b,c,hw - w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c) ** (-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b, c, h * w) - w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b, c, h, w) - - h_ = self.proj_out(h_) - - return x + h_ - - -class MemoryEfficientAttnBlock(nn.Module): - """ - Uses xformers efficient implementation, - see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - Note: this is a single-head self-attention operation - """ - - # - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.attention_op = None - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - B, C, H, W = q.shape - q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) - - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(B, t.shape[1], 1, C) - .permute(0, 2, 1, 3) - .reshape(B * 1, t.shape[1], C) - .contiguous(), - (q, k, v), - ) - out = xformers.ops.memory_efficient_attention( - q, k, v, attn_bias=None, op=self.attention_op - ) - - out = ( - out.unsqueeze(0) - .reshape(B, 1, out.shape[1], C) - .permute(0, 2, 1, 3) - .reshape(B, out.shape[1], C) - ) - out = rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) - out = self.proj_out(out) - return x + out - - -class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): - def forward(self, x, context=None, mask=None): - b, c, h, w = x.shape - x = rearrange(x, "b c h w -> b (h w) c") - out = super().forward(x, context=context, mask=mask) - out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c) - return x + out - - -def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): - assert attn_type in [ - "vanilla", - "vanilla-xformers", - "memory-efficient-cross-attn", - "linear", - "none", - ], f"attn_type {attn_type} unknown" - if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": - attn_type = "vanilla-xformers" - print(f"making attention of type '{attn_type}' with {in_channels} in_channels") - if attn_type == "vanilla": - assert attn_kwargs is None - return AttnBlock(in_channels) - elif attn_type == "vanilla-xformers": - print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") - return MemoryEfficientAttnBlock(in_channels) - elif type == "memory-efficient-cross-attn": - attn_kwargs["query_dim"] = in_channels - return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) - elif attn_type == "none": - return nn.Identity(in_channels) - else: - raise NotImplementedError() - - -class Model(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - use_timestep=True, - use_linear_attn=False, - attn_type="vanilla", - ): - super().__init__() - if use_linear_attn: - attn_type = "linear" - self.ch = ch - self.temb_ch = self.ch * 4 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - self.use_timestep = use_timestep - if self.use_timestep: - # timestep embedding - self.temb = nn.Module() - self.temb.dense = nn.ModuleList( - [ - torch.nn.Linear(self.ch, self.temb_ch), - torch.nn.Linear(self.temb_ch, self.temb_ch), - ] - ) - - # downsampling - self.conv_in = torch.nn.Conv2d( - in_channels, self.ch, kernel_size=3, stride=1, padding=1 - ) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - skip_in = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - if i_block == self.num_res_blocks: - skip_in = ch * in_ch_mult[i_level] - block.append( - ResnetBlock( - in_channels=block_in + skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, out_ch, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x, t=None, context=None): - # assert x.shape[2] == x.shape[3] == self.resolution - if context is not None: - # assume aligned context, cat along channel axis - x = torch.cat((x, context), dim=1) - if self.use_timestep: - # timestep embedding - assert t is not None - temb = get_timestep_embedding(t, self.ch) - temb = self.temb.dense[0](temb) - temb = nonlinearity(temb) - temb = self.temb.dense[1](temb) - else: - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block]( - torch.cat([h, hs.pop()], dim=1), temb - ) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - def get_last_layer(self): - return self.conv_out.weight - - -class Encoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - double_z=True, - use_linear_attn=False, - attn_type="vanilla", - **ignore_kwargs, - ): - super().__init__() - if use_linear_attn: - attn_type = "linear" - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - # downsampling - self.conv_in = torch.nn.Conv2d( - in_channels, self.ch, kernel_size=3, stride=1, padding=1 - ) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.in_ch_mult = in_ch_mult - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, - 2 * z_channels if double_z else z_channels, - kernel_size=3, - stride=1, - padding=1, - ) - - def forward(self, x): - # timestep embedding - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class Decoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - give_pre_end=False, - tanh_out=False, - use_linear_attn=False, - attn_type="vanilla", - **ignorekwargs, - ): - super().__init__() - if use_linear_attn: - attn_type = "linear" - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.give_pre_end = give_pre_end - self.tanh_out = tanh_out - - # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,) + tuple(ch_mult) - block_in = ch * ch_mult[self.num_resolutions - 1] - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, z_channels, curr_res, curr_res) - print( - "Working with z of shape {} = {} dimensions.".format( - self.z_shape, np.prod(self.z_shape) - ) - ) - - # z to block_in - self.conv_in = torch.nn.Conv2d( - z_channels, block_in, kernel_size=3, stride=1, padding=1 - ) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, out_ch, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, z): - # assert z.shape[1:] == self.z_shape[1:] - self.last_z_shape = z.shape - - # timestep embedding - temb = None - - # z to block_in - h = self.conv_in(z) - - # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h, temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - if self.give_pre_end: - return h - - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - if self.tanh_out: - h = torch.tanh(h) - return h - - -class SimpleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, *args, **kwargs): - super().__init__() - self.model = nn.ModuleList( - [ - nn.Conv2d(in_channels, in_channels, 1), - ResnetBlock( - in_channels=in_channels, - out_channels=2 * in_channels, - temb_channels=0, - dropout=0.0, - ), - ResnetBlock( - in_channels=2 * in_channels, - out_channels=4 * in_channels, - temb_channels=0, - dropout=0.0, - ), - ResnetBlock( - in_channels=4 * in_channels, - out_channels=2 * in_channels, - temb_channels=0, - dropout=0.0, - ), - nn.Conv2d(2 * in_channels, in_channels, 1), - Upsample(in_channels, with_conv=True), - ] - ) - # end - self.norm_out = Normalize(in_channels) - self.conv_out = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x): - for i, layer in enumerate(self.model): - if i in [1, 2, 3]: - x = layer(x, None) - else: - x = layer(x) - - h = self.norm_out(x) - h = nonlinearity(h) - x = self.conv_out(h) - return x - - -class UpsampleDecoder(nn.Module): - def __init__( - self, - in_channels, - out_channels, - ch, - num_res_blocks, - resolution, - ch_mult=(2, 2), - dropout=0.0, - ): - super().__init__() - # upsampling - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - block_in = in_channels - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.res_blocks = nn.ModuleList() - self.upsample_blocks = nn.ModuleList() - for i_level in range(self.num_resolutions): - res_block = [] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - res_block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - self.res_blocks.append(nn.ModuleList(res_block)) - if i_level != self.num_resolutions - 1: - self.upsample_blocks.append(Upsample(block_in, True)) - curr_res = curr_res * 2 - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, out_channels, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x): - # upsampling - h = x - for k, i_level in enumerate(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.res_blocks[i_level][i_block](h, None) - if i_level != self.num_resolutions - 1: - h = self.upsample_blocks[k](h) - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class LatentRescaler(nn.Module): - def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): - super().__init__() - # residual block, interpolate, residual block - self.factor = factor - self.conv_in = nn.Conv2d( - in_channels, mid_channels, kernel_size=3, stride=1, padding=1 - ) - self.res_block1 = nn.ModuleList( - [ - ResnetBlock( - in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0, - ) - for _ in range(depth) - ] - ) - self.attn = AttnBlock(mid_channels) - self.res_block2 = nn.ModuleList( - [ - ResnetBlock( - in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0, - ) - for _ in range(depth) - ] - ) - - self.conv_out = nn.Conv2d( - mid_channels, - out_channels, - kernel_size=1, - ) - - def forward(self, x): - x = self.conv_in(x) - for block in self.res_block1: - x = block(x, None) - x = torch.nn.functional.interpolate( - x, - size=( - int(round(x.shape[2] * self.factor)), - int(round(x.shape[3] * self.factor)), - ), - ) - x = self.attn(x) - for block in self.res_block2: - x = block(x, None) - x = self.conv_out(x) - return x - - -class MergedRescaleEncoder(nn.Module): - def __init__( - self, - in_channels, - ch, - resolution, - out_ch, - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - ch_mult=(1, 2, 4, 8), - rescale_factor=1.0, - rescale_module_depth=1, - ): - super().__init__() - intermediate_chn = ch * ch_mult[-1] - self.encoder = Encoder( - in_channels=in_channels, - num_res_blocks=num_res_blocks, - ch=ch, - ch_mult=ch_mult, - z_channels=intermediate_chn, - double_z=False, - resolution=resolution, - attn_resolutions=attn_resolutions, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - out_ch=None, - ) - self.rescaler = LatentRescaler( - factor=rescale_factor, - in_channels=intermediate_chn, - mid_channels=intermediate_chn, - out_channels=out_ch, - depth=rescale_module_depth, - ) - - def forward(self, x): - x = self.encoder(x) - x = self.rescaler(x) - return x - - -class MergedRescaleDecoder(nn.Module): - def __init__( - self, - z_channels, - out_ch, - resolution, - num_res_blocks, - attn_resolutions, - ch, - ch_mult=(1, 2, 4, 8), - dropout=0.0, - resamp_with_conv=True, - rescale_factor=1.0, - rescale_module_depth=1, - ): - super().__init__() - tmp_chn = z_channels * ch_mult[-1] - self.decoder = Decoder( - out_ch=out_ch, - z_channels=tmp_chn, - attn_resolutions=attn_resolutions, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - in_channels=None, - num_res_blocks=num_res_blocks, - ch_mult=ch_mult, - resolution=resolution, - ch=ch, - ) - self.rescaler = LatentRescaler( - factor=rescale_factor, - in_channels=z_channels, - mid_channels=tmp_chn, - out_channels=tmp_chn, - depth=rescale_module_depth, - ) - - def forward(self, x): - x = self.rescaler(x) - x = self.decoder(x) - return x - - -class Upsampler(nn.Module): - def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): - super().__init__() - assert out_size >= in_size - num_blocks = int(np.log2(out_size // in_size)) + 1 - factor_up = 1.0 + (out_size % in_size) - print( - f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" - ) - self.rescaler = LatentRescaler( - factor=factor_up, - in_channels=in_channels, - mid_channels=2 * in_channels, - out_channels=in_channels, - ) - self.decoder = Decoder( - out_ch=out_channels, - resolution=out_size, - z_channels=in_channels, - num_res_blocks=2, - attn_resolutions=[], - in_channels=None, - ch=in_channels, - ch_mult=[ch_mult for _ in range(num_blocks)], - ) - - def forward(self, x): - x = self.rescaler(x) - x = self.decoder(x) - return x - - -class Resize(nn.Module): - def __init__(self, in_channels=None, learned=False, mode="bilinear"): - super().__init__() - self.with_conv = learned - self.mode = mode - if self.with_conv: - print( - f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" - ) - raise NotImplementedError() - assert in_channels is not None - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=4, stride=2, padding=1 - ) - - def forward(self, x, scale_factor=1.0): - if scale_factor == 1.0: - return x - else: - x = torch.nn.functional.interpolate( - x, mode=self.mode, align_corners=False, scale_factor=scale_factor - ) - return x diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/openaimodel.py b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/openaimodel.py deleted file mode 100644 index 2f12a389584a729e1177af8683d631e5c5d77fd5..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/openaimodel.py +++ /dev/null @@ -1,1135 +0,0 @@ -from abc import abstractmethod -import math - -import numpy as np -import torch -import torch as th -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange, repeat - -from imagedream.ldm.modules.diffusionmodules.util import ( - checkpoint, - conv_nd, - linear, - avg_pool_nd, - zero_module, - normalization, - timestep_embedding, - convert_module_to_f16, - convert_module_to_f32 -) -from imagedream.ldm.modules.attention import ( - SpatialTransformer, - SpatialTransformer3D, - exists -) -from imagedream.ldm.modules.diffusionmodules.adaptors import ( - Resampler, - ImageProjModel -) - -## go -class AttentionPool2d(nn.Module): - """ - Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py - """ - - def __init__( - self, - spacial_dim: int, - embed_dim: int, - num_heads_channels: int, - output_dim: int = None, - ): - super().__init__() - self.positional_embedding = nn.Parameter( - th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 - ) - self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) - self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) - self.num_heads = embed_dim // num_heads_channels - self.attention = QKVAttention(self.num_heads) - - def forward(self, x): - b, c, *_spatial = x.shape - x = x.reshape(b, c, -1) # NC(HW) - x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) - x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) - x = self.qkv_proj(x) - x = self.attention(x) - x = self.c_proj(x) - return x[:, :, 0] - - -class TimestepBlock(nn.Module): - """ - Any module where forward() takes timestep embeddings as a second argument. - """ - - @abstractmethod - def forward(self, x, emb): - """ - Apply the module to `x` given `emb` timestep embeddings. - """ - - -class TimestepEmbedSequential(nn.Sequential, TimestepBlock): - """ - A sequential module that passes timestep embeddings to the children that - support it as an extra input. - """ - - def forward(self, x, emb, context=None, num_frames=1): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb) - elif isinstance(layer, SpatialTransformer3D): - x = layer(x, context, num_frames=num_frames) - elif isinstance(layer, SpatialTransformer): - x = layer(x, context) - else: - x = layer(x) - return x - - -class Upsample(nn.Module): - """ - An upsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - upsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - if use_conv: - self.conv = conv_nd( - dims, self.channels, self.out_channels, 3, padding=padding - ) - - def forward(self, x): - assert x.shape[1] == self.channels - if self.dims == 3: - x = F.interpolate( - x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" - ) - else: - x = F.interpolate(x, scale_factor=2, mode="nearest") - if self.use_conv: - x = self.conv(x) - return x - - -class TransposedUpsample(nn.Module): - "Learned 2x upsampling without padding" - - def __init__(self, channels, out_channels=None, ks=5): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - - self.up = nn.ConvTranspose2d( - self.channels, self.out_channels, kernel_size=ks, stride=2 - ) - - def forward(self, x): - return self.up(x) - - -class Downsample(nn.Module): - """ - A downsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - downsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - stride = 2 if dims != 3 else (1, 2, 2) - if use_conv: - self.op = conv_nd( - dims, - self.channels, - self.out_channels, - 3, - stride=stride, - padding=padding, - ) - else: - assert self.channels == self.out_channels - self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) - - def forward(self, x): - assert x.shape[1] == self.channels - return self.op(x) - - -class ResBlock(TimestepBlock): - """ - A residual block that can optionally change the number of channels. - :param channels: the number of input channels. - :param emb_channels: the number of timestep embedding channels. - :param dropout: the rate of dropout. - :param out_channels: if specified, the number of out channels. - :param use_conv: if True and out_channels is specified, use a spatial - convolution instead of a smaller 1x1 convolution to change the - channels in the skip connection. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param use_checkpoint: if True, use gradient checkpointing on this module. - :param up: if True, use this block for upsampling. - :param down: if True, use this block for downsampling. - """ - - def __init__( - self, - channels, - emb_channels, - dropout, - out_channels=None, - use_conv=False, - use_scale_shift_norm=False, - dims=2, - use_checkpoint=False, - up=False, - down=False, - ): - super().__init__() - self.channels = channels - self.emb_channels = emb_channels - self.dropout = dropout - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_checkpoint = use_checkpoint - self.use_scale_shift_norm = use_scale_shift_norm - - self.in_layers = nn.Sequential( - normalization(channels), - nn.SiLU(), - conv_nd(dims, channels, self.out_channels, 3, padding=1), - ) - - self.updown = up or down - - if up: - self.h_upd = Upsample(channels, False, dims) - self.x_upd = Upsample(channels, False, dims) - elif down: - self.h_upd = Downsample(channels, False, dims) - self.x_upd = Downsample(channels, False, dims) - else: - self.h_upd = self.x_upd = nn.Identity() - - self.emb_layers = nn.Sequential( - nn.SiLU(), - linear( - emb_channels, - 2 * self.out_channels if use_scale_shift_norm else self.out_channels, - ), - ) - self.out_layers = nn.Sequential( - normalization(self.out_channels), - nn.SiLU(), - nn.Dropout(p=dropout), - zero_module( - conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) - ), - ) - - if self.out_channels == channels: - self.skip_connection = nn.Identity() - elif use_conv: - self.skip_connection = conv_nd( - dims, channels, self.out_channels, 3, padding=1 - ) - else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) - - def forward(self, x, emb): - """ - Apply the block to a Tensor, conditioned on a timestep embedding. - :param x: an [N x C x ...] Tensor of features. - :param emb: an [N x emb_channels] Tensor of timestep embeddings. - :return: an [N x C x ...] Tensor of outputs. - """ - return checkpoint( - self._forward, (x, emb), self.parameters(), self.use_checkpoint - ) - - def _forward(self, x, emb): - if self.updown: - in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] - h = in_rest(x) - h = self.h_upd(h) - x = self.x_upd(x) - h = in_conv(h) - else: - h = self.in_layers(x) - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] - if self.use_scale_shift_norm: - out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = th.chunk(emb_out, 2, dim=1) - h = out_norm(h) * (1 + scale) + shift - h = out_rest(h) - else: - h = h + emb_out - h = self.out_layers(h) - return self.skip_connection(x) + h - - -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. - Originally ported from here, but adapted to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - """ - - def __init__( - self, - channels, - num_heads=1, - num_head_channels=-1, - use_checkpoint=False, - use_new_attention_order=False, - ): - super().__init__() - self.channels = channels - if num_head_channels == -1: - self.num_heads = num_heads - else: - assert ( - channels % num_head_channels == 0 - ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" - self.num_heads = channels // num_head_channels - self.use_checkpoint = use_checkpoint - self.norm = normalization(channels) - self.qkv = conv_nd(1, channels, channels * 3, 1) - if use_new_attention_order: - # split qkv before split heads - self.attention = QKVAttention(self.num_heads) - else: - # split heads before split qkv - self.attention = QKVAttentionLegacy(self.num_heads) - - self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) - - def forward(self, x): - return checkpoint( - self._forward, (x,), self.parameters(), True - ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! - # return pt_checkpoint(self._forward, x) # pytorch - - def _forward(self, x): - b, c, *spatial = x.shape - x = x.reshape(b, c, -1) - qkv = self.qkv(self.norm(x)) - h = self.attention(qkv) - h = self.proj_out(h) - return (x + h).reshape(b, c, *spatial) - - -def count_flops_attn(model, _x, y): - """ - A counter for the `thop` package to count the operations in an - attention operation. - Meant to be used like: - macs, params = thop.profile( - model, - inputs=(inputs, timestamps), - custom_ops={QKVAttention: QKVAttention.count_flops}, - ) - """ - b, c, *spatial = y[0].shape - num_spatial = int(np.prod(spatial)) - # We perform two matmuls with the same number of ops. - # The first computes the weight matrix, the second computes - # the combination of the value vectors. - matmul_ops = 2 * b * (num_spatial**2) * c - model.total_ops += th.DoubleTensor([matmul_ops]) - - -class QKVAttentionLegacy(nn.Module): - """ - A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping - """ - - def __init__(self, n_heads): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv): - """ - Apply QKV attention. - :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", q * scale, k * scale - ) # More stable with f16 than dividing afterwards - weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v) - return a.reshape(bs, -1, length) - - @staticmethod - def count_flops(model, _x, y): - return count_flops_attn(model, _x, y) - - -class QKVAttention(nn.Module): - """ - A module which performs QKV attention and splits in a different order. - """ - - def __init__(self, n_heads): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv): - """ - Apply QKV attention. - :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.chunk(3, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", - (q * scale).view(bs * self.n_heads, ch, length), - (k * scale).view(bs * self.n_heads, ch, length), - ) # More stable with f16 than dividing afterwards - weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) - return a.reshape(bs, -1, length) - - @staticmethod - def count_flops(model, _x, y): - return count_flops_attn(model, _x, y) - - -class Timestep(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, t): - return timestep_embedding(t, self.dim) - - -class MultiViewUNetModel(nn.Module): - """ - The full multi-view UNet model with attention, timestep embedding and camera embedding. - :param in_channels: channels in the input Tensor. - :param model_channels: base channel count for the model. - :param out_channels: channels in the output Tensor. - :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. - :param dropout: the dropout probability. - :param channel_mult: channel multiplier for each level of the UNet. - :param conv_resample: if True, use learned convolutions for upsampling and - downsampling. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param num_classes: if specified (as an int), then this model will be - class-conditional with `num_classes` classes. - :param use_checkpoint: use gradient checkpointing to reduce memory usage. - :param num_heads: the number of attention heads in each attention layer. - :param num_heads_channels: if specified, ignore num_heads and instead use - a fixed channel width per attention head. - :param num_heads_upsample: works with num_heads to set a different number - of heads for upsampling. Deprecated. - :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. - :param resblock_updown: use residual blocks for up/downsampling. - :param use_new_attention_order: use a different attention pattern for potentially - increased efficiency. - :param camera_dim: dimensionality of camera input. - """ - - def __init__( - self, - image_size, - in_channels, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - num_classes=None, - use_checkpoint=False, - use_fp16=False, - use_bf16=False, - num_heads=-1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - use_spatial_transformer=False, # custom transformer support - transformer_depth=1, # custom transformer support - context_dim=None, # custom transformer support - n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model - legacy=True, - disable_self_attentions=None, - num_attention_blocks=None, - disable_middle_self_attn=False, - use_linear_in_transformer=False, - adm_in_channels=None, - camera_dim=None, - with_ip=False, # wether add image prompt images - ip_dim=0, # number of extra token, 4 for global 16 for local - ip_weight=1.0, # weight for image prompt context - ip_mode="local_resample", # which mode of adaptor, global or local - ): - super().__init__() - if use_spatial_transformer: - assert ( - context_dim is not None - ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." - - if context_dim is not None: - assert ( - use_spatial_transformer - ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." - from omegaconf.listconfig import ListConfig - - if type(context_dim) == ListConfig: - context_dim = list(context_dim) - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - if num_heads == -1: - assert ( - num_head_channels != -1 - ), "Either num_heads or num_head_channels has to be set" - - if num_head_channels == -1: - assert ( - num_heads != -1 - ), "Either num_heads or num_head_channels has to be set" - - self.image_size = image_size - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - if isinstance(num_res_blocks, int): - self.num_res_blocks = len(channel_mult) * [num_res_blocks] - else: - if len(num_res_blocks) != len(channel_mult): - raise ValueError( - "provide num_res_blocks either as an int (globally constant) or " - "as a list/tuple (per-level) with the same length as channel_mult" - ) - self.num_res_blocks = num_res_blocks - if disable_self_attentions is not None: - # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not - assert len(disable_self_attentions) == len(channel_mult) - if num_attention_blocks is not None: - assert len(num_attention_blocks) == len(self.num_res_blocks) - assert all( - map( - lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], - range(len(num_attention_blocks)), - ) - ) - print( - f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " - f"This option has LESS priority than attention_resolutions {attention_resolutions}, " - f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " - f"attention will still not be set." - ) - - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.num_classes = num_classes - self.use_checkpoint = use_checkpoint - self.dtype = th.float16 if use_fp16 else th.float32 - self.dtype = th.bfloat16 if use_bf16 else self.dtype - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - self.predict_codebook_ids = n_embed is not None - - self.with_ip = with_ip # wether there is image prompt - self.ip_dim = ip_dim # num of extra token, 4 for global 16 for local - self.ip_weight = ip_weight - assert ip_mode in ["global", "local_resample"] - self.ip_mode = ip_mode # which mode of adaptor - - time_embed_dim = model_channels * 4 - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - if camera_dim is not None: - time_embed_dim = model_channels * 4 - self.camera_embed = nn.Sequential( - linear(camera_dim, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - if self.num_classes is not None: - if isinstance(self.num_classes, int): - self.label_emb = nn.Embedding(num_classes, time_embed_dim) - elif self.num_classes == "continuous": - print("setting up linear c_adm embedding layer") - self.label_emb = nn.Linear(1, time_embed_dim) - elif self.num_classes == "sequential": - assert adm_in_channels is not None - self.label_emb = nn.Sequential( - nn.Sequential( - linear(adm_in_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - ) - else: - raise ValueError() - - if self.with_ip and (context_dim is not None) and ip_dim > 0: - if self.ip_mode == "local_resample": - # ip-adapter-plus - hidden_dim = 1280 - self.image_embed = Resampler( - dim=context_dim, - depth=4, - dim_head=64, - heads=12, - num_queries=ip_dim, # num token - embedding_dim=hidden_dim, - output_dim=context_dim, - ff_mult=4, - ) - elif self.ip_mode == "global": - self.image_embed = ImageProjModel( - cross_attention_dim=context_dim, - clip_extra_context_tokens=ip_dim) - else: - raise ValueError(f"{self.ip_mode} is not supported") - - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] - ) - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - for level, mult in enumerate(channel_mult): - for nr in range(self.num_res_blocks[level]): - layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=mult * model_channels, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - # num_heads = 1 - dim_head = ( - ch // num_heads - if use_spatial_transformer - else num_head_channels - ) - if exists(disable_self_attentions): - disabled_sa = disable_self_attentions[level] - else: - disabled_sa = False - - if ( - not exists(num_attention_blocks) - or nr < num_attention_blocks[level] - ): - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) - if not use_spatial_transformer - else SpatialTransformer3D( - ch, - num_heads, - dim_head, - depth=transformer_depth, - context_dim=context_dim, - disable_self_attn=disabled_sa, - use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, - with_ip=self.with_ip, - ip_dim=self.ip_dim, - ip_weight=self.ip_weight - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - # num_heads = 1 - dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) - if not use_spatial_transformer - else SpatialTransformer3D( # always uses a self-attn - ch, - num_heads, - dim_head, - depth=transformer_depth, - context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, - use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, - with_ip=self.with_ip, - ip_dim=self.ip_dim, - ip_weight=self.ip_weight - ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - ) - self._feature_size += ch - - self.output_blocks = nn.ModuleList([]) - for level, mult in list(enumerate(channel_mult))[::-1]: - for i in range(self.num_res_blocks[level] + 1): - ich = input_block_chans.pop() - layers = [ - ResBlock( - ch + ich, - time_embed_dim, - dropout, - out_channels=model_channels * mult, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = model_channels * mult - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - # num_heads = 1 - dim_head = ( - ch // num_heads - if use_spatial_transformer - else num_head_channels - ) - if exists(disable_self_attentions): - disabled_sa = disable_self_attentions[level] - else: - disabled_sa = False - - if ( - not exists(num_attention_blocks) - or i < num_attention_blocks[level] - ): - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads_upsample, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) - if not use_spatial_transformer - else SpatialTransformer3D( - ch, - num_heads, - dim_head, - depth=transformer_depth, - context_dim=context_dim, - disable_self_attn=disabled_sa, - use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, - with_ip=self.with_ip, - ip_dim=self.ip_dim, - ip_weight=self.ip_weight - ) - ) - if level and i == self.num_res_blocks[level]: - out_ch = ch - layers.append( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - up=True, - ) - if resblock_updown - else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) - ) - ds //= 2 - self.output_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), - ) - if self.predict_codebook_ids: - self.id_predictor = nn.Sequential( - normalization(ch), - conv_nd(dims, model_channels, n_embed, 1), - # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits - ) - - def convert_to_fp16(self): - """ - Convert the torso of the model to float16. - """ - self.input_blocks.apply(convert_module_to_f16) - self.middle_block.apply(convert_module_to_f16) - self.output_blocks.apply(convert_module_to_f16) - - def convert_to_fp32(self): - """ - Convert the torso of the model to float32. - """ - self.input_blocks.apply(convert_module_to_f32) - self.middle_block.apply(convert_module_to_f32) - self.output_blocks.apply(convert_module_to_f32) - - def forward( - self, - x, - timesteps=None, - context=None, - y=None, - camera=None, - num_frames=1, - **kwargs, - ): - """ - Apply the model to an input batch. - :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views). - :param timesteps: a 1-D batch of timesteps. - :param context: a dict conditioning plugged in via crossattn - :param y: an [N] Tensor of labels, if class-conditional, default None. - :param num_frames: a integer indicating number of frames for tensor reshaping. - :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views). - """ - assert ( - x.shape[0] % num_frames == 0 - ), "[UNet] input batch size must be dividable by num_frames!" - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" - - hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # shape: torch.Size([B, 320]) mean: 0.18, std: 0.68, min: -1.00, max: 1.00 - emb = self.time_embed(t_emb) # shape: torch.Size([B, 1280]) mean: 0.12, std: 0.57, min: -5.73, max: 6.51 - - if self.num_classes is not None: - assert y.shape[0] == x.shape[0] - emb = emb + self.label_emb(y) - - # Add camera embeddings - if camera is not None: - assert camera.shape[0] == emb.shape[0] - # camera embed: shape: torch.Size([B, 1280]) mean: -0.02, std: 0.27, min: -7.23, max: 2.04 - emb = emb + self.camera_embed(camera) - ip = kwargs.get("ip", None) - ip_img = kwargs.get("ip_img", None) - - if ip_img is not None: - x[(num_frames-1)::num_frames, :, :, :] = ip_img - - if ip is not None: - ip_emb = self.image_embed(ip) # shape: torch.Size([B, 16, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31 - context = torch.cat((context, ip_emb), 1) # shape: torch.Size([B, 93, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31 - - h = x.type(self.dtype) - for module in self.input_blocks: - h = module(h, emb, context, num_frames=num_frames) - hs.append(h) - h = self.middle_block(h, emb, context, num_frames=num_frames) - for module in self.output_blocks: - h = th.cat([h, hs.pop()], dim=1) - h = module(h, emb, context, num_frames=num_frames) - h = h.type(x.dtype) # shape: torch.Size([10, 320, 32, 32]) mean: -0.67, std: 3.96, min: -42.74, max: 25.58 - if self.predict_codebook_ids: # False - return self.id_predictor(h) - else: - return self.out(h) # shape: torch.Size([10, 4, 32, 32]) mean: -0.00, std: 0.91, min: -3.65, max: 3.93 - - - - -class MultiViewUNetModelStage2(MultiViewUNetModel): - """ - The full multi-view UNet model with attention, timestep embedding and camera embedding. - :param in_channels: channels in the input Tensor. - :param model_channels: base channel count for the model. - :param out_channels: channels in the output Tensor. - :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. - :param dropout: the dropout probability. - :param channel_mult: channel multiplier for each level of the UNet. - :param conv_resample: if True, use learned convolutions for upsampling and - downsampling. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param num_classes: if specified (as an int), then this model will be - class-conditional with `num_classes` classes. - :param use_checkpoint: use gradient checkpointing to reduce memory usage. - :param num_heads: the number of attention heads in each attention layer. - :param num_heads_channels: if specified, ignore num_heads and instead use - a fixed channel width per attention head. - :param num_heads_upsample: works with num_heads to set a different number - of heads for upsampling. Deprecated. - :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. - :param resblock_updown: use residual blocks for up/downsampling. - :param use_new_attention_order: use a different attention pattern for potentially - increased efficiency. - :param camera_dim: dimensionality of camera input. - """ - - def __init__( - self, - image_size, - in_channels, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - num_classes=None, - use_checkpoint=False, - use_fp16=False, - use_bf16=False, - num_heads=-1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - use_spatial_transformer=False, # custom transformer support - transformer_depth=1, # custom transformer support - context_dim=None, # custom transformer support - n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model - legacy=True, - disable_self_attentions=None, - num_attention_blocks=None, - disable_middle_self_attn=False, - use_linear_in_transformer=False, - adm_in_channels=None, - camera_dim=None, - with_ip=False, # wether add image prompt images - ip_dim=0, # number of extra token, 4 for global 16 for local - ip_weight=1.0, # weight for image prompt context - ip_mode="local_resample", # which mode of adaptor, global or local - ): - super().__init__( - image_size, - in_channels, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, - dropout, - channel_mult, - conv_resample, - dims, - num_classes, - use_checkpoint, - use_fp16, - use_bf16, - num_heads, - num_head_channels, - num_heads_upsample, - use_scale_shift_norm, - resblock_updown, - use_new_attention_order, - use_spatial_transformer, - transformer_depth, - context_dim, - n_embed, - legacy, - disable_self_attentions, - num_attention_blocks, - disable_middle_self_attn, - use_linear_in_transformer, - adm_in_channels, - camera_dim, - with_ip, - ip_dim, - ip_weight, - ip_mode, - ) - - def forward( - self, - x, - timesteps=None, - context=None, - y=None, - camera=None, - num_frames=1, - **kwargs, - ): - """ - Apply the model to an input batch. - :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views). - :param timesteps: a 1-D batch of timesteps. - :param context: a dict conditioning plugged in via crossattn - :param y: an [N] Tensor of labels, if class-conditional, default None. - :param num_frames: a integer indicating number of frames for tensor reshaping. - :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views). - """ - assert ( - x.shape[0] % num_frames == 0 - ), "[UNet] input batch size must be dividable by num_frames!" - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" - - hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # shape: torch.Size([B, 320]) mean: 0.18, std: 0.68, min: -1.00, max: 1.00 - emb = self.time_embed(t_emb) # shape: torch.Size([B, 1280]) mean: 0.12, std: 0.57, min: -5.73, max: 6.51 - - if self.num_classes is not None: - assert y.shape[0] == x.shape[0] - emb = emb + self.label_emb(y) - - # Add camera embeddings - if camera is not None: - assert camera.shape[0] == emb.shape[0] - # camera embed: shape: torch.Size([B, 1280]) mean: -0.02, std: 0.27, min: -7.23, max: 2.04 - emb = emb + self.camera_embed(camera) - ip = kwargs.get("ip", None) - ip_img = kwargs.get("ip_img", None) - pixel_images = kwargs.get("pixel_images", None) - - if ip_img is not None: - x[(num_frames-1)::num_frames, :, :, :] = ip_img - - x = torch.cat((x, pixel_images), dim=1) - - if ip is not None: - ip_emb = self.image_embed(ip) # shape: torch.Size([B, 16, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31 - context = torch.cat((context, ip_emb), 1) # shape: torch.Size([B, 93, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31 - - h = x.type(self.dtype) - for module in self.input_blocks: - h = module(h, emb, context, num_frames=num_frames) - hs.append(h) - h = self.middle_block(h, emb, context, num_frames=num_frames) - for module in self.output_blocks: - h = th.cat([h, hs.pop()], dim=1) - h = module(h, emb, context, num_frames=num_frames) - h = h.type(x.dtype) # shape: torch.Size([10, 320, 32, 32]) mean: -0.67, std: 3.96, min: -42.74, max: 25.58 - if self.predict_codebook_ids: # False - return self.id_predictor(h) - else: - return self.out(h) # shape: torch.Size([10, 4, 32, 32]) mean: -0.00, std: 0.91, min: -3.65, max: 3.93 - \ No newline at end of file diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/util.py b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/util.py deleted file mode 100644 index af744261d41deab5d686aead790726efcdfaf961..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/util.py +++ /dev/null @@ -1,353 +0,0 @@ -# adopted from -# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py -# and -# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py -# and -# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py -# -# thanks! - - -import os -import math -import torch -import torch.nn as nn -import numpy as np -from einops import repeat -import importlib - - -def instantiate_from_config(config): - if not "target" in config: - if config == "__is_first_stage__": - return None - elif config == "__is_unconditional__": - return None - raise KeyError("Expected key `target` to instantiate.") - return get_obj_from_str(config["target"])(**config.get("params", dict())) - - -def get_obj_from_str(string, reload=False): - module, cls = string.rsplit(".", 1) - if reload: - module_imp = importlib.import_module(module) - importlib.reload(module_imp) - return getattr(importlib.import_module(module, package=None), cls) - - -def make_beta_schedule( - schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 -): - if schedule == "linear": - betas = ( - torch.linspace( - linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 - ) - ** 2 - ) - - elif schedule == "cosine": - timesteps = ( - torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s - ) - alphas = timesteps / (1 + cosine_s) * np.pi / 2 - alphas = torch.cos(alphas).pow(2) - alphas = alphas / alphas[0] - betas = 1 - alphas[1:] / alphas[:-1] - betas = np.clip(betas, a_min=0, a_max=0.999) - - elif schedule == "sqrt_linear": - betas = torch.linspace( - linear_start, linear_end, n_timestep, dtype=torch.float64 - ) - elif schedule == "sqrt": - betas = ( - torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) - ** 0.5 - ) - else: - raise ValueError(f"schedule '{schedule}' unknown.") - return betas.numpy() - -def enforce_zero_terminal_snr(betas): - betas = torch.tensor(betas) if not isinstance(betas, torch.Tensor) else betas - # Convert betas to alphas_bar_sqrt - alphas =1 - betas - alphas_bar = alphas.cumprod(0) - alphas_bar_sqrt = alphas_bar.sqrt() - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - # Shift so last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - # Scale so first timestep is back to old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt ** 2 - alphas = alphas_bar[1:] / alphas_bar[:-1] - alphas = torch.cat ([alphas_bar[0:1], alphas]) - betas = 1 - alphas - return betas - - -def make_ddim_timesteps( - ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True -): - if ddim_discr_method == "uniform": - c = num_ddpm_timesteps // num_ddim_timesteps - ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) - elif ddim_discr_method == "quad": - ddim_timesteps = ( - (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 - ).astype(int) - else: - raise NotImplementedError( - f'There is no ddim discretization method called "{ddim_discr_method}"' - ) - - # assert ddim_timesteps.shape[0] == num_ddim_timesteps - # add one to get the final alpha values right (the ones from first scale to data during sampling) - steps_out = ddim_timesteps + 1 - if verbose: - print(f"Selected timesteps for ddim sampler: {steps_out}") - return steps_out - - -def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): - # select alphas for computing the variance schedule - alphas = alphacums[ddim_timesteps] - alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) - - # according the the formula provided in https://arxiv.org/abs/2010.02502 - sigmas = eta * np.sqrt( - (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) - ) - if verbose: - print( - f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" - ) - print( - f"For the chosen value of eta, which is {eta}, " - f"this results in the following sigma_t schedule for ddim sampler {sigmas}" - ) - return sigmas, alphas, alphas_prev - - -def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, - which defines the cumulative product of (1-beta) over time from t = [0,1]. - :param num_diffusion_timesteps: the number of betas to produce. - :param alpha_bar: a lambda that takes an argument t from 0 to 1 and - produces the cumulative product of (1-beta) up to that - part of the diffusion process. - :param max_beta: the maximum beta to use; use values lower than 1 to - prevent singularities. - """ - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas) - - -def extract_into_tensor(a, t, x_shape): - b, *_ = t.shape - out = a.gather(-1, t) - return out.reshape(b, *((1,) * (len(x_shape) - 1))) - - -def checkpoint(func, inputs, params, flag): - """ - Evaluate a function without caching intermediate activations, allowing for - reduced memory at the expense of extra compute in the backward pass. - :param func: the function to evaluate. - :param inputs: the argument sequence to pass to `func`. - :param params: a sequence of parameters `func` depends on but does not - explicitly take as arguments. - :param flag: if False, disable gradient checkpointing. - """ - if flag: - args = tuple(inputs) + tuple(params) - return CheckpointFunction.apply(func, len(inputs), *args) - else: - return func(*inputs) - - -class CheckpointFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, run_function, length, *args): - ctx.run_function = run_function - ctx.input_tensors = list(args[:length]) - ctx.input_params = list(args[length:]) - - with torch.no_grad(): - output_tensors = ctx.run_function(*ctx.input_tensors) - return output_tensors - - @staticmethod - def backward(ctx, *output_grads): - ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] - with torch.enable_grad(): - # Fixes a bug where the first op in run_function modifies the - # Tensor storage in place, which is not allowed for detach()'d - # Tensors. - shallow_copies = [x.view_as(x) for x in ctx.input_tensors] - output_tensors = ctx.run_function(*shallow_copies) - input_grads = torch.autograd.grad( - output_tensors, - ctx.input_tensors + ctx.input_params, - output_grads, - allow_unused=True, - ) - del ctx.input_tensors - del ctx.input_params - del output_tensors - return (None, None) + input_grads - - -def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): - """ - Create sinusoidal timestep embeddings. - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - if not repeat_only: - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) - * torch.arange(start=0, end=half, dtype=torch.float32) - / half - ).to(device=timesteps.device) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat( - [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 - ) - else: - embedding = repeat(timesteps, "b -> b d", d=dim) - # import pdb; pdb.set_trace() - return embedding - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -def scale_module(module, scale): - """ - Scale the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().mul_(scale) - return module - - -def mean_flat(tensor): - """ - Take the mean over all non-batch dimensions. - """ - return tensor.mean(dim=list(range(1, len(tensor.shape)))) - - -def normalization(channels): - """ - Make a standard normalization layer. - :param channels: number of input channels. - :return: an nn.Module for normalization. - """ - return GroupNorm32(32, channels) - - -# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. -class SiLU(nn.Module): - def forward(self, x): - return x * torch.sigmoid(x) - - -class GroupNorm32(nn.GroupNorm): - def forward(self, x): - return super().forward(x.float()).type(x.dtype) - - -def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return nn.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -def linear(*args, **kwargs): - """ - Create a linear module. - """ - return nn.Linear(*args, **kwargs) - - -def avg_pool_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D average pooling module. - """ - if dims == 1: - return nn.AvgPool1d(*args, **kwargs) - elif dims == 2: - return nn.AvgPool2d(*args, **kwargs) - elif dims == 3: - return nn.AvgPool3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -class HybridConditioner(nn.Module): - def __init__(self, c_concat_config, c_crossattn_config): - super().__init__() - self.concat_conditioner = instantiate_from_config(c_concat_config) - self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) - - def forward(self, c_concat, c_crossattn): - c_concat = self.concat_conditioner(c_concat) - c_crossattn = self.crossattn_conditioner(c_crossattn) - return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} - - -def noise_like(shape, device, repeat=False): - repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( - shape[0], *((1,) * (len(shape) - 1)) - ) - noise = lambda: torch.randn(shape, device=device) - return repeat_noise() if repeat else noise() - - -# dummy replace -def convert_module_to_f16(l): - """ - Convert primitive modules to float16. - """ - if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): - l.weight.data = l.weight.data.half() - if l.bias is not None: - l.bias.data = l.bias.data.half() - -def convert_module_to_f32(l): - """ - Convert primitive modules to float32, undoing convert_module_to_f16(). - """ - if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): - l.weight.data = l.weight.data.float() - if l.bias is not None: - l.bias.data = l.bias.data.float() diff --git a/apps/third_party/CRM/imagedream/ldm/modules/distributions/__init__.py b/apps/third_party/CRM/imagedream/ldm/modules/distributions/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 312f98ad511d7f30d3483685fa70528b1910b20a..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 30cc9dbbbadd654864db556064cf17f299a66b28..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/distributions.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/distributions.cpython-310.pyc deleted file mode 100644 index c80371d54b3d6110586bfe2f53628dbae828ecde..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/distributions.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc deleted file mode 100644 index ee436d35f6a01e2ff87c0588b6ec2bdf7211f43b..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/distributions/distributions.py b/apps/third_party/CRM/imagedream/ldm/modules/distributions/distributions.py deleted file mode 100644 index 92f4428a3defd8fbae18fcd323c9d404036c652e..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/imagedream/ldm/modules/distributions/distributions.py +++ /dev/null @@ -1,102 +0,0 @@ -import torch -import numpy as np - - -class AbstractDistribution: - def sample(self): - raise NotImplementedError() - - def mode(self): - raise NotImplementedError() - - -class DiracDistribution(AbstractDistribution): - def __init__(self, value): - self.value = value - - def sample(self): - return self.value - - def mode(self): - return self.value - - -class DiagonalGaussianDistribution(object): - def __init__(self, parameters, deterministic=False): - self.parameters = parameters - self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) - self.logvar = torch.clamp(self.logvar, -30.0, 20.0) - self.deterministic = deterministic - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean).to( - device=self.parameters.device - ) - - def sample(self): - x = self.mean + self.std * torch.randn(self.mean.shape).to( - device=self.parameters.device - ) - return x - - def kl(self, other=None): - if self.deterministic: - return torch.Tensor([0.0]) - else: - if other is None: - return 0.5 * torch.sum( - torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, - dim=[1, 2, 3], - ) - else: - return 0.5 * torch.sum( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - - 1.0 - - self.logvar - + other.logvar, - dim=[1, 2, 3], - ) - - def nll(self, sample, dims=[1, 2, 3]): - if self.deterministic: - return torch.Tensor([0.0]) - logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, - dim=dims, - ) - - def mode(self): - return self.mean - - -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 - Compute the KL divergence between two gaussians. - Shapes are automatically broadcasted, so batches can be compared to - scalars, among other use cases. - """ - tensor = None - for obj in (mean1, logvar1, mean2, logvar2): - if isinstance(obj, torch.Tensor): - tensor = obj - break - assert tensor is not None, "at least one argument must be a Tensor" - - # Force variances to be Tensors. Broadcasting helps convert scalars to - # Tensors, but it does not work for torch.exp(). - logvar1, logvar2 = [ - x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) - for x in (logvar1, logvar2) - ] - - return 0.5 * ( - -1.0 - + logvar2 - - logvar1 - + torch.exp(logvar1 - logvar2) - + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) - ) diff --git a/apps/third_party/CRM/imagedream/ldm/modules/ema.py b/apps/third_party/CRM/imagedream/ldm/modules/ema.py deleted file mode 100644 index a073e116975f3313fb630b7d4ac115171c1fe31d..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/imagedream/ldm/modules/ema.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -from torch import nn - - -class LitEma(nn.Module): - def __init__(self, model, decay=0.9999, use_num_upates=True): - super().__init__() - if decay < 0.0 or decay > 1.0: - raise ValueError("Decay must be between 0 and 1") - - self.m_name2s_name = {} - self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) - self.register_buffer( - "num_updates", - torch.tensor(0, dtype=torch.int) - if use_num_upates - else torch.tensor(-1, dtype=torch.int), - ) - - for name, p in model.named_parameters(): - if p.requires_grad: - # remove as '.'-character is not allowed in buffers - s_name = name.replace(".", "") - self.m_name2s_name.update({name: s_name}) - self.register_buffer(s_name, p.clone().detach().data) - - self.collected_params = [] - - def reset_num_updates(self): - del self.num_updates - self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) - - def forward(self, model): - decay = self.decay - - if self.num_updates >= 0: - self.num_updates += 1 - decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) - - one_minus_decay = 1.0 - decay - - with torch.no_grad(): - m_param = dict(model.named_parameters()) - shadow_params = dict(self.named_buffers()) - - for key in m_param: - if m_param[key].requires_grad: - sname = self.m_name2s_name[key] - shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) - shadow_params[sname].sub_( - one_minus_decay * (shadow_params[sname] - m_param[key]) - ) - else: - assert not key in self.m_name2s_name - - def copy_to(self, model): - m_param = dict(model.named_parameters()) - shadow_params = dict(self.named_buffers()) - for key in m_param: - if m_param[key].requires_grad: - m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) - else: - assert not key in self.m_name2s_name - - def store(self, parameters): - """ - Save the current parameters for restoring later. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. - """ - self.collected_params = [param.clone() for param in parameters] - - def restore(self, parameters): - """ - Restore the parameters stored with the `store` method. - Useful to validate the model with EMA parameters without affecting the - original optimization process. Store the parameters before the - `copy_to` method. After validation (or model saving), use this to - restore the former parameters. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. - """ - for c_param, param in zip(self.collected_params, parameters): - param.data.copy_(c_param.data) diff --git a/apps/third_party/CRM/imagedream/ldm/modules/encoders/__init__.py b/apps/third_party/CRM/imagedream/ldm/modules/encoders/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 6ad2b831f070c88d8b1b6d35c697cbb2b8466d66..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 98745a81153dfd563cb689b76944075d93cd4feb..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/modules.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/modules.cpython-310.pyc deleted file mode 100644 index 8c722f5bdac9cc12253ad9934ee36901b8468d57..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/modules.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc deleted file mode 100644 index 70899c6fc30bc670956be3ac51675b27beae2d02..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/encoders/modules.py b/apps/third_party/CRM/imagedream/ldm/modules/encoders/modules.py deleted file mode 100644 index a19d2c193d660535f101fa1cc3f1b857ce1197fc..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/imagedream/ldm/modules/encoders/modules.py +++ /dev/null @@ -1,329 +0,0 @@ -import torch -import torch.nn as nn -from torch.utils.checkpoint import checkpoint - -from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel - -import numpy as np -import open_clip -from PIL import Image -from ...util import default, count_params - - -class AbstractEncoder(nn.Module): - def __init__(self): - super().__init__() - - def encode(self, *args, **kwargs): - raise NotImplementedError - - -class IdentityEncoder(AbstractEncoder): - def encode(self, x): - return x - - -class ClassEmbedder(nn.Module): - def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1): - super().__init__() - self.key = key - self.embedding = nn.Embedding(n_classes, embed_dim) - self.n_classes = n_classes - self.ucg_rate = ucg_rate - - def forward(self, batch, key=None, disable_dropout=False): - if key is None: - key = self.key - # this is for use in crossattn - c = batch[key][:, None] - if self.ucg_rate > 0.0 and not disable_dropout: - mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) - c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) - c = c.long() - c = self.embedding(c) - return c - - def get_unconditional_conditioning(self, bs, device="cuda"): - uc_class = ( - self.n_classes - 1 - ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) - uc = torch.ones((bs,), device=device) * uc_class - uc = {self.key: uc} - return uc - - -def disabled_train(self, mode=True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" - return self - - -class FrozenT5Embedder(AbstractEncoder): - """Uses the T5 transformer encoder for text""" - - def __init__( - self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True - ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl - super().__init__() - self.tokenizer = T5Tokenizer.from_pretrained(version) - self.transformer = T5EncoderModel.from_pretrained(version) - self.device = device - self.max_length = max_length # TODO: typical value? - if freeze: - self.freeze() - - def freeze(self): - self.transformer = self.transformer.eval() - # self.train = disabled_train - for param in self.parameters(): - param.requires_grad = False - - def forward(self, text): - batch_encoding = self.tokenizer( - text, - truncation=True, - max_length=self.max_length, - return_length=True, - return_overflowing_tokens=False, - padding="max_length", - return_tensors="pt", - ) - tokens = batch_encoding["input_ids"].to(self.device) - outputs = self.transformer(input_ids=tokens) - - z = outputs.last_hidden_state - return z - - def encode(self, text): - return self(text) - - -class FrozenCLIPEmbedder(AbstractEncoder): - """Uses the CLIP transformer encoder for text (from huggingface)""" - - LAYERS = ["last", "pooled", "hidden"] - - def __init__( - self, - version="openai/clip-vit-large-patch14", - device="cuda", - max_length=77, - freeze=True, - layer="last", - layer_idx=None, - ): # clip-vit-base-patch32 - super().__init__() - assert layer in self.LAYERS - self.tokenizer = CLIPTokenizer.from_pretrained(version) - self.transformer = CLIPTextModel.from_pretrained(version) - self.device = device - self.max_length = max_length - if freeze: - self.freeze() - self.layer = layer - self.layer_idx = layer_idx - if layer == "hidden": - assert layer_idx is not None - assert 0 <= abs(layer_idx) <= 12 - - def freeze(self): - self.transformer = self.transformer.eval() - # self.train = disabled_train - for param in self.parameters(): - param.requires_grad = False - - def forward(self, text): - batch_encoding = self.tokenizer( - text, - truncation=True, - max_length=self.max_length, - return_length=True, - return_overflowing_tokens=False, - padding="max_length", - return_tensors="pt", - ) - tokens = batch_encoding["input_ids"].to(self.device) - outputs = self.transformer( - input_ids=tokens, output_hidden_states=self.layer == "hidden" - ) - if self.layer == "last": - z = outputs.last_hidden_state - elif self.layer == "pooled": - z = outputs.pooler_output[:, None, :] - else: - z = outputs.hidden_states[self.layer_idx] - return z - - def encode(self, text): - return self(text) - - -class FrozenOpenCLIPEmbedder(AbstractEncoder, nn.Module): - """ - Uses the OpenCLIP transformer encoder for text - """ - - LAYERS = [ - # "pooled", - "last", - "penultimate", - ] - - def __init__( - self, - arch="ViT-H-14", - version="laion2b_s32b_b79k", - device="cuda", - max_length=77, - freeze=True, - layer="last", - ip_mode=None - ): - """_summary_ - - Args: - ip_mode (str, optional): what is the image promcessing mode. Defaults to None. - - """ - super().__init__() - assert layer in self.LAYERS - model, _, preprocess = open_clip.create_model_and_transforms( - arch, device=torch.device("cpu"), pretrained=version - ) - if ip_mode is None: - del model.visual - - self.model = model - self.preprocess = preprocess - self.device = device - self.max_length = max_length - self.ip_mode = ip_mode - if freeze: - self.freeze() - self.layer = layer - if self.layer == "last": - self.layer_idx = 0 - elif self.layer == "penultimate": - self.layer_idx = 1 - else: - raise NotImplementedError() - - def freeze(self): - self.model = self.model.eval() - for param in self.parameters(): - param.requires_grad = False - - def forward(self, text): - tokens = open_clip.tokenize(text) - z = self.encode_with_transformer(tokens.to(self.device)) - return z - - def forward_image(self, pil_image): - if isinstance(pil_image, Image.Image): - pil_image = [pil_image] - if isinstance(pil_image, torch.Tensor): - pil_image = pil_image.cpu().numpy() - if isinstance(pil_image, np.ndarray): - if pil_image.ndim == 3: - pil_image = pil_image[None, :, :, :] - pil_image = [Image.fromarray(x) for x in pil_image] - - images = [] - for image in pil_image: - images.append(self.preprocess(image).to(self.device)) - - image = torch.stack(images, 0) # to [b, 3, h, w] - if self.ip_mode == "global": - image_features = self.model.encode_image(image) - image_features /= image_features.norm(dim=-1, keepdim=True) - elif "local" in self.ip_mode: - image_features = self.encode_image_with_transformer(image) - - return image_features # b, l - - def encode_image_with_transformer(self, x): - visual = self.model.visual - x = visual.conv1(x) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - - # class embeddings and positional embeddings - x = torch.cat( - [visual.class_embedding.to(x.dtype) + \ - torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), - x], dim=1) # shape = [*, grid ** 2 + 1, width] - x = x + visual.positional_embedding.to(x.dtype) - - # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in - # x = visual.patch_dropout(x) - x = visual.ln_pre(x) - - x = x.permute(1, 0, 2) # NLD -> LND - hidden = self.image_transformer_forward(x) - x = hidden[-2].permute(1, 0, 2) # LND -> NLD - return x - - def image_transformer_forward(self, x): - encoder_states = () - trans = self.model.visual.transformer - for r in trans.resblocks: - if trans.grad_checkpointing and not torch.jit.is_scripting(): - # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 - x = checkpoint(r, x, None, None, None) - else: - x = r(x, attn_mask=None) - encoder_states = encoder_states + (x, ) - return encoder_states - - def encode_with_transformer(self, text): - x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] - x = x + self.model.positional_embedding - x = x.permute(1, 0, 2) # NLD -> LND - x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) - x = x.permute(1, 0, 2) # LND -> NLD - x = self.model.ln_final(x) - return x - - def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): - for i, r in enumerate(self.model.transformer.resblocks): - if i == len(self.model.transformer.resblocks) - self.layer_idx: - break - if ( - self.model.transformer.grad_checkpointing - and not torch.jit.is_scripting() - ): - x = checkpoint(r, x, attn_mask) - else: - x = r(x, attn_mask=attn_mask) - return x - - def encode(self, text): - return self(text) - - -class FrozenCLIPT5Encoder(AbstractEncoder): - def __init__( - self, - clip_version="openai/clip-vit-large-patch14", - t5_version="google/t5-v1_1-xl", - device="cuda", - clip_max_length=77, - t5_max_length=77, - ): - super().__init__() - self.clip_encoder = FrozenCLIPEmbedder( - clip_version, device, max_length=clip_max_length - ) - self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) - print( - f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " - f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params." - ) - - def encode(self, text): - return self(text) - - def forward(self, text): - clip_z = self.clip_encoder.encode(text) - t5_z = self.t5_encoder.encode(text) - return [clip_z, t5_z] diff --git a/apps/third_party/CRM/imagedream/ldm/util.py b/apps/third_party/CRM/imagedream/ldm/util.py deleted file mode 100644 index 8ed44393d153aad633175d3afda2a1d923c7d815..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/imagedream/ldm/util.py +++ /dev/null @@ -1,231 +0,0 @@ -import importlib - -import random -import torch -import numpy as np -from collections import abc - -import multiprocessing as mp -from threading import Thread -from queue import Queue - -from inspect import isfunction -from PIL import Image, ImageDraw, ImageFont - - -def log_txt_as_img(wh, xc, size=10): - # wh a tuple of (width, height) - # xc a list of captions to plot - b = len(xc) - txts = list() - for bi in range(b): - txt = Image.new("RGB", wh, color="white") - draw = ImageDraw.Draw(txt) - font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) - nc = int(40 * (wh[0] / 256)) - lines = "\n".join( - xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) - ) - - try: - draw.text((0, 0), lines, fill="black", font=font) - except UnicodeEncodeError: - print("Cant encode string for logging. Skipping.") - - txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 - txts.append(txt) - txts = np.stack(txts) - txts = torch.tensor(txts) - return txts - - -def ismap(x): - if not isinstance(x, torch.Tensor): - return False - return (len(x.shape) == 4) and (x.shape[1] > 3) - - -def isimage(x): - if not isinstance(x, torch.Tensor): - return False - return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) - - -def exists(x): - return x is not None - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def mean_flat(tensor): - """ - https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 - Take the mean over all non-batch dimensions. - """ - return tensor.mean(dim=list(range(1, len(tensor.shape)))) - - -def count_params(model, verbose=False): - total_params = sum(p.numel() for p in model.parameters()) - if verbose: - print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") - return total_params - - -def instantiate_from_config(config): - if not "target" in config: - if config == "__is_first_stage__": - return None - elif config == "__is_unconditional__": - return None - raise KeyError("Expected key `target` to instantiate.") - # import pdb; pdb.set_trace() - return get_obj_from_str(config["target"])(**config.get("params", dict())) - - -def get_obj_from_str(string, reload=False): - module, cls = string.rsplit(".", 1) - # import pdb; pdb.set_trace() - if reload: - module_imp = importlib.import_module(module) - importlib.reload(module_imp) - - if 'imagedream' in module: - module = 'apps.third_party.CRM.'+module - if 'lib' in module: - module = 'apps.third_party.CRM.'+module - return getattr(importlib.import_module(module, package=None), cls) - - -def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): - # create dummy dataset instance - - # run prefetching - if idx_to_fn: - res = func(data, worker_id=idx) - else: - res = func(data) - Q.put([idx, res]) - Q.put("Done") - - -def parallel_data_prefetch( - func: callable, - data, - n_proc, - target_data_type="ndarray", - cpu_intensive=True, - use_worker_id=False, -): - # if target_data_type not in ["ndarray", "list"]: - # raise ValueError( - # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." - # ) - if isinstance(data, np.ndarray) and target_data_type == "list": - raise ValueError("list expected but function got ndarray.") - elif isinstance(data, abc.Iterable): - if isinstance(data, dict): - print( - f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' - ) - data = list(data.values()) - if target_data_type == "ndarray": - data = np.asarray(data) - else: - data = list(data) - else: - raise TypeError( - f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." - ) - - if cpu_intensive: - Q = mp.Queue(1000) - proc = mp.Process - else: - Q = Queue(1000) - proc = Thread - # spawn processes - if target_data_type == "ndarray": - arguments = [ - [func, Q, part, i, use_worker_id] - for i, part in enumerate(np.array_split(data, n_proc)) - ] - else: - step = ( - int(len(data) / n_proc + 1) - if len(data) % n_proc != 0 - else int(len(data) / n_proc) - ) - arguments = [ - [func, Q, part, i, use_worker_id] - for i, part in enumerate( - [data[i : i + step] for i in range(0, len(data), step)] - ) - ] - processes = [] - for i in range(n_proc): - p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) - processes += [p] - - # start processes - print(f"Start prefetching...") - import time - - start = time.time() - gather_res = [[] for _ in range(n_proc)] - try: - for p in processes: - p.start() - - k = 0 - while k < n_proc: - # get result - res = Q.get() - if res == "Done": - k += 1 - else: - gather_res[res[0]] = res[1] - - except Exception as e: - print("Exception: ", e) - for p in processes: - p.terminate() - - raise e - finally: - for p in processes: - p.join() - print(f"Prefetching complete. [{time.time() - start} sec.]") - - if target_data_type == "ndarray": - if not isinstance(gather_res[0], np.ndarray): - return np.concatenate([np.asarray(r) for r in gather_res], axis=0) - - # order outputs - return np.concatenate(gather_res, axis=0) - elif target_data_type == "list": - out = [] - for r in gather_res: - out.extend(r) - return out - else: - return gather_res - -def set_seed(seed=None): - random.seed(seed) - np.random.seed(seed) - if seed is not None: - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - -def add_random_background(image, bg_color=None): - bg_color = np.random.rand() * 255 if bg_color is None else bg_color - image = np.array(image) - rgb, alpha = image[..., :3], image[..., 3:] - alpha = alpha.astype(np.float32) / 255.0 - image_new = rgb * alpha + bg_color * (1 - alpha) - return Image.fromarray(image_new.astype(np.uint8)) \ No newline at end of file diff --git a/apps/third_party/CRM/imagedream/model_zoo.py b/apps/third_party/CRM/imagedream/model_zoo.py deleted file mode 100644 index 45d6b678bdc554f5b2ad19903f8ed9976ece024e..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/imagedream/model_zoo.py +++ /dev/null @@ -1,64 +0,0 @@ -""" Utiliy functions to load pre-trained models more easily """ -import os -import pkg_resources -from omegaconf import OmegaConf - -import torch -from huggingface_hub import hf_hub_download - -from imagedream.ldm.util import instantiate_from_config - - -PRETRAINED_MODELS = { - "sd-v2.1-base-4view-ipmv": { - "config": "sd_v2_base_ipmv.yaml", - "repo_id": "Peng-Wang/ImageDream", - "filename": "sd-v2.1-base-4view-ipmv.pt", - }, - "sd-v2.1-base-4view-ipmv-local": { - "config": "sd_v2_base_ipmv_local.yaml", - "repo_id": "Peng-Wang/ImageDream", - "filename": "sd-v2.1-base-4view-ipmv-local.pt", - }, -} - - -def get_config_file(config_path): - cfg_file = pkg_resources.resource_filename( - "imagedream", os.path.join("configs", config_path) - ) - if not os.path.exists(cfg_file): - raise RuntimeError(f"Config {config_path} not available!") - return cfg_file - - -def build_model(model_name, config_path=None, ckpt_path=None, cache_dir=None): - if (config_path is not None) and (ckpt_path is not None): - config = OmegaConf.load(config_path) - model = instantiate_from_config(config.model) - model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False) - return model - - if not model_name in PRETRAINED_MODELS: - raise RuntimeError( - f"Model name {model_name} is not a pre-trained model. Available models are:\n- " - + "\n- ".join(PRETRAINED_MODELS.keys()) - ) - model_info = PRETRAINED_MODELS[model_name] - - # Instiantiate the model - print(f"Loading model from config: {model_info['config']}") - config_file = get_config_file(model_info["config"]) - config = OmegaConf.load(config_file) - model = instantiate_from_config(config.model) - - # Load pre-trained checkpoint from huggingface - if not ckpt_path: - ckpt_path = hf_hub_download( - repo_id=model_info["repo_id"], - filename=model_info["filename"], - cache_dir=cache_dir, - ) - print(f"Loading model from cache file: {ckpt_path}") - model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False) - return model diff --git a/apps/third_party/CRM/inference.py b/apps/third_party/CRM/inference.py deleted file mode 100644 index a6fc2a9e49d606dc44d0cb8ae0cfbf223b44dcd4..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/inference.py +++ /dev/null @@ -1,91 +0,0 @@ -import numpy as np -import torch -import time -import nvdiffrast.torch as dr -from util.utils import get_tri -import tempfile -from mesh import Mesh -import zipfile -def generate3d(model, rgb, ccm, device): - - color_tri = torch.from_numpy(rgb)/255 - xyz_tri = torch.from_numpy(ccm[:,:,(2,1,0)])/255 - color = color_tri.permute(2,0,1) - xyz = xyz_tri.permute(2,0,1) - - - def get_imgs(color): - # color : [C, H, W*6] - color_list = [] - color_list.append(color[:,:,256*5:256*(1+5)]) - for i in range(0,5): - color_list.append(color[:,:,256*i:256*(1+i)]) - return torch.stack(color_list, dim=0)# [6, C, H, W] - - triplane_color = get_imgs(color).permute(0,2,3,1).unsqueeze(0).to(device)# [1, 6, H, W, C] - - color = get_imgs(color) - xyz = get_imgs(xyz) - - color = get_tri(color, dim=0, blender= True, scale = 1).unsqueeze(0) - xyz = get_tri(xyz, dim=0, blender= True, scale = 1, fix= True).unsqueeze(0) - - triplane = torch.cat([color,xyz],dim=1).to(device) - # 3D visualize - model.eval() - glctx = dr.RasterizeCudaContext() - - if model.denoising == True: - tnew = 20 - tnew = torch.randint(tnew, tnew+1, [triplane.shape[0]], dtype=torch.long, device=triplane.device) - noise_new = torch.randn_like(triplane) *0.5+0.5 - triplane = model.scheduler.add_noise(triplane, noise_new, tnew) - start_time = time.time() - with torch.no_grad(): - triplane_feature2 = model.unet2(triplane,tnew) - end_time = time.time() - elapsed_time = end_time - start_time - print(f"unet takes {elapsed_time}s") - else: - triplane_feature2 = model.unet2(triplane) - - - with torch.no_grad(): - data_config = { - 'resolution': [1024, 1024], - "triview_color": triplane_color.to(device), - } - - verts, faces = model.decode(data_config, triplane_feature2) - - data_config['verts'] = verts[0] - data_config['faces'] = faces - - - from kiui.mesh_utils import clean_mesh - verts, faces = clean_mesh(data_config['verts'].squeeze().cpu().numpy().astype(np.float32), data_config['faces'].squeeze().cpu().numpy().astype(np.int32), repair = False, remesh=False, remesh_size=0.005) - data_config['verts'] = torch.from_numpy(verts).cuda().contiguous() - data_config['faces'] = torch.from_numpy(faces).cuda().contiguous() - - start_time = time.time() - with torch.no_grad(): - mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name - model.export_mesh_wt_uv(glctx, data_config, mesh_path_obj, "", device, res=(1024,1024), tri_fea_2=triplane_feature2) - - mesh = Mesh.load(mesh_path_obj+".obj", bound=0.9, front_dir="+z") - mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name - mesh.write(mesh_path_glb+".glb") - - # mesh_obj2 = trimesh.load(mesh_path_glb+".glb", file_type='glb') - # mesh_path_obj2 = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name - # mesh_obj2.export(mesh_path_obj2+".obj") - - with zipfile.ZipFile(mesh_path_obj+'.zip', 'w') as myzip: - myzip.write(mesh_path_obj+'.obj', mesh_path_obj.split("/")[-1]+'.obj') - myzip.write(mesh_path_obj+'.png', mesh_path_obj.split("/")[-1]+'.png') - myzip.write(mesh_path_obj+'.mtl', mesh_path_obj.split("/")[-1]+'.mtl') - - end_time = time.time() - elapsed_time = end_time - start_time - print(f"uv takes {elapsed_time}s") - return mesh_path_glb+".glb", mesh_path_obj+'.zip' diff --git a/apps/third_party/CRM/libs/__init__.py b/apps/third_party/CRM/libs/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/apps/third_party/CRM/libs/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/libs/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 377a573aa67f7cc53af4c8b05123f278611521fc..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/libs/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/libs/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/libs/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 1b54287998f0608201936c65b8676e9a2f8d2c9f..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/libs/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/libs/__pycache__/base_utils.cpython-310.pyc b/apps/third_party/CRM/libs/__pycache__/base_utils.cpython-310.pyc deleted file mode 100644 index b19825c00604c13eafa9ac4db0ce70a9bf04e1df..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/libs/__pycache__/base_utils.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/libs/__pycache__/base_utils.cpython-38.pyc b/apps/third_party/CRM/libs/__pycache__/base_utils.cpython-38.pyc deleted file mode 100644 index fc8c696f4c72d09559bcbc1519c671680fd92a35..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/libs/__pycache__/base_utils.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/libs/__pycache__/sample.cpython-310.pyc b/apps/third_party/CRM/libs/__pycache__/sample.cpython-310.pyc deleted file mode 100644 index 1b872ef355b5c9cac32066264ac6d3fe8e6f49d0..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/libs/__pycache__/sample.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/CRM/libs/__pycache__/sample.cpython-38.pyc b/apps/third_party/CRM/libs/__pycache__/sample.cpython-38.pyc deleted file mode 100644 index 888166c9fae24f2fd70b9939ac5803ec24ab5895..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/libs/__pycache__/sample.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/libs/base_utils.py b/apps/third_party/CRM/libs/base_utils.py deleted file mode 100644 index c90a548286e80ded1f317126d8f560f67a85e0d8..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/libs/base_utils.py +++ /dev/null @@ -1,84 +0,0 @@ -import numpy as np -import cv2 -import torch -import numpy as np -from PIL import Image - - -def instantiate_from_config(config): - if not "target" in config: - raise KeyError("Expected key `target` to instantiate.") - return get_obj_from_str(config["target"])(**config.get("params", dict())) - - -def get_obj_from_str(string, reload=False): - import importlib - module, cls = string.rsplit(".", 1) - if reload: - module_imp = importlib.import_module(module) - importlib.reload(module_imp) - return getattr(importlib.import_module(module, package=None), cls) - - -def tensor_detail(t): - assert type(t) == torch.Tensor - print(f"shape: {t.shape} mean: {t.mean():.2f}, std: {t.std():.2f}, min: {t.min():.2f}, max: {t.max():.2f}") - - - -def drawRoundRec(draw, color, x, y, w, h, r): - drawObject = draw - - '''Rounds''' - drawObject.ellipse((x, y, x + r, y + r), fill=color) - drawObject.ellipse((x + w - r, y, x + w, y + r), fill=color) - drawObject.ellipse((x, y + h - r, x + r, y + h), fill=color) - drawObject.ellipse((x + w - r, y + h - r, x + w, y + h), fill=color) - - '''rec.s''' - drawObject.rectangle((x + r / 2, y, x + w - (r / 2), y + h), fill=color) - drawObject.rectangle((x, y + r / 2, x + w, y + h - (r / 2)), fill=color) - - -def do_resize_content(original_image: Image, scale_rate): - # resize image content wile retain the original image size - if scale_rate != 1: - # Calculate the new size after rescaling - new_size = tuple(int(dim * scale_rate) for dim in original_image.size) - # Resize the image while maintaining the aspect ratio - resized_image = original_image.resize(new_size) - # Create a new image with the original size and black background - padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0)) - paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2) - padded_image.paste(resized_image, paste_position) - return padded_image - else: - return original_image - -def add_stroke(img, color=(255, 255, 255), stroke_radius=3): - # color in R, G, B format - if isinstance(img, Image.Image): - assert img.mode == "RGBA" - img = cv2.cvtColor(np.array(img), cv2.COLOR_RGBA2BGRA) - else: - assert img.shape[2] == 4 - gray = img[:,:, 3] - ret, binary = cv2.threshold(gray,127,255,cv2.THRESH_BINARY) - contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE) - res = cv2.drawContours(img, contours,-1, tuple(color)[::-1] + (255,), stroke_radius) - return Image.fromarray(cv2.cvtColor(res,cv2.COLOR_BGRA2RGBA)) - -def make_blob(image_size=(512, 512), sigma=0.2): - """ - make 2D blob image with: - I(x, y)=1-\exp \left(-\frac{(x-H / 2)^2+(y-W / 2)^2}{2 \sigma^2 HS}\right) - """ - import numpy as np - H, W = image_size - x = np.arange(0, W, 1, float) - y = np.arange(0, H, 1, float) - x, y = np.meshgrid(x, y) - x0 = W // 2 - y0 = H // 2 - img = 1 - np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2 * H * W)) - return (img * 255).astype(np.uint8) \ No newline at end of file diff --git a/apps/third_party/CRM/libs/sample.py b/apps/third_party/CRM/libs/sample.py deleted file mode 100644 index 6a2e5cc453796c10165401bea47814113b6381ba..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/libs/sample.py +++ /dev/null @@ -1,384 +0,0 @@ -import numpy as np -import torch -from imagedream.camera_utils import get_camera_for_index -from imagedream.ldm.util import set_seed, add_random_background -# import os -# import sys -# proj_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -# sys.path.append(proj_dir) -from apps.third_party.CRM.libs.base_utils import do_resize_content -from imagedream.ldm.models.diffusion.ddim import DDIMSampler -from torchvision import transforms as T - - -class ImageDreamDiffusion: - def __init__( - self, - model, - device, - dtype, - mode, - num_frames, - camera_views, - ref_position, - random_background=False, - offset_noise=False, - resize_rate=1, - image_size=256, - seed=1234, - ) -> None: - assert mode in ["pixel", "local"] - size = image_size - self.seed = seed - batch_size = max(4, num_frames) - - neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear." - uc = model.get_learned_conditioning([neg_texts]).to(device) - sampler = DDIMSampler(model) - - # pre-compute camera matrices - camera = [get_camera_for_index(i).squeeze() for i in camera_views] - camera[ref_position] = torch.zeros_like(camera[ref_position]) # set ref camera to zero - camera = torch.stack(camera) - camera = camera.repeat(batch_size // num_frames, 1).to(device) - - self.image_transform = T.Compose( - [ - T.Resize((size, size)), - T.ToTensor(), - T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), - ] - ) - self.dtype = dtype - self.ref_position = ref_position - self.mode = mode - self.random_background = random_background - self.resize_rate = resize_rate - self.num_frames = num_frames - self.size = size - self.device = device - self.batch_size = batch_size - self.model = model - self.sampler = sampler - self.uc = uc - self.camera = camera - self.offset_noise = offset_noise - - @staticmethod - def i2i( - model, - image_size, - prompt, - uc, - sampler, - ip=None, - step=20, - scale=5.0, - batch_size=8, - ddim_eta=0.0, - dtype=torch.float32, - device="cuda", - camera=None, - num_frames=4, - pixel_control=False, - transform=None, - offset_noise=False, - ): - """ The function supports additional image prompt. - Args: - model (_type_): the image dream model - image_size (_type_): size of diffusion output (standard 256) - prompt (_type_): text prompt for the image (prompt in type str) - uc (_type_): unconditional vector (tensor in shape [1, 77, 1024]) - sampler (_type_): imagedream.ldm.models.diffusion.ddim.DDIMSampler - ip (Image, optional): the image prompt. Defaults to None. - step (int, optional): _description_. Defaults to 20. - scale (float, optional): _description_. Defaults to 7.5. - batch_size (int, optional): _description_. Defaults to 8. - ddim_eta (float, optional): _description_. Defaults to 0.0. - dtype (_type_, optional): _description_. Defaults to torch.float32. - device (str, optional): _description_. Defaults to "cuda". - camera (_type_, optional): camera info in tensor, shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00 - num_frames (int, optional): _num of frames (views) to generate - pixel_control: whether to use pixel conditioning. Defaults to False, True when using pixel mode - transform: Compose( - Resize(size=(256, 256), interpolation=bilinear, max_size=None, antialias=warn) - ToTensor() - Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) - ) - """ - ip_raw = ip - if type(prompt) != list: - prompt = [prompt] - with torch.no_grad(), torch.autocast(device_type=torch.device(device).type, dtype=dtype): - c = model.get_learned_conditioning(prompt).to( - device - ) # shape: torch.Size([1, 77, 1024]) mean: -0.17, std: 1.02, min: -7.50, max: 13.05 - c_ = {"context": c.repeat(batch_size, 1, 1)} # batch_size - uc_ = {"context": uc.repeat(batch_size, 1, 1)} - - if camera is not None: - c_["camera"] = uc_["camera"] = ( - camera # shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00 - ) - c_["num_frames"] = uc_["num_frames"] = num_frames - - if ip is not None: - ip_embed = model.get_learned_image_conditioning(ip).to( - device - ) # shape: torch.Size([1, 257, 1280]) mean: 0.06, std: 0.53, min: -6.83, max: 11.12 - ip_ = ip_embed.repeat(batch_size, 1, 1) - c_["ip"] = ip_ - uc_["ip"] = torch.zeros_like(ip_) - - if pixel_control: - assert camera is not None - ip = transform(ip).to( - device - ) # shape: torch.Size([3, 256, 256]) mean: 0.33, std: 0.37, min: -1.00, max: 1.00 - ip_img = model.get_first_stage_encoding( - model.encode_first_stage(ip[None, :, :, :]) - ) # shape: torch.Size([1, 4, 32, 32]) mean: 0.23, std: 0.77, min: -4.42, max: 3.55 - c_["ip_img"] = ip_img - uc_["ip_img"] = torch.zeros_like(ip_img) - - shape = [4, image_size // 8, image_size // 8] # [4, 32, 32] - if offset_noise: - ref = transform(ip_raw).to(device) - ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref[None, :, :, :])) - ref_mean = ref_latent.mean(dim=(-1, -2), keepdim=True) - time_steps = torch.randint(model.num_timesteps - 1, model.num_timesteps, (batch_size,), device=device) - x_T = model.q_sample(torch.ones([batch_size] + shape, device=device) * ref_mean, time_steps) - - samples_ddim, _ = ( - sampler.sample( # shape: torch.Size([5, 4, 32, 32]) mean: 0.29, std: 0.85, min: -3.38, max: 4.43 - S=step, - conditioning=c_, - batch_size=batch_size, - shape=shape, - verbose=False, - unconditional_guidance_scale=scale, - unconditional_conditioning=uc_, - eta=ddim_eta, - x_T=x_T if offset_noise else None, - ) - ) - - x_sample = model.decode_first_stage(samples_ddim) - x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) - x_sample = 255.0 * x_sample.permute(0, 2, 3, 1).cpu().numpy() - - return list(x_sample.astype(np.uint8)) - - def diffuse(self, t, ip, n_test=2): - set_seed(self.seed) - ip = do_resize_content(ip, self.resize_rate) - if self.random_background: - ip = add_random_background(ip) - - images = [] - for _ in range(n_test): - img = self.i2i( - self.model, - self.size, - t, - self.uc, - self.sampler, - ip=ip, - step=50, - scale=5, - batch_size=self.batch_size, - ddim_eta=0.0, - dtype=self.dtype, - device=self.device, - camera=self.camera, - num_frames=self.num_frames, - pixel_control=(self.mode == "pixel"), - transform=self.image_transform, - offset_noise=self.offset_noise, - ) - img = np.concatenate(img, 1) - img = np.concatenate((img, ip.resize((self.size, self.size))), axis=1) - images.append(img) - set_seed() # unset random and numpy seed - return images - - -class ImageDreamDiffusionStage2: - def __init__( - self, - model, - device, - dtype, - num_frames, - camera_views, - ref_position, - random_background=False, - offset_noise=False, - resize_rate=1, - mode="pixel", - image_size=256, - seed=1234, - ) -> None: - assert mode in ["pixel", "local"] - - size = image_size - self.seed = seed - batch_size = max(4, num_frames) - - neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear." - uc = model.get_learned_conditioning([neg_texts]).to(device) - sampler = DDIMSampler(model) - - # pre-compute camera matrices - camera = [get_camera_for_index(i).squeeze() for i in camera_views] - if ref_position is not None: - camera[ref_position] = torch.zeros_like(camera[ref_position]) # set ref camera to zero - camera = torch.stack(camera) - camera = camera.repeat(batch_size // num_frames, 1).to(device) - - self.image_transform = T.Compose( - [ - T.Resize((size, size)), - T.ToTensor(), - T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), - ] - ) - - self.dtype = dtype - self.mode = mode - self.ref_position = ref_position - self.random_background = random_background - self.resize_rate = resize_rate - self.num_frames = num_frames - self.size = size - self.device = device - self.batch_size = batch_size - self.model = model - self.sampler = sampler - self.uc = uc - self.camera = camera - self.offset_noise = offset_noise - - @staticmethod - def i2iStage2( - model, - image_size, - prompt, - uc, - sampler, - pixel_images, - ip=None, - step=20, - scale=5.0, - batch_size=8, - ddim_eta=0.0, - dtype=torch.float32, - device="cuda", - camera=None, - num_frames=4, - pixel_control=False, - transform=None, - offset_noise=False, - ): - ip_raw = ip - if type(prompt) != list: - prompt = [prompt] - with torch.no_grad(), torch.autocast(device_type=torch.device(device).type, dtype=dtype): - c = model.get_learned_conditioning(prompt).to( - device - ) # shape: torch.Size([1, 77, 1024]) mean: -0.17, std: 1.02, min: -7.50, max: 13.05 - c_ = {"context": c.repeat(batch_size, 1, 1)} # batch_size - uc_ = {"context": uc.repeat(batch_size, 1, 1)} - - if camera is not None: - c_["camera"] = uc_["camera"] = ( - camera # shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00 - ) - c_["num_frames"] = uc_["num_frames"] = num_frames - - if ip is not None: - ip_embed = model.get_learned_image_conditioning(ip).to( - device - ) # shape: torch.Size([1, 257, 1280]) mean: 0.06, std: 0.53, min: -6.83, max: 11.12 - ip_ = ip_embed.repeat(batch_size, 1, 1) - c_["ip"] = ip_ - uc_["ip"] = torch.zeros_like(ip_) - - if pixel_control: - assert camera is not None - - transed_pixel_images = torch.stack([transform(i).to(device) for i in pixel_images]) - latent_pixel_images = model.get_first_stage_encoding(model.encode_first_stage(transed_pixel_images)) - - c_["pixel_images"] = latent_pixel_images - uc_["pixel_images"] = torch.zeros_like(latent_pixel_images) - - shape = [4, image_size // 8, image_size // 8] # [4, 32, 32] - if offset_noise: - ref = transform(ip_raw).to(device) - ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref[None, :, :, :])) - ref_mean = ref_latent.mean(dim=(-1, -2), keepdim=True) - time_steps = torch.randint(model.num_timesteps - 1, model.num_timesteps, (batch_size,), device=device) - x_T = model.q_sample(torch.ones([batch_size] + shape, device=device) * ref_mean, time_steps) - - samples_ddim, _ = ( - sampler.sample( # shape: torch.Size([5, 4, 32, 32]) mean: 0.29, std: 0.85, min: -3.38, max: 4.43 - S=step, - conditioning=c_, - batch_size=batch_size, - shape=shape, - verbose=False, - unconditional_guidance_scale=scale, - unconditional_conditioning=uc_, - eta=ddim_eta, - x_T=x_T if offset_noise else None, - ) - ) - x_sample = model.decode_first_stage(samples_ddim) - x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) - x_sample = 255.0 * x_sample.permute(0, 2, 3, 1).cpu().numpy() - - return list(x_sample.astype(np.uint8)) - - @torch.no_grad() - def diffuse(self, t, ip, pixel_images, n_test=2): - set_seed(self.seed) - ip = do_resize_content(ip, self.resize_rate) - pixel_images = [do_resize_content(i, self.resize_rate) for i in pixel_images] - - if self.random_background: - bg_color = np.random.rand() * 255 - ip = add_random_background(ip, bg_color) - pixel_images = [add_random_background(i, bg_color) for i in pixel_images] - - images = [] - for _ in range(n_test): - img = self.i2iStage2( - self.model, - self.size, - t, - self.uc, - self.sampler, - pixel_images=pixel_images, - ip=ip, - step=50, - scale=5, - batch_size=self.batch_size, - ddim_eta=0.0, - dtype=self.dtype, - device=self.device, - camera=self.camera, - num_frames=self.num_frames, - pixel_control=(self.mode == "pixel"), - transform=self.image_transform, - offset_noise=self.offset_noise, - ) - img = np.concatenate(img, 1) - img = np.concatenate( - (img, ip.resize((self.size, self.size)), *[i.resize((self.size, self.size)) for i in pixel_images]), - axis=1, - ) - images.append(img) - set_seed() # unset random and numpy seed - return images diff --git a/apps/third_party/CRM/mesh.py b/apps/third_party/CRM/mesh.py deleted file mode 100644 index b98dea041fb41d207b6e95ed927216344854d25c..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/mesh.py +++ /dev/null @@ -1,845 +0,0 @@ -import os -import cv2 -import torch -import trimesh -import numpy as np - -from kiui.op import safe_normalize, dot -from kiui.typing import * - -class Mesh: - """ - A torch-native trimesh class, with support for ``ply/obj/glb`` formats. - - Note: - This class only supports one mesh with a single texture image (an albedo texture and a metallic-roughness texture). - """ - def __init__( - self, - v: Optional[Tensor] = None, - f: Optional[Tensor] = None, - vn: Optional[Tensor] = None, - fn: Optional[Tensor] = None, - vt: Optional[Tensor] = None, - ft: Optional[Tensor] = None, - vc: Optional[Tensor] = None, # vertex color - albedo: Optional[Tensor] = None, - metallicRoughness: Optional[Tensor] = None, - device: Optional[torch.device] = None, - ): - """Init a mesh directly using all attributes. - - Args: - v (Optional[Tensor]): vertices, float [N, 3]. Defaults to None. - f (Optional[Tensor]): faces, int [M, 3]. Defaults to None. - vn (Optional[Tensor]): vertex normals, float [N, 3]. Defaults to None. - fn (Optional[Tensor]): faces for normals, int [M, 3]. Defaults to None. - vt (Optional[Tensor]): vertex uv coordinates, float [N, 2]. Defaults to None. - ft (Optional[Tensor]): faces for uvs, int [M, 3]. Defaults to None. - vc (Optional[Tensor]): vertex colors, float [N, 3]. Defaults to None. - albedo (Optional[Tensor]): albedo texture, float [H, W, 3], RGB format. Defaults to None. - metallicRoughness (Optional[Tensor]): metallic-roughness texture, float [H, W, 3], metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]. Defaults to None. - device (Optional[torch.device]): torch device. Defaults to None. - """ - self.device = device - self.v = v - self.vn = vn - self.vt = vt - self.f = f - self.fn = fn - self.ft = ft - # will first see if there is vertex color to use - self.vc = vc - # only support a single albedo image - self.albedo = albedo - # pbr extension, metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1] - # ref: https://registry.khronos.org/glTF/specs/2.0/glTF-2.0.html - self.metallicRoughness = metallicRoughness - - self.ori_center = 0 - self.ori_scale = 1 - - @classmethod - def load(cls, path, resize=True, clean=False, renormal=True, retex=False, bound=0.9, front_dir='+z', **kwargs): - """load mesh from path. - - Args: - path (str): path to mesh file, supports ply, obj, glb. - clean (bool, optional): perform mesh cleaning at load (e.g., merge close vertices). Defaults to False. - resize (bool, optional): auto resize the mesh using ``bound`` into [-bound, bound]^3. Defaults to True. - renormal (bool, optional): re-calc the vertex normals. Defaults to True. - retex (bool, optional): re-calc the uv coordinates, will overwrite the existing uv coordinates. Defaults to False. - bound (float, optional): bound to resize. Defaults to 0.9. - front_dir (str, optional): front-view direction of the mesh, should be [+-][xyz][ 123]. Defaults to '+z'. - device (torch.device, optional): torch device. Defaults to None. - - Note: - a ``device`` keyword argument can be provided to specify the torch device. - If it's not provided, we will try to use ``'cuda'`` as the device if it's available. - - Returns: - Mesh: the loaded Mesh object. - """ - # obj supports face uv - if path.endswith(".obj"): - mesh = cls.load_obj(path, **kwargs) - # trimesh only supports vertex uv, but can load more formats - else: - mesh = cls.load_trimesh(path, **kwargs) - - # clean - if clean: - from kiui.mesh_utils import clean_mesh - vertices = mesh.v.detach().cpu().numpy() - triangles = mesh.f.detach().cpu().numpy() - vertices, triangles = clean_mesh(vertices, triangles, remesh=False) - mesh.v = torch.from_numpy(vertices).contiguous().float().to(mesh.device) - mesh.f = torch.from_numpy(triangles).contiguous().int().to(mesh.device) - - print(f"[Mesh loading] v: {mesh.v.shape}, f: {mesh.f.shape}") - # auto-normalize - if resize: - mesh.auto_size(bound=bound) - # auto-fix normal - if renormal or mesh.vn is None: - mesh.auto_normal() - print(f"[Mesh loading] vn: {mesh.vn.shape}, fn: {mesh.fn.shape}") - # auto-fix texcoords - if retex or (mesh.albedo is not None and mesh.vt is None): - mesh.auto_uv(cache_path=path) - print(f"[Mesh loading] vt: {mesh.vt.shape}, ft: {mesh.ft.shape}") - - # rotate front dir to +z - if front_dir != "+z": - # axis switch - if "-z" in front_dir: - T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -1]], device=mesh.device, dtype=torch.float32) - elif "+x" in front_dir: - T = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32) - elif "-x" in front_dir: - T = torch.tensor([[0, 0, -1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32) - elif "+y" in front_dir: - T = torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]], device=mesh.device, dtype=torch.float32) - elif "-y" in front_dir: - T = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=mesh.device, dtype=torch.float32) - else: - T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) - # rotation (how many 90 degrees) - if '1' in front_dir: - T @= torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) - elif '2' in front_dir: - T @= torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) - elif '3' in front_dir: - T @= torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) - mesh.v @= T - mesh.vn @= T - - return mesh - - # load from obj file - @classmethod - def load_obj(cls, path, albedo_path=None, device=None): - """load an ``obj`` mesh. - - Args: - path (str): path to mesh. - albedo_path (str, optional): path to the albedo texture image, will overwrite the existing texture path if specified in mtl. Defaults to None. - device (torch.device, optional): torch device. Defaults to None. - - Note: - We will try to read `mtl` path from `obj`, else we assume the file name is the same as `obj` but with `mtl` extension. - The `usemtl` statement is ignored, and we only use the last material path in `mtl` file. - - Returns: - Mesh: the loaded Mesh object. - """ - assert os.path.splitext(path)[-1] == ".obj" - - mesh = cls() - - # device - if device is None: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - mesh.device = device - - # load obj - with open(path, "r") as f: - lines = f.readlines() - - def parse_f_v(fv): - # pass in a vertex term of a face, return {v, vt, vn} (-1 if not provided) - # supported forms: - # f v1 v2 v3 - # f v1/vt1 v2/vt2 v3/vt3 - # f v1/vt1/vn1 v2/vt2/vn2 v3/vt3/vn3 - # f v1//vn1 v2//vn2 v3//vn3 - xs = [int(x) - 1 if x != "" else -1 for x in fv.split("/")] - xs.extend([-1] * (3 - len(xs))) - return xs[0], xs[1], xs[2] - - vertices, texcoords, normals = [], [], [] - faces, tfaces, nfaces = [], [], [] - mtl_path = None - - for line in lines: - split_line = line.split() - # empty line - if len(split_line) == 0: - continue - prefix = split_line[0].lower() - # mtllib - if prefix == "mtllib": - mtl_path = split_line[1] - # usemtl - elif prefix == "usemtl": - pass # ignored - # v/vn/vt - elif prefix == "v": - vertices.append([float(v) for v in split_line[1:]]) - elif prefix == "vn": - normals.append([float(v) for v in split_line[1:]]) - elif prefix == "vt": - val = [float(v) for v in split_line[1:]] - texcoords.append([val[0], 1.0 - val[1]]) - elif prefix == "f": - vs = split_line[1:] - nv = len(vs) - v0, t0, n0 = parse_f_v(vs[0]) - for i in range(nv - 2): # triangulate (assume vertices are ordered) - v1, t1, n1 = parse_f_v(vs[i + 1]) - v2, t2, n2 = parse_f_v(vs[i + 2]) - faces.append([v0, v1, v2]) - tfaces.append([t0, t1, t2]) - nfaces.append([n0, n1, n2]) - - mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device) - mesh.vt = ( - torch.tensor(texcoords, dtype=torch.float32, device=device) - if len(texcoords) > 0 - else None - ) - mesh.vn = ( - torch.tensor(normals, dtype=torch.float32, device=device) - if len(normals) > 0 - else None - ) - - mesh.f = torch.tensor(faces, dtype=torch.int32, device=device) - mesh.ft = ( - torch.tensor(tfaces, dtype=torch.int32, device=device) - if len(texcoords) > 0 - else None - ) - mesh.fn = ( - torch.tensor(nfaces, dtype=torch.int32, device=device) - if len(normals) > 0 - else None - ) - - # see if there is vertex color - use_vertex_color = False - if mesh.v.shape[1] == 6: - use_vertex_color = True - mesh.vc = mesh.v[:, 3:] - mesh.v = mesh.v[:, :3] - print(f"[load_obj] use vertex color: {mesh.vc.shape}") - - # try to load texture image - if not use_vertex_color: - # try to retrieve mtl file - mtl_path_candidates = [] - if mtl_path is not None: - mtl_path_candidates.append(mtl_path) - mtl_path_candidates.append(os.path.join(os.path.dirname(path), mtl_path)) - mtl_path_candidates.append(path.replace(".obj", ".mtl")) - - mtl_path = None - for candidate in mtl_path_candidates: - if os.path.exists(candidate): - mtl_path = candidate - break - - # if albedo_path is not provided, try retrieve it from mtl - metallic_path = None - roughness_path = None - if mtl_path is not None and albedo_path is None: - with open(mtl_path, "r") as f: - lines = f.readlines() - - for line in lines: - split_line = line.split() - # empty line - if len(split_line) == 0: - continue - prefix = split_line[0] - - if "map_Kd" in prefix: - # assume relative path! - albedo_path = os.path.join(os.path.dirname(path), split_line[1]) - print(f"[load_obj] use texture from: {albedo_path}") - elif "map_Pm" in prefix: - metallic_path = os.path.join(os.path.dirname(path), split_line[1]) - elif "map_Pr" in prefix: - roughness_path = os.path.join(os.path.dirname(path), split_line[1]) - - # still not found albedo_path, or the path doesn't exist - if albedo_path is None or not os.path.exists(albedo_path): - # init an empty texture - print(f"[load_obj] init empty albedo!") - # albedo = np.random.rand(1024, 1024, 3).astype(np.float32) - albedo = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) # default color - else: - albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED) - albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB) - albedo = albedo.astype(np.float32) / 255 - print(f"[load_obj] load texture: {albedo.shape}") - - mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device) - - # try to load metallic and roughness - if metallic_path is not None and roughness_path is not None: - print(f"[load_obj] load metallicRoughness from: {metallic_path}, {roughness_path}") - metallic = cv2.imread(metallic_path, cv2.IMREAD_UNCHANGED) - metallic = metallic.astype(np.float32) / 255 - roughness = cv2.imread(roughness_path, cv2.IMREAD_UNCHANGED) - roughness = roughness.astype(np.float32) / 255 - metallicRoughness = np.stack([np.zeros_like(metallic), roughness, metallic], axis=-1) - - mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous() - - return mesh - - @classmethod - def load_trimesh(cls, path, device=None): - """load a mesh using ``trimesh.load()``. - - Can load various formats like ``glb`` and serves as a fallback. - - Note: - We will try to merge all meshes if the glb contains more than one, - but **this may cause the texture to lose**, since we only support one texture image! - - Args: - path (str): path to the mesh file. - device (torch.device, optional): torch device. Defaults to None. - - Returns: - Mesh: the loaded Mesh object. - """ - mesh = cls() - - # device - if device is None: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - mesh.device = device - - # use trimesh to load ply/glb - _data = trimesh.load(path) - if isinstance(_data, trimesh.Scene): - if len(_data.geometry) == 1: - _mesh = list(_data.geometry.values())[0] - else: - print(f"[load_trimesh] concatenating {len(_data.geometry)} meshes.") - _concat = [] - # loop the scene graph and apply transform to each mesh - scene_graph = _data.graph.to_flattened() # dict {name: {transform: 4x4 mat, geometry: str}} - for k, v in scene_graph.items(): - name = v['geometry'] - if name in _data.geometry and isinstance(_data.geometry[name], trimesh.Trimesh): - transform = v['transform'] - _concat.append(_data.geometry[name].apply_transform(transform)) - _mesh = trimesh.util.concatenate(_concat) - else: - _mesh = _data - - if _mesh.visual.kind == 'vertex': - vertex_colors = _mesh.visual.vertex_colors - vertex_colors = np.array(vertex_colors[..., :3]).astype(np.float32) / 255 - mesh.vc = torch.tensor(vertex_colors, dtype=torch.float32, device=device) - print(f"[load_trimesh] use vertex color: {mesh.vc.shape}") - elif _mesh.visual.kind == 'texture': - _material = _mesh.visual.material - if isinstance(_material, trimesh.visual.material.PBRMaterial): - texture = np.array(_material.baseColorTexture).astype(np.float32) / 255 - # load metallicRoughness if present - if _material.metallicRoughnessTexture is not None: - metallicRoughness = np.array(_material.metallicRoughnessTexture).astype(np.float32) / 255 - mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous() - elif isinstance(_material, trimesh.visual.material.SimpleMaterial): - texture = np.array(_material.to_pbr().baseColorTexture).astype(np.float32) / 255 - else: - raise NotImplementedError(f"material type {type(_material)} not supported!") - mesh.albedo = torch.tensor(texture[..., :3], dtype=torch.float32, device=device).contiguous() - print(f"[load_trimesh] load texture: {texture.shape}") - else: - texture = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) - mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device) - print(f"[load_trimesh] failed to load texture.") - - vertices = _mesh.vertices - - try: - texcoords = _mesh.visual.uv - texcoords[:, 1] = 1 - texcoords[:, 1] - except Exception as e: - texcoords = None - - try: - normals = _mesh.vertex_normals - except Exception as e: - normals = None - - # trimesh only support vertex uv... - faces = tfaces = nfaces = _mesh.faces - - mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device) - mesh.vt = ( - torch.tensor(texcoords, dtype=torch.float32, device=device) - if texcoords is not None - else None - ) - mesh.vn = ( - torch.tensor(normals, dtype=torch.float32, device=device) - if normals is not None - else None - ) - - mesh.f = torch.tensor(faces, dtype=torch.int32, device=device) - mesh.ft = ( - torch.tensor(tfaces, dtype=torch.int32, device=device) - if texcoords is not None - else None - ) - mesh.fn = ( - torch.tensor(nfaces, dtype=torch.int32, device=device) - if normals is not None - else None - ) - - return mesh - - # sample surface (using trimesh) - def sample_surface(self, count: int): - """sample points on the surface of the mesh. - - Args: - count (int): number of points to sample. - - Returns: - torch.Tensor: the sampled points, float [count, 3]. - """ - _mesh = trimesh.Trimesh(vertices=self.v.detach().cpu().numpy(), faces=self.f.detach().cpu().numpy()) - points, face_idx = trimesh.sample.sample_surface(_mesh, count) - points = torch.from_numpy(points).float().to(self.device) - return points - - # aabb - def aabb(self): - """get the axis-aligned bounding box of the mesh. - - Returns: - Tuple[torch.Tensor]: the min xyz and max xyz of the mesh. - """ - return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values - - # unit size - @torch.no_grad() - def auto_size(self, bound=0.9): - """auto resize the mesh. - - Args: - bound (float, optional): resizing into ``[-bound, bound]^3``. Defaults to 0.9. - """ - vmin, vmax = self.aabb() - self.ori_center = (vmax + vmin) / 2 - self.ori_scale = 2 * bound / torch.max(vmax - vmin).item() - self.v = (self.v - self.ori_center) * self.ori_scale - - def auto_normal(self): - """auto calculate the vertex normals. - """ - i0, i1, i2 = self.f[:, 0].long(), self.f[:, 1].long(), self.f[:, 2].long() - v0, v1, v2 = self.v[i0, :], self.v[i1, :], self.v[i2, :] - - face_normals = torch.cross(v1 - v0, v2 - v0) - - # Splat face normals to vertices - vn = torch.zeros_like(self.v) - vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) - vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) - vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) - - # Normalize, replace zero (degenerated) normals with some default value - vn = torch.where( - dot(vn, vn) > 1e-20, - vn, - torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device), - ) - vn = safe_normalize(vn) - - self.vn = vn - self.fn = self.f - - def auto_uv(self, cache_path=None, vmap=True): - """auto calculate the uv coordinates. - - Args: - cache_path (str, optional): path to save/load the uv cache as a npz file, this can avoid calculating uv every time when loading the same mesh, which is time-consuming. Defaults to None. - vmap (bool, optional): remap vertices based on uv coordinates, so each v correspond to a unique vt (necessary for formats like gltf). - Usually this will duplicate the vertices on the edge of uv atlas. Defaults to True. - """ - # try to load cache - if cache_path is not None: - cache_path = os.path.splitext(cache_path)[0] + "_uv.npz" - if cache_path is not None and os.path.exists(cache_path): - data = np.load(cache_path) - vt_np, ft_np, vmapping = data["vt"], data["ft"], data["vmapping"] - else: - import xatlas - - v_np = self.v.detach().cpu().numpy() - f_np = self.f.detach().int().cpu().numpy() - atlas = xatlas.Atlas() - atlas.add_mesh(v_np, f_np) - chart_options = xatlas.ChartOptions() - # chart_options.max_iterations = 4 - atlas.generate(chart_options=chart_options) - vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2] - - # save to cache - if cache_path is not None: - np.savez(cache_path, vt=vt_np, ft=ft_np, vmapping=vmapping) - - vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device) - ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device) - self.vt = vt - self.ft = ft - - if vmap: - vmapping = torch.from_numpy(vmapping.astype(np.int64)).long().to(self.device) - self.align_v_to_vt(vmapping) - - def align_v_to_vt(self, vmapping=None): - """ remap v/f and vn/fn to vt/ft. - - Args: - vmapping (np.ndarray, optional): the mapping relationship from f to ft. Defaults to None. - """ - if vmapping is None: - ft = self.ft.view(-1).long() - f = self.f.view(-1).long() - vmapping = torch.zeros(self.vt.shape[0], dtype=torch.long, device=self.device) - vmapping[ft] = f # scatter, randomly choose one if index is not unique - - self.v = self.v[vmapping] - self.f = self.ft - - if self.vn is not None: - self.vn = self.vn[vmapping] - self.fn = self.ft - - def to(self, device): - """move all tensor attributes to device. - - Args: - device (torch.device): target device. - - Returns: - Mesh: self. - """ - self.device = device - for name in ["v", "f", "vn", "fn", "vt", "ft", "albedo", "vc", "metallicRoughness"]: - tensor = getattr(self, name) - if tensor is not None: - setattr(self, name, tensor.to(device)) - return self - - def write(self, path): - """write the mesh to a path. - - Args: - path (str): path to write, supports ply, obj and glb. - """ - if path.endswith(".ply"): - self.write_ply(path) - elif path.endswith(".obj"): - self.write_obj(path) - elif path.endswith(".glb") or path.endswith(".gltf"): - self.write_glb(path) - else: - raise NotImplementedError(f"format {path} not supported!") - - def write_ply(self, path): - """write the mesh in ply format. Only for geometry! - - Args: - path (str): path to write. - """ - - if self.albedo is not None: - print(f'[WARN] ply format does not support exporting texture, will ignore!') - - v_np = self.v.detach().cpu().numpy() - f_np = self.f.detach().cpu().numpy() - - _mesh = trimesh.Trimesh(vertices=v_np, faces=f_np) - _mesh.export(path) - - - def write_glb(self, path): - """write the mesh in glb/gltf format. - This will create a scene with a single mesh. - - Args: - path (str): path to write. - """ - - # assert self.v.shape[0] == self.vn.shape[0] and self.v.shape[0] == self.vt.shape[0] - if self.vt is not None and self.v.shape[0] != self.vt.shape[0]: - self.align_v_to_vt() - - import pygltflib - - f_np = self.f.detach().cpu().numpy().astype(np.uint32) - f_np_blob = f_np.flatten().tobytes() - - v_np = self.v.detach().cpu().numpy().astype(np.float32) - v_np_blob = v_np.tobytes() - - blob = f_np_blob + v_np_blob - byteOffset = len(blob) - - # base mesh - gltf = pygltflib.GLTF2( - scene=0, - scenes=[pygltflib.Scene(nodes=[0])], - nodes=[pygltflib.Node(mesh=0)], - meshes=[pygltflib.Mesh(primitives=[pygltflib.Primitive( - # indices to accessors (0 is triangles) - attributes=pygltflib.Attributes( - POSITION=1, - ), - indices=0, - )])], - buffers=[ - pygltflib.Buffer(byteLength=len(f_np_blob) + len(v_np_blob)) - ], - # buffer view (based on dtype) - bufferViews=[ - # triangles; as flatten (element) array - pygltflib.BufferView( - buffer=0, - byteLength=len(f_np_blob), - target=pygltflib.ELEMENT_ARRAY_BUFFER, # GL_ELEMENT_ARRAY_BUFFER (34963) - ), - # positions; as vec3 array - pygltflib.BufferView( - buffer=0, - byteOffset=len(f_np_blob), - byteLength=len(v_np_blob), - byteStride=12, # vec3 - target=pygltflib.ARRAY_BUFFER, # GL_ARRAY_BUFFER (34962) - ), - ], - accessors=[ - # 0 = triangles - pygltflib.Accessor( - bufferView=0, - componentType=pygltflib.UNSIGNED_INT, # GL_UNSIGNED_INT (5125) - count=f_np.size, - type=pygltflib.SCALAR, - max=[int(f_np.max())], - min=[int(f_np.min())], - ), - # 1 = positions - pygltflib.Accessor( - bufferView=1, - componentType=pygltflib.FLOAT, # GL_FLOAT (5126) - count=len(v_np), - type=pygltflib.VEC3, - max=v_np.max(axis=0).tolist(), - min=v_np.min(axis=0).tolist(), - ), - ], - ) - - # append texture info - if self.vt is not None: - - vt_np = self.vt.detach().cpu().numpy().astype(np.float32) - vt_np_blob = vt_np.tobytes() - - albedo = self.albedo.detach().cpu().numpy() - albedo = (albedo * 255).astype(np.uint8) - albedo = cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR) - albedo_blob = cv2.imencode('.png', albedo)[1].tobytes() - - # update primitive - gltf.meshes[0].primitives[0].attributes.TEXCOORD_0 = 2 - gltf.meshes[0].primitives[0].material = 0 - - # update materials - gltf.materials.append(pygltflib.Material( - pbrMetallicRoughness=pygltflib.PbrMetallicRoughness( - baseColorTexture=pygltflib.TextureInfo(index=0, texCoord=0), - metallicFactor=0.0, - roughnessFactor=1.0, - ), - alphaMode=pygltflib.OPAQUE, - alphaCutoff=None, - doubleSided=True, - )) - - gltf.textures.append(pygltflib.Texture(sampler=0, source=0)) - gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT)) - gltf.images.append(pygltflib.Image(bufferView=3, mimeType="image/png")) - - # update buffers - gltf.bufferViews.append( - # index = 2, texcoords; as vec2 array - pygltflib.BufferView( - buffer=0, - byteOffset=byteOffset, - byteLength=len(vt_np_blob), - byteStride=8, # vec2 - target=pygltflib.ARRAY_BUFFER, - ) - ) - - gltf.accessors.append( - # 2 = texcoords - pygltflib.Accessor( - bufferView=2, - componentType=pygltflib.FLOAT, - count=len(vt_np), - type=pygltflib.VEC2, - max=vt_np.max(axis=0).tolist(), - min=vt_np.min(axis=0).tolist(), - ) - ) - - blob += vt_np_blob - byteOffset += len(vt_np_blob) - - gltf.bufferViews.append( - # index = 3, albedo texture; as none target - pygltflib.BufferView( - buffer=0, - byteOffset=byteOffset, - byteLength=len(albedo_blob), - ) - ) - - blob += albedo_blob - byteOffset += len(albedo_blob) - - gltf.buffers[0].byteLength = byteOffset - - # append metllic roughness - if self.metallicRoughness is not None: - metallicRoughness = self.metallicRoughness.detach().cpu().numpy() - metallicRoughness = (metallicRoughness * 255).astype(np.uint8) - metallicRoughness = cv2.cvtColor(metallicRoughness, cv2.COLOR_RGB2BGR) - metallicRoughness_blob = cv2.imencode('.png', metallicRoughness)[1].tobytes() - - # update texture definition - gltf.materials[0].pbrMetallicRoughness.metallicFactor = 1.0 - gltf.materials[0].pbrMetallicRoughness.roughnessFactor = 1.0 - gltf.materials[0].pbrMetallicRoughness.metallicRoughnessTexture = pygltflib.TextureInfo(index=1, texCoord=0) - - gltf.textures.append(pygltflib.Texture(sampler=1, source=1)) - gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT)) - gltf.images.append(pygltflib.Image(bufferView=4, mimeType="image/png")) - - # update buffers - gltf.bufferViews.append( - # index = 4, metallicRoughness texture; as none target - pygltflib.BufferView( - buffer=0, - byteOffset=byteOffset, - byteLength=len(metallicRoughness_blob), - ) - ) - - blob += metallicRoughness_blob - byteOffset += len(metallicRoughness_blob) - - gltf.buffers[0].byteLength = byteOffset - - - # set actual data - gltf.set_binary_blob(blob) - - # glb = b"".join(gltf.save_to_bytes()) - gltf.save(path) - - - def write_obj(self, path): - """write the mesh in obj format. Will also write the texture and mtl files. - - Args: - path (str): path to write. - """ - - mtl_path = path.replace(".obj", ".mtl") - albedo_path = path.replace(".obj", "_albedo.png") - metallic_path = path.replace(".obj", "_metallic.png") - roughness_path = path.replace(".obj", "_roughness.png") - - v_np = self.v.detach().cpu().numpy() - vt_np = self.vt.detach().cpu().numpy() if self.vt is not None else None - vn_np = self.vn.detach().cpu().numpy() if self.vn is not None else None - f_np = self.f.detach().cpu().numpy() - ft_np = self.ft.detach().cpu().numpy() if self.ft is not None else None - fn_np = self.fn.detach().cpu().numpy() if self.fn is not None else None - - with open(path, "w") as fp: - fp.write(f"mtllib {os.path.basename(mtl_path)} \n") - - for v in v_np: - fp.write(f"v {v[0]} {v[1]} {v[2]} \n") - - if vt_np is not None: - for v in vt_np: - fp.write(f"vt {v[0]} {1 - v[1]} \n") - - if vn_np is not None: - for v in vn_np: - fp.write(f"vn {v[0]} {v[1]} {v[2]} \n") - - fp.write(f"usemtl defaultMat \n") - for i in range(len(f_np)): - fp.write( - f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else ""}/{fn_np[i, 0] + 1 if fn_np is not None else ""} \ - {f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else ""}/{fn_np[i, 1] + 1 if fn_np is not None else ""} \ - {f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else ""}/{fn_np[i, 2] + 1 if fn_np is not None else ""} \n' - ) - - with open(mtl_path, "w") as fp: - fp.write(f"newmtl defaultMat \n") - fp.write(f"Ka 1 1 1 \n") - fp.write(f"Kd 1 1 1 \n") - fp.write(f"Ks 0 0 0 \n") - fp.write(f"Tr 1 \n") - fp.write(f"illum 1 \n") - fp.write(f"Ns 0 \n") - if self.albedo is not None: - fp.write(f"map_Kd {os.path.basename(albedo_path)} \n") - if self.metallicRoughness is not None: - # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering - fp.write(f"map_Pm {os.path.basename(metallic_path)} \n") - fp.write(f"map_Pr {os.path.basename(roughness_path)} \n") - - if self.albedo is not None: - albedo = self.albedo.detach().cpu().numpy() - albedo = (albedo * 255).astype(np.uint8) - cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR)) - - if self.metallicRoughness is not None: - metallicRoughness = self.metallicRoughness.detach().cpu().numpy() - metallicRoughness = (metallicRoughness * 255).astype(np.uint8) - cv2.imwrite(metallic_path, metallicRoughness[..., 2]) - cv2.imwrite(roughness_path, metallicRoughness[..., 1]) - diff --git a/apps/third_party/CRM/model/.DS_Store b/apps/third_party/CRM/model/.DS_Store deleted file mode 100644 index e2fd35f3e9054910a42dde149c88e430130d66d7..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/model/.DS_Store and /dev/null differ diff --git a/apps/third_party/CRM/model/__init__.py b/apps/third_party/CRM/model/__init__.py deleted file mode 100644 index b339e3ea9dac5482a6daed63b069e4d1eda000a8..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/model/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from model.crm.model import CRM \ No newline at end of file diff --git a/apps/third_party/CRM/model/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/model/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index de5f1cf642bb8b776bf8da5e5d582f6001996da8..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/model/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/model/archs/__init__.py b/apps/third_party/CRM/model/archs/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/apps/third_party/CRM/model/archs/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/model/archs/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 1b1f2b99930a199372b91624098bf60f9baa8469..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/model/archs/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/model/archs/__pycache__/mlp_head.cpython-38.pyc b/apps/third_party/CRM/model/archs/__pycache__/mlp_head.cpython-38.pyc deleted file mode 100644 index da7fbd72afdb4465902b7f80d4dda57c8f3b9ee5..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/model/archs/__pycache__/mlp_head.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/model/archs/__pycache__/unet.cpython-38.pyc b/apps/third_party/CRM/model/archs/__pycache__/unet.cpython-38.pyc deleted file mode 100644 index 932b63f9d151ae809cb092048dd560ac436567f7..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/model/archs/__pycache__/unet.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/model/archs/decoders/__init__.py b/apps/third_party/CRM/model/archs/decoders/__init__.py deleted file mode 100644 index d3f5a12faa99758192ecc4ed3fc22c9249232e86..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/model/archs/decoders/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/apps/third_party/CRM/model/archs/decoders/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/model/archs/decoders/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index da3e1e7536c21056bb8268bd03799770646488ae..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/model/archs/decoders/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/model/archs/decoders/__pycache__/shape_texture_net.cpython-38.pyc b/apps/third_party/CRM/model/archs/decoders/__pycache__/shape_texture_net.cpython-38.pyc deleted file mode 100644 index 2fab55d21095a1a7e4841140ca2efb1c71f9730c..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/model/archs/decoders/__pycache__/shape_texture_net.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/model/archs/decoders/shape_texture_net.py b/apps/third_party/CRM/model/archs/decoders/shape_texture_net.py deleted file mode 100644 index 5e5ddd78215b9f48b281757a91b8de6f73e4742a..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/model/archs/decoders/shape_texture_net.py +++ /dev/null @@ -1,62 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class TetTexNet(nn.Module): - def __init__(self, plane_reso=64, padding=0.1, fea_concat=True): - super().__init__() - # self.c_dim = c_dim - self.plane_reso = plane_reso - self.padding = padding - self.fea_concat = fea_concat - - def forward(self, rolled_out_feature, query): - # rolled_out_feature: rolled-out triplane feature - # query: queried xyz coordinates (should be scaled consistently to ptr cloud) - - plane_reso = self.plane_reso - - triplane_feature = dict() - triplane_feature['xy'] = rolled_out_feature[:, :, :, 0: plane_reso] - triplane_feature['yz'] = rolled_out_feature[:, :, :, plane_reso: 2 * plane_reso] - triplane_feature['zx'] = rolled_out_feature[:, :, :, 2 * plane_reso:] - - query_feature_xy = self.sample_plane_feature(query, triplane_feature['xy'], 'xy') - query_feature_yz = self.sample_plane_feature(query, triplane_feature['yz'], 'yz') - query_feature_zx = self.sample_plane_feature(query, triplane_feature['zx'], 'zx') - - if self.fea_concat: - query_feature = torch.cat((query_feature_xy, query_feature_yz, query_feature_zx), dim=1) - else: - query_feature = query_feature_xy + query_feature_yz + query_feature_zx - - output = query_feature.permute(0, 2, 1) - - return output - - # uses values from plane_feature and pixel locations from vgrid to interpolate feature - def sample_plane_feature(self, query, plane_feature, plane): - # CYF note: - # for pretraining, query are uniformly sampled positions w.i. [-scale, scale] - # for training, query are essentially tetrahedra grid vertices, which are - # also within [-scale, scale] in the current version! - # xy range [-scale, scale] - if plane == 'xy': - xy = query[:, :, [0, 1]] - elif plane == 'yz': - xy = query[:, :, [1, 2]] - elif plane == 'zx': - xy = query[:, :, [2, 0]] - else: - raise ValueError("Error! Invalid plane type!") - - xy = xy[:, :, None].float() - # not seem necessary to rescale the grid, because from - # https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html, - # it specifies sampling locations normalized by plane_feature's spatial dimension, - # which is within [-scale, scale] as specified by encoder's calling of coordinate2index() - vgrid = 1.0 * xy - sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1) - - return sampled_feat diff --git a/apps/third_party/CRM/model/archs/mlp_head.py b/apps/third_party/CRM/model/archs/mlp_head.py deleted file mode 100644 index 33d7dcdbf58374dd036d9f3f5f0bfd3f248e845b..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/model/archs/mlp_head.py +++ /dev/null @@ -1,40 +0,0 @@ -import torch.nn as nn -import torch.nn.functional as F - - -class SdfMlp(nn.Module): - def __init__(self, input_dim, hidden_dim=512, bias=True): - super().__init__() - self.input_dim = input_dim - self.hidden_dim = hidden_dim - - self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias) - self.fc2 = nn.Linear(hidden_dim, hidden_dim, bias=bias) - self.fc3 = nn.Linear(hidden_dim, 4, bias=bias) - - - def forward(self, input): - x = F.relu(self.fc1(input)) - x = F.relu(self.fc2(x)) - out = self.fc3(x) - return out - - -class RgbMlp(nn.Module): - def __init__(self, input_dim, hidden_dim=512, bias=True): - super().__init__() - self.input_dim = input_dim - self.hidden_dim = hidden_dim - - self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias) - self.fc2 = nn.Linear(hidden_dim, hidden_dim, bias=bias) - self.fc3 = nn.Linear(hidden_dim, 3, bias=bias) - - def forward(self, input): - x = F.relu(self.fc1(input)) - x = F.relu(self.fc2(x)) - out = self.fc3(x) - - return out - - \ No newline at end of file diff --git a/apps/third_party/CRM/model/archs/unet.py b/apps/third_party/CRM/model/archs/unet.py deleted file mode 100644 index e427c18b1bed00089e87fe25b0e810042538d6b3..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/model/archs/unet.py +++ /dev/null @@ -1,53 +0,0 @@ -''' -Codes are from: -https://github.com/jaxony/unet-pytorch/blob/master/model.py -''' - -import torch -import torch.nn as nn -from diffusers import UNet2DModel -import einops -class UNetPP(nn.Module): - ''' - Wrapper for UNet in diffusers - ''' - def __init__(self, in_channels): - super(UNetPP, self).__init__() - self.in_channels = in_channels - self.unet = UNet2DModel( - sample_size=[256, 256*3], - in_channels=in_channels, - out_channels=32, - layers_per_block=2, - block_out_channels=(64, 128, 128, 128*2, 128*2, 128*4, 128*4), - down_block_types=( - "DownBlock2D", - "DownBlock2D", - "DownBlock2D", - "AttnDownBlock2D", - "AttnDownBlock2D", - "AttnDownBlock2D", - "DownBlock2D", - ), - up_block_types=( - "UpBlock2D", - "AttnUpBlock2D", - "AttnUpBlock2D", - "AttnUpBlock2D", - "UpBlock2D", - "UpBlock2D", - "UpBlock2D", - ), - ) - - self.unet.enable_xformers_memory_efficient_attention() - if in_channels > 12: - self.learned_plane = torch.nn.parameter.Parameter(torch.zeros([1,in_channels-12,256,256*3])) - - def forward(self, x, t=256): - learned_plane = self.learned_plane - if x.shape[1] < self.in_channels: - learned_plane = einops.repeat(learned_plane, '1 C H W -> B C H W', B=x.shape[0]).to(x.device) - x = torch.cat([x, learned_plane], dim = 1) - return self.unet(x, t).sample - diff --git a/apps/third_party/CRM/model/crm/__pycache__/model.cpython-38.pyc b/apps/third_party/CRM/model/crm/__pycache__/model.cpython-38.pyc deleted file mode 100644 index 93b38087a8e1a71bab18af950898769b765939ab..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/model/crm/__pycache__/model.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/model/crm/model.py b/apps/third_party/CRM/model/crm/model.py deleted file mode 100644 index 9eb164266a0c23295ab3451d99d720ec12afa468..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/model/crm/model.py +++ /dev/null @@ -1,213 +0,0 @@ -import torch.nn as nn -import torch -import torch.nn.functional as F - -import numpy as np - - -from pathlib import Path -import cv2 -import trimesh -import nvdiffrast.torch as dr - -from model.archs.decoders.shape_texture_net import TetTexNet -from model.archs.unet import UNetPP -from util.renderer import Renderer -from model.archs.mlp_head import SdfMlp, RgbMlp -import xatlas - - -class Dummy: - pass - -class CRM(nn.Module): - def __init__(self, specs): - super(CRM, self).__init__() - - self.specs = specs - # configs - input_specs = specs["Input"] - self.input = Dummy() - self.input.scale = input_specs['scale'] - self.input.resolution = input_specs['resolution'] - self.tet_grid_size = input_specs['tet_grid_size'] - self.camera_angle_num = input_specs['camera_angle_num'] - - self.arch = Dummy() - self.arch.fea_concat = specs["ArchSpecs"]["fea_concat"] - self.arch.mlp_bias = specs["ArchSpecs"]["mlp_bias"] - - self.dec = Dummy() - self.dec.c_dim = specs["DecoderSpecs"]["c_dim"] - self.dec.plane_resolution = specs["DecoderSpecs"]["plane_resolution"] - - self.geo_type = specs["Train"].get("geo_type", "flex") # "dmtet" or "flex" - - self.unet2 = UNetPP(in_channels=self.dec.c_dim) - - mlp_chnl_s = 3 if self.arch.fea_concat else 1 # 3 for queried triplane feature concatenation - self.decoder = TetTexNet(plane_reso=self.dec.plane_resolution, fea_concat=self.arch.fea_concat) - - if self.geo_type == "flex": - self.weightMlp = nn.Sequential( - nn.Linear(mlp_chnl_s * 32 * 8, 512), - nn.SiLU(), - nn.Linear(512, 21)) - - self.sdfMlp = SdfMlp(mlp_chnl_s * 32, 512, bias=self.arch.mlp_bias) - self.rgbMlp = RgbMlp(mlp_chnl_s * 32, 512, bias=self.arch.mlp_bias) - self.renderer = Renderer(tet_grid_size=self.tet_grid_size, camera_angle_num=self.camera_angle_num, - scale=self.input.scale, geo_type = self.geo_type) - - - self.spob = True if specs['Pretrain']['mode'] is None else False # whether to add sphere - self.radius = specs['Pretrain']['radius'] # used when spob - - self.denoising = True - from diffusers import DDIMScheduler - self.scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler") - - def decode(self, data, triplane_feature2): - if self.geo_type == "flex": - tet_verts = self.renderer.flexicubes.verts.unsqueeze(0) - tet_indices = self.renderer.flexicubes.indices - - dec_verts = self.decoder(triplane_feature2, tet_verts) - out = self.sdfMlp(dec_verts) - - weight = None - if self.geo_type == "flex": - grid_feat = torch.index_select(input=dec_verts, index=self.renderer.flexicubes.indices.reshape(-1),dim=1) - grid_feat = grid_feat.reshape(dec_verts.shape[0], self.renderer.flexicubes.indices.shape[0], self.renderer.flexicubes.indices.shape[1] * dec_verts.shape[-1]) - weight = self.weightMlp(grid_feat) - weight = weight * 0.1 - - pred_sdf, deformation = out[..., 0], out[..., 1:] - if self.spob: - pred_sdf = pred_sdf + self.radius - torch.sqrt((tet_verts**2).sum(-1)) - - _, verts, faces = self.renderer(data, pred_sdf, deformation, tet_verts, tet_indices, weight= weight) - return verts[0].unsqueeze(0), faces[0].int() - - def export_mesh(self, data, out_dir, ind, device=None, tri_fea_2 = None): - verts = data['verts'] - faces = data['faces'] - - dec_verts = self.decoder(tri_fea_2, verts.unsqueeze(0)) - colors = self.rgbMlp(dec_verts).squeeze().detach().cpu().numpy() - # Expect predicted colors value range from [-1, 1] - colors = (colors * 0.5 + 0.5).clip(0, 1) - - verts = verts.squeeze().cpu().numpy() - faces = faces[..., [2, 1, 0]].squeeze().cpu().numpy() - - # export the final mesh - with torch.no_grad(): - mesh = trimesh.Trimesh(verts, faces, vertex_colors=colors, process=False) # important, process=True leads to seg fault... - mesh.export(out_dir / f'{ind}.obj') - - def export_mesh_wt_uv(self, ctx, data, out_dir, ind, device, res, tri_fea_2=None): - - mesh_v = data['verts'].squeeze().cpu().numpy() - mesh_pos_idx = data['faces'].squeeze().cpu().numpy() - - def interpolate(attr, rast, attr_idx, rast_db=None): - return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, - diff_attrs=None if rast_db is None else 'all') - - vmapping, indices, uvs = xatlas.parametrize(mesh_v, mesh_pos_idx) - - mesh_v = torch.tensor(mesh_v, dtype=torch.float32, device=device) - mesh_pos_idx = torch.tensor(mesh_pos_idx, dtype=torch.int64, device=device) - - # Convert to tensors - indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) - - uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device) - mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device) - # mesh_v_tex. ture - uv_clip = uvs[None, ...] * 2.0 - 1.0 - - # pad to four component coordinate - uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1) - - # rasterize - rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), res) - - # Interpolate world space position - gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int()) - mask = rast[..., 3:4] > 0 - - # return uvs, mesh_tex_idx, gb_pos, mask - gb_pos_unsqz = gb_pos.view(-1, 3) - mask_unsqz = mask.view(-1) - tex_unsqz = torch.zeros_like(gb_pos_unsqz) + 1 - - gb_mask_pos = gb_pos_unsqz[mask_unsqz] - - gb_mask_pos = gb_mask_pos[None, ] - - with torch.no_grad(): - - dec_verts = self.decoder(tri_fea_2, gb_mask_pos) - colors = self.rgbMlp(dec_verts).squeeze() - - # Expect predicted colors value range from [-1, 1] - lo, hi = (-1, 1) - colors = (colors - lo) * (255 / (hi - lo)) - colors = colors.clip(0, 255) - - tex_unsqz[mask_unsqz] = colors - - tex = tex_unsqz.view(res + (3,)) - - verts = mesh_v.squeeze().cpu().numpy() - faces = mesh_pos_idx[..., [2, 1, 0]].squeeze().cpu().numpy() - # faces = mesh_pos_idx - # faces = faces.detach().cpu().numpy() - # faces = faces[..., [2, 1, 0]] - indices = indices[..., [2, 1, 0]] - - # xatlas.export(f"{out_dir}/{ind}.obj", verts[vmapping], indices, uvs) - matname = f'{out_dir}.mtl' - # matname = f'{out_dir}/{ind}.mtl' - fid = open(matname, 'w') - fid.write('newmtl material_0\n') - fid.write('Kd 1 1 1\n') - fid.write('Ka 1 1 1\n') - # fid.write('Ks 0 0 0\n') - fid.write('Ks 0.4 0.4 0.4\n') - fid.write('Ns 10\n') - fid.write('illum 2\n') - fid.write(f'map_Kd {out_dir.split("/")[-1]}.png\n') - fid.close() - - fid = open(f'{out_dir}.obj', 'w') - # fid = open(f'{out_dir}/{ind}.obj', 'w') - fid.write('mtllib %s.mtl\n' % out_dir.split("/")[-1]) - - for pidx, p in enumerate(verts): - pp = p - fid.write('v %f %f %f\n' % (pp[0], pp[2], - pp[1])) - - for pidx, p in enumerate(uvs): - pp = p - fid.write('vt %f %f\n' % (pp[0], 1 - pp[1])) - - fid.write('usemtl material_0\n') - for i, f in enumerate(faces): - f1 = f + 1 - f2 = indices[i] + 1 - fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2])) - fid.close() - - img = np.asarray(tex.data.cpu().numpy(), dtype=np.float32) - mask = np.sum(img.astype(float), axis=-1, keepdims=True) - mask = (mask <= 3.0).astype(float) - kernel = np.ones((3, 3), 'uint8') - dilate_img = cv2.dilate(img, kernel, iterations=1) - img = img * (1 - mask) + dilate_img * mask - img = img.clip(0, 255).astype(np.uint8) - - cv2.imwrite(f'{out_dir}.png', img[..., [2, 1, 0]]) - # cv2.imwrite(f'{out_dir}/{ind}.png', img[..., [2, 1, 0]]) diff --git a/apps/third_party/CRM/pipelines.py b/apps/third_party/CRM/pipelines.py deleted file mode 100644 index 0ef19dc84dd7789197d02a29239d99b0a82558b1..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/pipelines.py +++ /dev/null @@ -1,205 +0,0 @@ -import torch -import os -import sys -proj_dir = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(proj_dir) -from .libs.base_utils import do_resize_content -from .imagedream.ldm.util import ( - instantiate_from_config, - get_obj_from_str, -) -from omegaconf import OmegaConf -from PIL import Image -import PIL -import rembg -class TwoStagePipeline(object): - def __init__( - self, - stage1_model_config, - stage1_sampler_config, - device="cuda", - dtype=torch.float16, - resize_rate=1, - ) -> None: - """ - only for two stage generate process. - - the first stage was condition on single pixel image, gererate multi-view pixel image, based on the v2pp config - - the second stage was condition on multiview pixel image generated by the first stage, generate the final image, based on the stage2-test config - """ - self.resize_rate = resize_rate - - self.stage1_model = instantiate_from_config(OmegaConf.load(stage1_model_config.config).model) - self.stage1_model.load_state_dict(torch.load(stage1_model_config.resume, map_location="cpu"), strict=False) - self.stage1_model = self.stage1_model.to(device).to(dtype) - - self.stage1_model.device = device - self.device = device - self.dtype = dtype - self.stage1_sampler = get_obj_from_str(stage1_sampler_config.target)( - self.stage1_model, device=device, dtype=dtype, **stage1_sampler_config.params - ) - - def stage1_sample( - self, - pixel_img, - prompt="3D assets", - neg_texts="uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear.", - step=50, - scale=5, - ddim_eta=0.0, - ): - if type(pixel_img) == str: - pixel_img = Image.open(pixel_img) - - if isinstance(pixel_img, Image.Image): - if pixel_img.mode == "RGBA": - background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0)) - pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB") - else: - pixel_img = pixel_img.convert("RGB") - else: - raise - uc = self.stage1_sampler.model.get_learned_conditioning([neg_texts]).to(self.device) - stage1_images = self.stage1_sampler.i2i( - self.stage1_sampler.model, - self.stage1_sampler.size, - prompt, - uc=uc, - sampler=self.stage1_sampler.sampler, - ip=pixel_img, - step=step, - scale=scale, - batch_size=self.stage1_sampler.batch_size, - ddim_eta=ddim_eta, - dtype=self.stage1_sampler.dtype, - device=self.stage1_sampler.device, - camera=self.stage1_sampler.camera, - num_frames=self.stage1_sampler.num_frames, - pixel_control=(self.stage1_sampler.mode == "pixel"), - transform=self.stage1_sampler.image_transform, - offset_noise=self.stage1_sampler.offset_noise, - ) - - stage1_images = [Image.fromarray(img) for img in stage1_images] - stage1_images.pop(self.stage1_sampler.ref_position) - return stage1_images - - def stage2_sample(self, pixel_img, stage1_images, scale=5, step=50): - if type(pixel_img) == str: - pixel_img = Image.open(pixel_img) - - if isinstance(pixel_img, Image.Image): - if pixel_img.mode == "RGBA": - background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0)) - pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB") - else: - pixel_img = pixel_img.convert("RGB") - else: - raise - stage2_images = self.stage2_sampler.i2iStage2( - self.stage2_sampler.model, - self.stage2_sampler.size, - "3D assets", - self.stage2_sampler.uc, - self.stage2_sampler.sampler, - pixel_images=stage1_images, - ip=pixel_img, - step=step, - scale=scale, - batch_size=self.stage2_sampler.batch_size, - ddim_eta=0.0, - dtype=self.stage2_sampler.dtype, - device=self.stage2_sampler.device, - camera=self.stage2_sampler.camera, - num_frames=self.stage2_sampler.num_frames, - pixel_control=(self.stage2_sampler.mode == "pixel"), - transform=self.stage2_sampler.image_transform, - offset_noise=self.stage2_sampler.offset_noise, - ) - stage2_images = [Image.fromarray(img) for img in stage2_images] - return stage2_images - - def set_seed(self, seed): - self.stage1_sampler.seed = seed - # self.stage2_sampler.seed = seed - - def __call__(self, pixel_img, prompt="3D assets", scale=5, step=50): - pixel_img = do_resize_content(pixel_img, self.resize_rate) - stage1_images = self.stage1_sample(pixel_img, prompt, scale=scale, step=step) - # stage2_images = self.stage2_sample(pixel_img, stage1_images, scale=scale, step=step) - - return { - "ref_img": pixel_img, - "stage1_images": stage1_images, - # "stage2_images": stage2_images, - } - -rembg_session = rembg.new_session() - -def expand_to_square(image, bg_color=(0, 0, 0, 0)): - # expand image to 1:1 - width, height = image.size - if width == height: - return image - new_size = (max(width, height), max(width, height)) - new_image = Image.new("RGBA", new_size, bg_color) - paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2) - new_image.paste(image, paste_position) - return new_image - -def remove_background( - image: PIL.Image.Image, - rembg_session = None, - force: bool = False, - **rembg_kwargs, -) -> PIL.Image.Image: - do_remove = True - if image.mode == "RGBA" and image.getextrema()[3][0] < 255: - # explain why current do not rm bg - print("alhpa channl not enpty, skip remove background, using alpha channel as mask") - background = Image.new("RGBA", image.size, (0, 0, 0, 0)) - image = Image.alpha_composite(background, image) - do_remove = False - do_remove = do_remove or force - if do_remove: - image = rembg.remove(image, session=rembg_session, **rembg_kwargs) - return image - -def do_resize_content(original_image: Image, scale_rate): - # resize image content wile retain the original image size - if scale_rate != 1: - # Calculate the new size after rescaling - new_size = tuple(int(dim * scale_rate) for dim in original_image.size) - # Resize the image while maintaining the aspect ratio - resized_image = original_image.resize(new_size) - # Create a new image with the original size and black background - padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0)) - paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2) - padded_image.paste(resized_image, paste_position) - return padded_image - else: - return original_image - -def add_background(image, bg_color=(255, 255, 255)): - # given an RGBA image, alpha channel is used as mask to add background color - background = Image.new("RGBA", image.size, bg_color) - return Image.alpha_composite(background, image) - - -def preprocess_image(image, background_choice, foreground_ratio, backgroud_color): - """ - input image is a pil image in RGBA, return RGB image - """ - print(background_choice) - if background_choice == "Alpha as mask": - background = Image.new("RGBA", image.size, (0, 0, 0, 0)) - image = Image.alpha_composite(background, image) - else: - image = remove_background(image, rembg_session, force_remove=True) - image = do_resize_content(image, foreground_ratio) - image = expand_to_square(image) - image = add_background(image, backgroud_color) - return image.convert("RGB") - - - diff --git a/apps/third_party/CRM/requirements.txt b/apps/third_party/CRM/requirements.txt deleted file mode 100644 index 8501f40d7ec31f64b3b6b77549fb2b7623f2d382..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/requirements.txt +++ /dev/null @@ -1,16 +0,0 @@ -gradio -huggingface-hub -diffusers==0.24.0 -einops==0.7.0 -Pillow==10.1.0 -transformers==4.27.1 -open-clip-torch==2.7.0 -opencv-contrib-python-headless==4.9.0.80 -opencv-python-headless==4.9.0.80 -omegaconf -rembg -pygltflib -kiui -trimesh -xatlas -pymeshlab diff --git a/apps/third_party/CRM/run.py b/apps/third_party/CRM/run.py deleted file mode 100644 index 8e14be6e7c7cb1d314d1e82a23a6d250e79ce3b7..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/run.py +++ /dev/null @@ -1,160 +0,0 @@ -import torch -from libs.base_utils import do_resize_content -from imagedream.ldm.util import ( - instantiate_from_config, - get_obj_from_str, -) -from omegaconf import OmegaConf -from PIL import Image -import numpy as np -from inference import generate3d -from huggingface_hub import hf_hub_download -import json -import argparse -import shutil -from model import CRM -import PIL -import rembg -import os -from pipelines import TwoStagePipeline - -rembg_session = rembg.new_session() - -def expand_to_square(image, bg_color=(0, 0, 0, 0)): - # expand image to 1:1 - width, height = image.size - if width == height: - return image - new_size = (max(width, height), max(width, height)) - new_image = Image.new("RGBA", new_size, bg_color) - paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2) - new_image.paste(image, paste_position) - return new_image - -def remove_background( - image: PIL.Image.Image, - rembg_session = None, - force: bool = False, - **rembg_kwargs, -) -> PIL.Image.Image: - do_remove = True - if image.mode == "RGBA" and image.getextrema()[3][0] < 255: - # explain why current do not rm bg - print("alhpa channl not enpty, skip remove background, using alpha channel as mask") - background = Image.new("RGBA", image.size, (0, 0, 0, 0)) - image = Image.alpha_composite(background, image) - do_remove = False - do_remove = do_remove or force - if do_remove: - image = rembg.remove(image, session=rembg_session, **rembg_kwargs) - return image - -def do_resize_content(original_image: Image, scale_rate): - # resize image content wile retain the original image size - if scale_rate != 1: - # Calculate the new size after rescaling - new_size = tuple(int(dim * scale_rate) for dim in original_image.size) - # Resize the image while maintaining the aspect ratio - resized_image = original_image.resize(new_size) - # Create a new image with the original size and black background - padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0)) - paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2) - padded_image.paste(resized_image, paste_position) - return padded_image - else: - return original_image - -def add_background(image, bg_color=(255, 255, 255)): - # given an RGBA image, alpha channel is used as mask to add background color - background = Image.new("RGBA", image.size, bg_color) - return Image.alpha_composite(background, image) - - -def preprocess_image(image, background_choice, foreground_ratio, backgroud_color): - """ - input image is a pil image in RGBA, return RGB image - """ - print(background_choice) - if background_choice == "Alpha as mask": - background = Image.new("RGBA", image.size, (0, 0, 0, 0)) - image = Image.alpha_composite(background, image) - else: - image = remove_background(image, rembg_session, force_remove=True) - image = do_resize_content(image, foreground_ratio) - image = expand_to_square(image) - image = add_background(image, backgroud_color) - return image.convert("RGB") - -if __name__ == "__main__": - - parser = argparse.ArgumentParser() - parser.add_argument( - "--inputdir", - type=str, - default="examples/kunkun.webp", - help="dir for input image", - ) - parser.add_argument( - "--scale", - type=float, - default=5.0, - ) - parser.add_argument( - "--step", - type=int, - default=50, - ) - parser.add_argument( - "--bg_choice", - type=str, - default="Auto Remove background", - help="[Auto Remove background] or [Alpha as mask]", - ) - parser.add_argument( - "--outdir", - type=str, - default="out/", - ) - args = parser.parse_args() - - - img = Image.open(args.inputdir) - img = preprocess_image(img, args.bg_choice, 1.0, (127, 127, 127)) - os.makedirs(args.outdir, exist_ok=True) - img.save(args.outdir+"preprocessed_image.png") - - crm_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="CRM.pth") - specs = json.load(open("configs/specs_objaverse_total.json")) - model = CRM(specs).to("cuda") - model.load_state_dict(torch.load(crm_path, map_location = "cuda"), strict=False) - - stage1_config = OmegaConf.load("configs/nf7_v3_SNR_rd_size_stroke.yaml").config - stage2_config = OmegaConf.load("configs/stage2-v2-snr.yaml").config - stage2_sampler_config = stage2_config.sampler - stage1_sampler_config = stage1_config.sampler - - stage1_model_config = stage1_config.models - stage2_model_config = stage2_config.models - - xyz_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="ccm-diffusion.pth") - pixel_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="pixel-diffusion.pth") - stage1_model_config.resume = pixel_path - stage2_model_config.resume = xyz_path - - pipeline = TwoStagePipeline( - stage1_model_config, - stage2_model_config, - stage1_sampler_config, - stage2_sampler_config, - ) - - rt_dict = pipeline(img, scale=args.scale, step=args.step) - stage1_images = rt_dict["stage1_images"] - stage2_images = rt_dict["stage2_images"] - np_imgs = np.concatenate(stage1_images, 1) - np_xyzs = np.concatenate(stage2_images, 1) - Image.fromarray(np_imgs).save(args.outdir+"pixel_images.png") - Image.fromarray(np_xyzs).save(args.outdir+"xyz_images.png") - - glb_path, obj_path = generate3d(model, np_imgs, np_xyzs, "cuda") - shutil.copy(obj_path, args.outdir+"output3d.zip") \ No newline at end of file diff --git a/apps/third_party/CRM/util/__init__.py b/apps/third_party/CRM/util/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/apps/third_party/CRM/util/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/util/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index c2271f9a2aee0aef5a1ec6b84fa433a3e04bfcd6..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/util/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/util/__pycache__/flexicubes.cpython-38.pyc b/apps/third_party/CRM/util/__pycache__/flexicubes.cpython-38.pyc deleted file mode 100644 index 37685af5cd939c94104af01552419fc96e7e4688..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/util/__pycache__/flexicubes.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/util/__pycache__/flexicubes_geometry.cpython-38.pyc b/apps/third_party/CRM/util/__pycache__/flexicubes_geometry.cpython-38.pyc deleted file mode 100644 index 7b063e820c975d5640b6f1c2684f633f461a9fac..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/util/__pycache__/flexicubes_geometry.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/util/__pycache__/renderer.cpython-38.pyc b/apps/third_party/CRM/util/__pycache__/renderer.cpython-38.pyc deleted file mode 100644 index 5fc688d3238a84ca4ffd4c61656f7a63679d72e8..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/util/__pycache__/renderer.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/util/__pycache__/tables.cpython-38.pyc b/apps/third_party/CRM/util/__pycache__/tables.cpython-38.pyc deleted file mode 100644 index e9acbd82a2a5d5a7ef03a34c9b89035b2bcd6b39..0000000000000000000000000000000000000000 Binary files a/apps/third_party/CRM/util/__pycache__/tables.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/CRM/util/flexicubes.py b/apps/third_party/CRM/util/flexicubes.py deleted file mode 100644 index 0e12d371362ea7f7f315bd33d866f4bb3510eadb..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/util/flexicubes.py +++ /dev/null @@ -1,579 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. -import torch -from util.tables import * - -__all__ = [ - 'FlexiCubes' -] - - -class FlexiCubes: - """ - This class implements the FlexiCubes method for extracting meshes from scalar fields. - It maintains a series of lookup tables and indices to support the mesh extraction process. - FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances - the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting - the surface representation through gradient-based optimization. - - During instantiation, the class loads DMC tables from a file and transforms them into - PyTorch tensors on the specified device. - - Attributes: - device (str): Specifies the computational device (default is "cuda"). - dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges - associated with each dual vertex in 256 Marching Cubes (MC) configurations. - num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of - the 256 MC configurations. - check_table (torch.Tensor): Table resolving ambiguity in cases C16 and C19 - of the DMC configurations. - tet_table (torch.Tensor): Lookup table used in tetrahedralizing the isosurface. - quad_split_1 (torch.Tensor): Indices for splitting a quad into two triangles - along one diagonal. - quad_split_2 (torch.Tensor): Alternative indices for splitting a quad into - two triangles along the other diagonal. - quad_split_train (torch.Tensor): Indices for splitting a quad into four triangles - during training by connecting all edges to their midpoints. - cube_corners (torch.Tensor): Defines the positions of a standard unit cube's - eight corners in 3D space, ordered starting from the origin (0,0,0), - moving along the x-axis, then y-axis, and finally z-axis. - Used as a blueprint for generating a voxel grid. - cube_corners_idx (torch.Tensor): Cube corners indexed as powers of 2, used - to retrieve the case id. - cube_edges (torch.Tensor): Edge connections in a cube, listed in pairs. - Used to retrieve edge vertices in DMC. - edge_dir_table (torch.Tensor): A mapping tensor that associates edge indices with - their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the - first edge is oriented along the x-axis. - dir_faces_table (torch.Tensor): A tensor that maps the corresponding axis of shared edges - across four adjacent cubes to the shared faces of these cubes. For instance, - dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along - the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively. - This tensor is only utilized during isosurface tetrahedralization. - adj_pairs (torch.Tensor): - A tensor containing index pairs that correspond to neighboring cubes that share the same edge. - qef_reg_scale (float): - The scaling factor applied to the regularization loss to prevent issues with singularity - when solving the QEF. This parameter is only used when a 'grad_func' is specified. - weight_scale (float): - The scale of weights in FlexiCubes. Should be between 0 and 1. - """ - - def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99): - - self.device = device - self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False) - self.num_vd_table = torch.tensor(num_vd_table, - dtype=torch.long, device=device, requires_grad=False) - self.check_table = torch.tensor( - check_table, - dtype=torch.long, device=device, requires_grad=False) - - self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False) - self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False) - self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False) - self.quad_split_train = torch.tensor( - [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False) - - self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ - 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device) - self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False)) - self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, - 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False) - - self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1], - dtype=torch.long, device=device) - self.dir_faces_table = torch.tensor([ - [[5, 4], [3, 2], [4, 5], [2, 3]], - [[5, 4], [1, 0], [4, 5], [0, 1]], - [[3, 2], [1, 0], [2, 3], [0, 1]] - ], dtype=torch.long, device=device) - self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device) - self.qef_reg_scale = qef_reg_scale - self.weight_scale = weight_scale - - def construct_voxel_grid(self, res): - """ - Generates a voxel grid based on the specified resolution. - - Args: - res (int or list[int]): The resolution of the voxel grid. If an integer - is provided, it is used for all three dimensions. If a list or tuple - of 3 integers is provided, they define the resolution for the x, - y, and z dimensions respectively. - - Returns: - (torch.Tensor, torch.Tensor): Returns the vertices and the indices of the - cube corners (index into vertices) of the constructed voxel grid. - The vertices are centered at the origin, with the length of each - dimension in the grid being one. - """ - base_cube_f = torch.arange(8).to(self.device) - if isinstance(res, int): - res = (res, res, res) - voxel_grid_template = torch.ones(res, device=self.device) - - res = torch.tensor([res], dtype=torch.float, device=self.device) - coords = torch.nonzero(voxel_grid_template).float() / res # N, 3 - verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(-1, 3) - cubes = (base_cube_f.unsqueeze(0) + - torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8).reshape(-1) - - verts_rounded = torch.round(verts * 10**5) / (10**5) - verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True) - cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8) - - return verts_unique - 0.5, cubes - - def __call__(self, x_nx3, s_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None, - gamma_f=None, training=False, output_tetmesh=False, grad_func=None): - r""" - Main function for mesh extraction from scalar field using FlexiCubes. This function converts - discrete signed distance fields, encoded on voxel grids and additional per-cube parameters, - to triangle or tetrahedral meshes using a differentiable operation as described in - `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances - mesh quality and geometric fidelity by adjusting the surface representation based on gradient - optimization. The output surface is differentiable with respect to the input vertex positions, - scalar field values, and weight parameters. - - If you intend to extract a surface mesh from a fixed Signed Distance Field without the - optimization of parameters, it is suggested to provide the "grad_func" which should - return the surface gradient at any given 3D position. When grad_func is provided, the process - to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as - described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy. - Please note, this approach is non-differentiable. - - For more details and example usage in optimization, refer to the - `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper. - - Args: - x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed. - s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values - denote that the corresponding vertex resides inside the isosurface. This affects - the directions of the extracted triangle faces and volume to be tetrahedralized. - cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid. - res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it - is used for all three dimensions. If a list or tuple of 3 integers is provided, they - specify the resolution for the x, y, and z dimensions respectively. - beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual - vertices positioning. Defaults to uniform value for all edges. - alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual - vertices positioning. Defaults to uniform value for all vertices. - gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of - quadrilaterals into triangles. Defaults to uniform value for all cubes. - training (bool, optional): If set to True, applies differentiable quad splitting for - training. Defaults to False. - output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise, - outputs a triangular mesh. Defaults to False. - grad_func (callable, optional): A function to compute the surface gradient at specified - 3D positions (input: Nx3 positions). The function should return gradients as an Nx3 - tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None. - - Returns: - (torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing: - - Vertices for the extracted triangular/tetrahedral mesh. - - Faces for the extracted triangular/tetrahedral mesh. - - Regularizer L_dev, computed per dual vertex. - - .. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization: - https://research.nvidia.com/labs/toronto-ai/flexicubes/ - .. _Manifold Dual Contouring: - https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf - """ - - surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8) - if surf_cubes.sum() == 0: - return torch.zeros( - (0, 3), - device=self.device), torch.zeros( - (0, 4), - dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros( - (0, 3), - dtype=torch.long, device=self.device), torch.zeros( - (0), - device=self.device) - beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes) - - case_ids = self._get_case_id(occ_fx8, surf_cubes, res) - - surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes) - - vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd( - x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func) - vertices, faces, s_edges, edge_indices = self._triangulate( - s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func) - if not output_tetmesh: - return vertices, faces, L_dev - else: - vertices, tets = self._tetrahedralize( - x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices, - surf_cubes, training) - return vertices, tets, L_dev - - def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges): - """ - Regularizer L_dev as in Equation 8 - """ - dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1) - mean_l2 = torch.zeros_like(vd[:, 0]) - mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float() - mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs() - return mad - - def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes): - """ - Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones. - """ - n_cubes = surf_cubes.shape[0] - - if beta_fx12 is not None: - beta_fx12 = (torch.tanh(beta_fx12) * self.weight_scale + 1) - else: - beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device) - - if alpha_fx8 is not None: - alpha_fx8 = (torch.tanh(alpha_fx8) * self.weight_scale + 1) - else: - alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device) - - if gamma_f is not None: - gamma_f = torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale)/2 - else: - gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device) - - return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes] - - @torch.no_grad() - def _get_case_id(self, occ_fx8, surf_cubes, res): - """ - Obtains the ID of topology cases based on cell corner occupancy. This function resolves the - ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the - supplementary material. It should be noted that this function assumes a regular grid. - """ - case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1) - - problem_config = self.check_table.to(self.device)[case_ids] - to_check = problem_config[..., 0] == 1 - problem_config = problem_config[to_check] - if not isinstance(res, (list, tuple)): - res = [res, res, res] - - # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array, - # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes). - # This allows efficient checking on adjacent cubes. - problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long) - vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3 - vol_idx_problem = vol_idx[surf_cubes][to_check] - problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config - vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4] - - within_range = ( - vol_idx_problem_adj[..., 0] >= 0) & ( - vol_idx_problem_adj[..., 0] < res[0]) & ( - vol_idx_problem_adj[..., 1] >= 0) & ( - vol_idx_problem_adj[..., 1] < res[1]) & ( - vol_idx_problem_adj[..., 2] >= 0) & ( - vol_idx_problem_adj[..., 2] < res[2]) - - vol_idx_problem = vol_idx_problem[within_range] - vol_idx_problem_adj = vol_idx_problem_adj[within_range] - problem_config = problem_config[within_range] - problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0], - vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]] - # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted. - to_invert = (problem_config_adj[..., 0] == 1) - idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert] - case_ids.index_put_((idx,), problem_config[to_invert][..., -1]) - return case_ids - - @torch.no_grad() - def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes): - """ - Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge - can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge - and marks the cube edges with this index. - """ - occ_n = s_n < 0 - all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2) - unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) - - unique_edges = unique_edges.long() - mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 - - surf_edges_mask = mask_edges[_idx_map] - counts = counts[_idx_map] - - mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device) * -1 - mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device) - # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index - # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1. - idx_map = mapping[_idx_map] - surf_edges = unique_edges[mask_edges] - return surf_edges, idx_map, counts, surf_edges_mask - - @torch.no_grad() - def _identify_surf_cubes(self, s_n, cube_fx8): - """ - Identifies grid cubes that intersect with the underlying surface by checking if the signs at - all corners are not identical. - """ - occ_n = s_n < 0 - occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) - _occ_sum = torch.sum(occ_fx8, -1) - surf_cubes = (_occ_sum > 0) & (_occ_sum < 8) - return surf_cubes, occ_fx8 - - def _linear_interp(self, edges_weight, edges_x): - """ - Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'. - """ - edge_dim = edges_weight.dim() - 2 - assert edges_weight.shape[edge_dim] == 2 - edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), - - torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim) - denominator = edges_weight.sum(edge_dim) - ue = (edges_x * edges_weight).sum(edge_dim) / denominator - return ue - - def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None): - p_bxnx3 = p_bxnx3.reshape(-1, 7, 3) - norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3) - c_bx3 = c_bx3.reshape(-1, 3) - A = norm_bxnx3 - B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True) - - A_reg = (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1) - B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1) - A = torch.cat([A, A_reg], 1) - B = torch.cat([B, B_reg], 1) - dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1) - return dual_verts - - def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func): - """ - Computes the location of dual vertices as described in Section 4.2 - """ - alpha_nx12x2 = torch.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2) - surf_edges_x = torch.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3) - surf_edges_s = torch.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1) - zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x) - - idx_map = idx_map.reshape(-1, 12) - num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0) - edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], [] - - total_num_vd = 0 - vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False) - if grad_func is not None: - normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1) - vd = [] - for num in torch.unique(num_vd): - cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching) - curr_num_vd = cur_cubes.sum() * num - curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7) - curr_edge_group_to_vd = torch.arange( - curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd - total_num_vd += curr_num_vd - curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[ - cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group) - - curr_mask = (curr_edge_group != -1) - edge_group.append(torch.masked_select(curr_edge_group, curr_mask)) - edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask)) - edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask)) - vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True)) - vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1)) - - if grad_func is not None: - with torch.no_grad(): - cube_e_verts_idx = idx_map[cur_cubes] - curr_edge_group[~curr_mask] = 0 - - verts_group_idx = torch.gather(input=cube_e_verts_idx, dim=1, index=curr_edge_group) - verts_group_idx[verts_group_idx == -1] = 0 - verts_group_pos = torch.index_select( - input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0).reshape(-1, num.item(), 7, 3) - v0 = x_nx3[surf_cubes_fx8[cur_cubes][:, 0]].reshape(-1, 1, 1, 3).repeat(1, num.item(), 1, 1) - curr_mask = curr_mask.reshape(-1, num.item(), 7, 1) - verts_centroid = (verts_group_pos * curr_mask).sum(2) / (curr_mask.sum(2)) - - normals_bx7x3 = torch.index_select(input=normals, index=verts_group_idx.reshape(-1), dim=0).reshape( - -1, num.item(), 7, - 3) - curr_mask = curr_mask.squeeze(2) - vd.append(self._solve_vd_QEF((verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask, - verts_centroid - v0.squeeze(2)) + v0.reshape(-1, 3)) - edge_group = torch.cat(edge_group) - edge_group_to_vd = torch.cat(edge_group_to_vd) - edge_group_to_cube = torch.cat(edge_group_to_cube) - vd_num_edges = torch.cat(vd_num_edges) - vd_gamma = torch.cat(vd_gamma) - - if grad_func is not None: - vd = torch.cat(vd) - L_dev = torch.zeros([1], device=self.device) - else: - vd = torch.zeros((total_num_vd, 3), device=self.device) - beta_sum = torch.zeros((total_num_vd, 1), device=self.device) - - idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group) - - x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3) - s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1) - - zero_crossing_group = torch.index_select( - input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3) - - alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0, - index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1) - ue_group = self._linear_interp(s_group * alpha_group, x_group) - - beta_group = torch.gather(input=beta_fx12.reshape(-1), dim=0, - index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1) - beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group) - vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum - L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges) - - v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd - - vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube * - 12 + edge_group, src=v_idx[edge_group_to_vd]) - - return vd, L_dev, vd_gamma, vd_idx_map - - def _triangulate(self, s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func): - """ - Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into - triangles based on the gamma parameter, as described in Section 4.3. - """ - with torch.no_grad(): - group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes. - group = idx_map.reshape(-1)[group_mask] - vd_idx = vd_idx_map[group_mask] - edge_indices, indices = torch.sort(group, stable=True) - quad_vd_idx = vd_idx[indices].reshape(-1, 4) - - # Ensure all face directions point towards the positive SDF to maintain consistent winding. - s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2) - flip_mask = s_edges[:, 0] > 0 - quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]], - quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]])) - if grad_func is not None: - # when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients. - with torch.no_grad(): - vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1) - quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) - gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True) - gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True) - else: - quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4) - gamma_02 = torch.index_select(input=quad_gamma, index=torch.tensor( - 0, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1) - gamma_13 = torch.index_select(input=quad_gamma, index=torch.tensor( - 1, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1) - if not training: - mask = (gamma_02 > gamma_13).squeeze(1) - faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device) - faces[mask] = quad_vd_idx[mask][:, self.quad_split_1] - faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2] - faces = faces.reshape(-1, 3) - else: - vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) - vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) + - torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2 - vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) + - torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2 - weight_sum = (gamma_02 + gamma_13) + 1e-8 - vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / - weight_sum.unsqueeze(-1)).squeeze(1) - vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0] - vd = torch.cat([vd, vd_center]) - faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2) - faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3) - return vd, faces, s_edges, edge_indices - - def _tetrahedralize( - self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices, - surf_cubes, training): - """ - Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5. - """ - occ_n = s_n < 0 - occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) - occ_sum = torch.sum(occ_fx8, -1) - - inside_verts = x_nx3[occ_n] - mapping_inside_verts = torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1 - mapping_inside_verts[occ_n] = torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0] - """ - For each grid edge connecting two grid vertices with different - signs, we first form a four-sided pyramid by connecting one - of the grid vertices with four mesh vertices that correspond - to the grid edge and then subdivide the pyramid into two tetrahedra - """ - inside_verts_idx = mapping_inside_verts[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[ - s_edges < 0]] - if not training: - inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1) - else: - inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1) - - tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1) - """ - For each grid edge connecting two grid vertices with the - same sign, the tetrahedron is formed by the two grid vertices - and two vertices in consecutive adjacent cells - """ - inside_cubes = (occ_sum == 8) - inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1) - inside_cubes_center_idx = torch.arange( - inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0] - - surface_n_inside_cubes = surf_cubes | inside_cubes - edge_center_vertex_idx = torch.ones(((surface_n_inside_cubes).sum(), 13), - dtype=torch.long, device=x_nx3.device) * -1 - surf_cubes = surf_cubes[surface_n_inside_cubes] - inside_cubes = inside_cubes[surface_n_inside_cubes] - edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12) - edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx - - all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2) - unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) - unique_edges = unique_edges.long() - mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2 - mask = mask_edges[_idx_map] - counts = counts[_idx_map] - mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1 - mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device) - idx_map = mapping[_idx_map] - - group_mask = (counts == 4) & mask - group = idx_map.reshape(-1)[group_mask] - edge_indices, indices = torch.sort(group) - cube_idx = torch.arange((_idx_map.shape[0] // 12), dtype=torch.long, - device=self.device).unsqueeze(1).expand(-1, 12).reshape(-1)[group_mask] - edge_idx = torch.arange((12), dtype=torch.long, device=self.device).unsqueeze( - 0).expand(_idx_map.shape[0] // 12, -1).reshape(-1)[group_mask] - # Identify the face shared by the adjacent cells. - cube_idx_4 = cube_idx[indices].reshape(-1, 4) - edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0] - shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1) - cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1) - # Identify an edge of the face with different signs and - # select the mesh vertex corresponding to the identified edge. - case_ids_expand = torch.ones((surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device) * 255 - case_ids_expand[surf_cubes] = case_ids - cases = case_ids_expand[cube_idx_4x2] - quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2) - mask = (quad_edge == -1).sum(-1) == 0 - inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2) - tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask] - - tets = torch.cat([tets_surface, tets_inside]) - vertices = torch.cat([vertices, inside_verts, inside_cubes_center]) - return vertices, tets diff --git a/apps/third_party/CRM/util/flexicubes_geometry.py b/apps/third_party/CRM/util/flexicubes_geometry.py deleted file mode 100644 index 5dec7635e7f275f6ef3867223ea2700d3af6f4ea..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/util/flexicubes_geometry.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. - -import torch -from util.flexicubes import FlexiCubes # replace later -# from dmtet import sdf_reg_loss_batch -import torch.nn.functional as F - -def get_center_boundary_index(grid_res, device): - v = torch.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device) - v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True - center_indices = torch.nonzero(v.reshape(-1)) - - v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False - v[:2, ...] = True - v[-2:, ...] = True - v[:, :2, ...] = True - v[:, -2:, ...] = True - v[:, :, :2] = True - v[:, :, -2:] = True - boundary_indices = torch.nonzero(v.reshape(-1)) - return center_indices, boundary_indices - -############################################################################### -# Geometry interface -############################################################################### -class FlexiCubesGeometry(object): - def __init__( - self, grid_res=64, scale=2.0, device='cuda', renderer=None, - render_type='neural_render', args=None): - super(FlexiCubesGeometry, self).__init__() - self.grid_res = grid_res - self.device = device - self.args = args - self.fc = FlexiCubes(device, weight_scale=0.5) - self.verts, self.indices = self.fc.construct_voxel_grid(grid_res) - if isinstance(scale, list): - self.verts[:, 0] = self.verts[:, 0] * scale[0] - self.verts[:, 1] = self.verts[:, 1] * scale[1] - self.verts[:, 2] = self.verts[:, 2] * scale[1] - else: - self.verts = self.verts * scale - - all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2) - self.all_edges = torch.unique(all_edges, dim=0) - - # Parameters used for fix boundary sdf - self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res, device) - self.renderer = renderer - self.render_type = render_type - - def getAABB(self): - return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values - - def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False): - if indices is None: - indices = self.indices - - verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, self.grid_res, - beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20], - gamma_f=weight_n[:, 20], training=is_training - ) - return verts, faces, v_reg_loss - - - def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False): - return_value = dict() - if self.render_type == 'neural_render': - tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh( - mesh_v_nx3.unsqueeze(dim=0), - mesh_f_fx3.int(), - camera_mv_bx4x4, - mesh_v_nx3.unsqueeze(dim=0), - resolution=resolution, - device=self.device, - hierarchical_mask=hierarchical_mask - ) - - return_value['tex_pos'] = tex_pos - return_value['mask'] = mask - return_value['hard_mask'] = hard_mask - return_value['rast'] = rast - return_value['v_pos_clip'] = v_pos_clip - return_value['mask_pyramid'] = mask_pyramid - return_value['depth'] = depth - else: - raise NotImplementedError - - return return_value - - def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256): - # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1 - v_list = [] - f_list = [] - n_batch = v_deformed_bxnx3.shape[0] - all_render_output = [] - for i_batch in range(n_batch): - verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch]) - v_list.append(verts_nx3) - f_list.append(faces_fx3) - render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution) - all_render_output.append(render_output) - - # Concatenate all render output - return_keys = all_render_output[0].keys() - return_value = dict() - for k in return_keys: - value = [v[k] for v in all_render_output] - return_value[k] = value - # We can do concatenation outside of the render - return return_value diff --git a/apps/third_party/CRM/util/renderer.py b/apps/third_party/CRM/util/renderer.py deleted file mode 100644 index 2d1dc69dcaaf902cbb70af06834015ed1730dc0e..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/util/renderer.py +++ /dev/null @@ -1,49 +0,0 @@ - -import torch -import torch.nn as nn -import nvdiffrast.torch as dr -from util.flexicubes_geometry import FlexiCubesGeometry - -class Renderer(nn.Module): - def __init__(self, tet_grid_size, camera_angle_num, scale, geo_type): - super().__init__() - - self.tet_grid_size = tet_grid_size - self.camera_angle_num = camera_angle_num - self.scale = scale - self.geo_type = geo_type - self.glctx = dr.RasterizeCudaContext() - - if self.geo_type == "flex": - self.flexicubes = FlexiCubesGeometry(grid_res = self.tet_grid_size) - - def forward(self, data, sdf, deform, verts, tets, training=False, weight = None): - - results = {} - - deform = torch.tanh(deform) / self.tet_grid_size * self.scale / 0.95 - if self.geo_type == "flex": - deform = deform *0.5 - - v_deformed = verts + deform - - verts_list = [] - faces_list = [] - reg_list = [] - n_shape = verts.shape[0] - for i in range(n_shape): - verts_i, faces_i, reg_i = self.flexicubes.get_mesh(v_deformed[i], sdf[i].squeeze(dim=-1), - with_uv=False, indices=tets, weight_n=weight[i], is_training=training) - - verts_list.append(verts_i) - faces_list.append(faces_i) - reg_list.append(reg_i) - verts = verts_list - faces = faces_list - - flexicubes_surface_reg = torch.cat(reg_list).mean() - flexicubes_weight_reg = (weight ** 2).mean() - results["flex_surf_loss"] = flexicubes_surface_reg - results["flex_weight_loss"] = flexicubes_weight_reg - - return results, verts, faces \ No newline at end of file diff --git a/apps/third_party/CRM/util/tables.py b/apps/third_party/CRM/util/tables.py deleted file mode 100644 index 936a4bc5e2f95891f72651f2c42272e01a3a2bc3..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/util/tables.py +++ /dev/null @@ -1,791 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. -dmc_table = [ -[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]], -[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]], -[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]] -] -num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2, -2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, -1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1, -1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2, -2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, -3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1, -2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, -1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, -1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0] -check_table = [ -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 1, 0, 0, 194], -[1, -1, 0, 0, 193], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 164], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 161], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 152], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 145], -[1, 0, 0, 1, 144], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 137], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 133], -[1, 0, 1, 0, 132], -[1, 1, 0, 0, 131], -[1, 1, 0, 0, 130], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 100], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 98], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 96], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 88], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 82], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 74], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 72], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 70], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 67], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 65], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 1, 0, 0, 56], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 52], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 1, 0, 0, 44], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 1, 0, 0, 40], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 38], -[1, 0, -1, 0, 37], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 33], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 28], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 26], -[1, 0, 0, -1, 25], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 20], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 18], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 9], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 6], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0] -] -tet_table = [ -[-1, -1, -1, -1, -1, -1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[4, 4, 4, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[4, 0, 0, 4, 4, -1], -[1, 1, 1, 1, 1, 1], -[4, 4, 4, 4, 4, 4], -[0, 4, 0, 4, 4, -1], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[5, 5, 5, 5, 5, 5], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, -1, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, -1, 2, 4, 4, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 4, 4, 2], -[1, 1, 1, 1, 1, 1], -[2, 4, 2, 4, 4, 2], -[0, 4, 0, 4, 4, 0], -[2, 0, 2, 0, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 5, 2, 5, 5, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 0, 0, 2], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 1, 1], -[0, 1, 1, -1, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[4, 1, 1, 4, 4, 1], -[0, 1, 1, 0, 0, 1], -[4, 0, 0, 4, 4, 0], -[2, 2, 2, 2, 2, 2], -[-1, 1, 1, 4, 4, 1], -[0, 1, 1, 4, 4, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[5, 1, 1, 5, 5, 1], -[0, 1, 1, 0, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[8, 8, 8, 8, 8, 8], -[1, 1, 1, 4, 4, 1], -[0, 0, 0, 0, 0, 0], -[4, 0, 0, 4, 4, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 4, 4, 1], -[0, 4, 0, 4, 4, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 5, 5, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[5, 5, 5, 5, 5, 5], -[6, 6, 6, 6, 6, 6], -[6, -1, 0, 6, 0, 6], -[6, 0, 0, 6, 0, 6], -[6, 1, 1, 6, 1, 6], -[4, 4, 4, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[4, 0, 0, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[6, 4, -1, 6, 4, 6], -[6, 4, 0, 6, 4, 6], -[6, 0, 0, 6, 0, 6], -[6, 1, 1, 6, 1, 6], -[5, 5, 5, 5, 5, 5], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 2, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[2, 4, 2, 2, 4, 2], -[0, 4, 0, 4, 4, 0], -[2, 0, 2, 2, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[6, 1, 1, 6, -1, 6], -[6, 1, 1, 6, 0, 6], -[6, 0, 0, 6, 0, 6], -[6, 2, 2, 6, 2, 6], -[4, 1, 1, 4, 4, 1], -[0, 1, 1, 0, 0, 1], -[4, 0, 0, 4, 4, 4], -[2, 2, 2, 2, 2, 2], -[6, 1, 1, 6, 4, 6], -[6, 1, 1, 6, 4, 6], -[6, 0, 0, 6, 0, 6], -[6, 2, 2, 6, 2, 6], -[5, 1, 1, 5, 5, 1], -[0, 1, 1, 0, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[6, 6, 6, 6, 6, 6], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 4, 1], -[0, 4, 0, 4, 4, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 5, 0, 5, 0, 5], -[5, 5, 5, 5, 5, 5], -[5, 5, 5, 5, 5, 5], -[0, 5, 0, 5, 0, 5], -[-1, 5, 0, 5, 0, 5], -[1, 5, 1, 5, 1, 5], -[4, 5, -1, 5, 4, 5], -[0, 5, 0, 5, 0, 5], -[4, 5, 0, 5, 4, 5], -[1, 5, 1, 5, 1, 5], -[4, 4, 4, 4, 4, 4], -[0, 4, 0, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[6, 6, 6, 6, 6, 6], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[2, 5, 2, 5, -1, 5], -[0, 5, 0, 5, 0, 5], -[2, 5, 2, 5, 0, 5], -[1, 5, 1, 5, 1, 5], -[2, 5, 2, 5, 4, 5], -[0, 5, 0, 5, 0, 5], -[2, 5, 2, 5, 4, 5], -[1, 5, 1, 5, 1, 5], -[2, 4, 2, 4, 4, 2], -[0, 4, 0, 4, 4, 4], -[2, 0, 2, 0, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 6, 2, 6, 6, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 0, 0, 2], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 1, 1], -[0, 1, 1, 1, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[4, 1, 1, 1, 4, 1], -[0, 1, 1, 1, 0, 1], -[4, 0, 0, 4, 4, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[5, 5, 5, 5, 5, 5], -[1, 1, 1, 1, 4, 1], -[0, 0, 0, 0, 0, 0], -[4, 0, 0, 4, 4, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[6, 0, 0, 6, 0, 6], -[0, 0, 0, 0, 0, 0], -[6, 6, 6, 6, 6, 6], -[5, 5, 5, 5, 5, 5], -[5, 5, 0, 5, 0, 5], -[5, 5, 0, 5, 0, 5], -[5, 5, 1, 5, 1, 5], -[4, 4, 4, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[4, 4, 0, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[4, 4, 4, 4, 4, 4], -[4, 4, 0, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[8, 8, 8, 8, 8, 8], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 1, 1, 4, 4, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[2, 4, 2, 4, 4, 2], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[5, 5, 5, 5, 5, 5], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[12, 12, 12, 12, 12, 12] -] diff --git a/apps/third_party/CRM/util/utils.py b/apps/third_party/CRM/util/utils.py deleted file mode 100644 index ab3cad9b703b1cdd9a3d6a66e82f898185ebf6c8..0000000000000000000000000000000000000000 --- a/apps/third_party/CRM/util/utils.py +++ /dev/null @@ -1,194 +0,0 @@ -import numpy as np -import torch -import random - - -# Reworked so this matches gluPerspective / glm::perspective, using fovy -def perspective(fovx=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None): - # y = np.tan(fovy / 2) - x = np.tan(fovx / 2) - return torch.tensor([[1/x, 0, 0, 0], - [ 0, -aspect/x, 0, 0], - [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], - [ 0, 0, -1, 0]], dtype=torch.float32, device=device) - - -def translate(x, y, z, device=None): - return torch.tensor([[1, 0, 0, x], - [0, 1, 0, y], - [0, 0, 1, z], - [0, 0, 0, 1]], dtype=torch.float32, device=device) - - -def rotate_x(a, device=None): - s, c = np.sin(a), np.cos(a) - return torch.tensor([[1, 0, 0, 0], - [0, c, -s, 0], - [0, s, c, 0], - [0, 0, 0, 1]], dtype=torch.float32, device=device) - - -def rotate_y(a, device=None): - s, c = np.sin(a), np.cos(a) - return torch.tensor([[ c, 0, s, 0], - [ 0, 1, 0, 0], - [-s, 0, c, 0], - [ 0, 0, 0, 1]], dtype=torch.float32, device=device) - - -def rotate_z(a, device=None): - s, c = np.sin(a), np.cos(a) - return torch.tensor([[c, -s, 0, 0], - [s, c, 0, 0], - [0, 0, 1, 0], - [0, 0, 0, 1]], dtype=torch.float32, device=device) - -@torch.no_grad() -def batch_random_rotation_translation(b, t, device=None): - m = np.random.normal(size=[b, 3, 3]) - m[:, 1] = np.cross(m[:, 0], m[:, 2]) - m[:, 2] = np.cross(m[:, 0], m[:, 1]) - m = m / np.linalg.norm(m, axis=2, keepdims=True) - m = np.pad(m, [[0, 0], [0, 1], [0, 1]], mode='constant') - m[:, 3, 3] = 1.0 - m[:, :3, 3] = np.random.uniform(-t, t, size=[b, 3]) - return torch.tensor(m, dtype=torch.float32, device=device) - -@torch.no_grad() -def random_rotation_translation(t, device=None): - m = np.random.normal(size=[3, 3]) - m[1] = np.cross(m[0], m[2]) - m[2] = np.cross(m[0], m[1]) - m = m / np.linalg.norm(m, axis=1, keepdims=True) - m = np.pad(m, [[0, 1], [0, 1]], mode='constant') - m[3, 3] = 1.0 - m[:3, 3] = np.random.uniform(-t, t, size=[3]) - return torch.tensor(m, dtype=torch.float32, device=device) - - -@torch.no_grad() -def random_rotation(device=None): - m = np.random.normal(size=[3, 3]) - m[1] = np.cross(m[0], m[2]) - m[2] = np.cross(m[0], m[1]) - m = m / np.linalg.norm(m, axis=1, keepdims=True) - m = np.pad(m, [[0, 1], [0, 1]], mode='constant') - m[3, 3] = 1.0 - m[:3, 3] = np.array([0,0,0]).astype(np.float32) - return torch.tensor(m, dtype=torch.float32, device=device) - - -def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return torch.sum(x*y, -1, keepdim=True) - - -def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: - return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN - - -def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: - return x / length(x, eps) - - -def lr_schedule(iter, warmup_iter, scheduler_decay): - if iter < warmup_iter: - return iter / warmup_iter - return max(0.0, 10 ** ( - -(iter - warmup_iter) * scheduler_decay)) - - -def trans_depth(depth): - depth = depth[0].detach().cpu().numpy() - valid = depth > 0 - depth[valid] -= depth[valid].min() - depth[valid] = ((depth[valid] / depth[valid].max()) * 255) - return depth.astype('uint8') - - -def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): - assert isinstance(input, torch.Tensor) - if posinf is None: - posinf = torch.finfo(input.dtype).max - if neginf is None: - neginf = torch.finfo(input.dtype).min - assert nan == 0 - return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) - - -def load_item(filepath): - with open(filepath, 'r') as f: - items = [name.strip() for name in f.readlines()] - return set(items) - -def load_prompt(filepath): - uuid2prompt = {} - with open(filepath, 'r') as f: - for line in f.readlines(): - list_line = line.split(',') - uuid2prompt[list_line[0]] = ','.join(list_line[1:]).strip() - return uuid2prompt - -def resize_and_center_image(image_tensor, scale=0.95, c = 0, shift = 0, rgb=False, aug_shift = 0): - if scale == 1: - return image_tensor - B, C, H, W = image_tensor.shape - new_H, new_W = int(H * scale), int(W * scale) - resized_image = torch.nn.functional.interpolate(image_tensor, size=(new_H, new_W), mode='bilinear', align_corners=False).squeeze(0) - background = torch.zeros_like(image_tensor) + c - start_y, start_x = (H - new_H) // 2, (W - new_W) // 2 - if shift == 0: - background[:, :, start_y:start_y + new_H, start_x:start_x + new_W] = resized_image - else: - for i in range(B): - randx = random.randint(-shift, shift) - randy = random.randint(-shift, shift) - if rgb == True: - if i == 0 or i==2 or i==4: - randx = 0 - randy = 0 - background[i, :, start_y+randy:start_y + new_H+randy, start_x+randx:start_x + new_W+randx] = resized_image[i] - if aug_shift == 0: - return background - for i in range(B): - for j in range(C): - background[i, j, :, :] += (random.random() - 0.5)*2 * aug_shift / 255 - return background - -def get_tri(triview_color, dim = 1, blender=True, c = 0, scale=0.95, shift = 0, fix = False, rgb=False, aug_shift = 0): - # triview_color: [6,C,H,W] - # rgb is useful when shift is not 0 - triview_color = resize_and_center_image(triview_color, scale=scale, c = c, shift=shift,rgb=rgb, aug_shift = aug_shift) - if blender is False: - triview_color0 = torch.rot90(triview_color[0],k=2,dims=[1,2]) - triview_color1 = torch.rot90(triview_color[4],k=1,dims=[1,2]).flip(2).flip(1) - triview_color2 = torch.rot90(triview_color[5],k=1,dims=[1,2]).flip(2) - triview_color3 = torch.rot90(triview_color[3],k=2,dims=[1,2]).flip(2) - triview_color4 = torch.rot90(triview_color[1],k=3,dims=[1,2]).flip(1) - triview_color5 = torch.rot90(triview_color[2],k=3,dims=[1,2]).flip(1).flip(2) - else: - triview_color0 = torch.rot90(triview_color[2],k=2,dims=[1,2]) - triview_color1 = torch.rot90(triview_color[4],k=0,dims=[1,2]).flip(2).flip(1) - triview_color2 = torch.rot90(torch.rot90(triview_color[0],k=3,dims=[1,2]).flip(2), k=2,dims=[1,2]) - triview_color3 = torch.rot90(torch.rot90(triview_color[5],k=2,dims=[1,2]).flip(2), k=2,dims=[1,2]) - triview_color4 = torch.rot90(triview_color[1],k=2,dims=[1,2]).flip(1).flip(1).flip(2) - triview_color5 = torch.rot90(triview_color[3],k=1,dims=[1,2]).flip(1).flip(2) - if fix == True: - triview_color0[1] = triview_color0[1] * 0 - triview_color0[2] = triview_color0[2] * 0 - triview_color3[1] = triview_color3[1] * 0 - triview_color3[2] = triview_color3[2] * 0 - - triview_color1[0] = triview_color1[0] * 0 - triview_color1[1] = triview_color1[1] * 0 - triview_color4[0] = triview_color4[0] * 0 - triview_color4[1] = triview_color4[1] * 0 - - triview_color2[0] = triview_color2[0] * 0 - triview_color2[2] = triview_color2[2] * 0 - triview_color5[0] = triview_color5[0] * 0 - triview_color5[2] = triview_color5[2] * 0 - color_tensor1_gt = torch.cat((triview_color0, triview_color1, triview_color2), dim=2) - color_tensor2_gt = torch.cat((triview_color3, triview_color4, triview_color5), dim=2) - color_tensor_gt = torch.cat((color_tensor1_gt, color_tensor2_gt), dim = dim) - return color_tensor_gt - diff --git a/apps/third_party/Era3D/configs/test_unclip-512-6view.yaml b/apps/third_party/Era3D/configs/test_unclip-512-6view.yaml deleted file mode 100644 index d9353a5984d60e96390da5a3d06a70f09ae9d6b0..0000000000000000000000000000000000000000 --- a/apps/third_party/Era3D/configs/test_unclip-512-6view.yaml +++ /dev/null @@ -1,56 +0,0 @@ -pretrained_model_name_or_path: './MacLab-Era3D-512-6view' -revision: null - -num_views: 6 -validation_dataset: - prompt_embeds_path: mvdiffusion/data/fixed_prompt_embeds_6view - root_dir: 'examples' - num_views: ${num_views} - bg_color: 'white' - img_wh: [512, 512] - num_validation_samples: 1000 - crop_size: 420 - -pred_type: 'joint' -save_dir: 'mv_res' -save_mode: 'rgba' # 'concat', 'rgba', 'rgb' -seed: 42 -validation_batch_size: 1 -dataloader_num_workers: 1 -local_rank: -1 - -pipe_kwargs: - num_views: ${num_views} - -validation_guidance_scales: [3.0] -pipe_validation_kwargs: - num_inference_steps: 40 - eta: 1.0 - -validation_grid_nrow: ${num_views} -regress_elevation: true -regress_focal_length: true -unet_from_pretrained_kwargs: - unclip: true - sdxl: false - num_views: ${num_views} - sample_size: 64 - zero_init_conv_in: false # modify - - regress_elevation: ${regress_elevation} - regress_focal_length: ${regress_focal_length} - camera_embedding_type: e_de_da_sincos - projection_camera_embeddings_input_dim: 4 # 2 for elevation and 6 for focal_length - zero_init_camera_projection: false - num_regress_blocks: 3 - - cd_attention_last: false - cd_attention_mid: false - multiview_attention: true - sparse_mv_attention: true - selfattn_block: self_rowwise - mvcd_attention: true - - use_dino: false - -enable_xformers_memory_efficient_attention: true \ No newline at end of file diff --git a/apps/third_party/Era3D/data/fixed_prompt_embeds_6view/clr_embeds.pt b/apps/third_party/Era3D/data/fixed_prompt_embeds_6view/clr_embeds.pt deleted file mode 100644 index de105d6a0a017da97af4644608a87785ec54d9cb..0000000000000000000000000000000000000000 --- a/apps/third_party/Era3D/data/fixed_prompt_embeds_6view/clr_embeds.pt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b9e51666588d0f075e031262744d371e12076160231aab19a531dbf7ab976e4d -size 946932 diff --git a/apps/third_party/Era3D/data/fixed_prompt_embeds_6view/normal_embeds.pt b/apps/third_party/Era3D/data/fixed_prompt_embeds_6view/normal_embeds.pt deleted file mode 100644 index 7fb88dcf24443b235588cf426eba3951316e825f..0000000000000000000000000000000000000000 --- a/apps/third_party/Era3D/data/fixed_prompt_embeds_6view/normal_embeds.pt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:53dfcd17f62fbfd8aeba60b1b05fa7559d72179738fd048e2ac1d53e5be5ed9d -size 946941 diff --git a/apps/third_party/Era3D/data/normal_utils.py b/apps/third_party/Era3D/data/normal_utils.py deleted file mode 100644 index 1a390082ca0d891f044eca3c6fc291d8f29036de..0000000000000000000000000000000000000000 --- a/apps/third_party/Era3D/data/normal_utils.py +++ /dev/null @@ -1,78 +0,0 @@ -import numpy as np -def deg2rad(deg): - return deg*np.pi/180 - -def inv_RT(RT): - # RT_h = np.concatenate([RT, np.array([[0,0,0,1]])], axis=0) - RT_inv = np.linalg.inv(RT) - - return RT_inv[:3, :] -def camNormal2worldNormal(rot_c2w, camNormal): - H,W,_ = camNormal.shape - normal_img = np.matmul(rot_c2w[None, :, :], camNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3]) - - return normal_img - -def worldNormal2camNormal(rot_w2c, normal_map_world): - H,W,_ = normal_map_world.shape - # normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3]) - - # faster version - # Reshape the normal map into a 2D array where each row represents a normal vector - normal_map_flat = normal_map_world.reshape(-1, 3) - - # Transform the normal vectors using the transformation matrix - normal_map_camera_flat = np.dot(normal_map_flat, rot_w2c.T) - - # Reshape the transformed normal map back to its original shape - normal_map_camera = normal_map_camera_flat.reshape(normal_map_world.shape) - - return normal_map_camera - -def trans_normal(normal, RT_w2c, RT_w2c_target): - - # normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal) - # normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3,:3], normal_world) - - relative_RT = np.matmul(RT_w2c_target[:3,:3], np.linalg.inv(RT_w2c[:3,:3])) - return worldNormal2camNormal(relative_RT[:3,:3], normal) - -def trans_normal_complex(normal, RT_w2c, RT_w2c_rela_to_cond): - # camview -> world -> condview - normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal) - # debug_normal_world = normal2img(normal_world) - - # relative_RT = np.matmul(RT_w2c_rela_to_cond[:3,:3], np.linalg.inv(RT_w2c[:3,:3])) - normal_target_cam = worldNormal2camNormal(RT_w2c_rela_to_cond[:3,:3], normal_world) - # normal_condview = normal2img(normal_target_cam) - return normal_target_cam -def img2normal(img): - return (img/255.)*2-1 - -def normal2img(normal): - return np.uint8((normal*0.5+0.5)*255) - -def norm_normalize(normal, dim=-1): - - normal = normal/(np.linalg.norm(normal, axis=dim, keepdims=True)+1e-6) - - return normal - -def plot_grid_images(images, row, col, path=None): - import cv2 - """ - Args: - images: np.array [B, H, W, 3] - row: - col: - save_path: - - Returns: - - """ - images = images.detach().cpu().numpy() - assert row * col == images.shape[0] - images = np.vstack([np.hstack(images[r * col:(r + 1) * col]) for r in range(row)]) - if path: - cv2.imwrite(path, images[:,:,::-1] * 255) - return images \ No newline at end of file diff --git a/apps/third_party/Era3D/data/single_image_dataset.py b/apps/third_party/Era3D/data/single_image_dataset.py deleted file mode 100644 index e225ab8a7765173f576942f741bf2008567136d6..0000000000000000000000000000000000000000 --- a/apps/third_party/Era3D/data/single_image_dataset.py +++ /dev/null @@ -1,247 +0,0 @@ -from typing import Dict -import numpy as np -from omegaconf import DictConfig, ListConfig -import torch -from torch.utils.data import Dataset -from pathlib import Path -import json -from PIL import Image -from torchvision import transforms -from einops import rearrange -from typing import Literal, Tuple, Optional, Any -import cv2 -import random - -import json -import os, sys -import math - -from glob import glob - -import PIL.Image -from .normal_utils import trans_normal, normal2img, img2normal -import pdb - -import cv2 -import numpy as np - -def add_margin(pil_img, color=0, size=256): - width, height = pil_img.size - result = Image.new(pil_img.mode, (size, size), color) - result.paste(pil_img, ((size - width) // 2, (size - height) // 2)) - return result - -def scale_and_place_object(image, scale_factor): - assert np.shape(image)[-1]==4 # RGBA - - # Extract the alpha channel (transparency) and the object (RGB channels) - alpha_channel = image[:, :, 3] - - # Find the bounding box coordinates of the object - coords = cv2.findNonZero(alpha_channel) - x, y, width, height = cv2.boundingRect(coords) - - # Calculate the scale factor for resizing - original_height, original_width = image.shape[:2] - - if width > height: - size = width - original_size = original_width - else: - size = height - original_size = original_height - - scale_factor = min(scale_factor, size / (original_size+0.0)) - - new_size = scale_factor * original_size - scale_factor = new_size / size - - # Calculate the new size based on the scale factor - new_width = int(width * scale_factor) - new_height = int(height * scale_factor) - - center_x = original_width // 2 - center_y = original_height // 2 - - paste_x = center_x - (new_width // 2) - paste_y = center_y - (new_height // 2) - - # Resize the object (RGB channels) to the new size - rescaled_object = cv2.resize(image[y:y+height, x:x+width], (new_width, new_height)) - - # Create a new RGBA image with the resized image - new_image = np.zeros((original_height, original_width, 4), dtype=np.uint8) - - new_image[paste_y:paste_y + new_height, paste_x:paste_x + new_width] = rescaled_object - - return new_image - -class SingleImageDataset(Dataset): - def __init__(self, - root_dir: str, - num_views: int, - img_wh: Tuple[int, int], - bg_color: str, - crop_size: int = 224, - single_image: Optional[PIL.Image.Image] = None, - num_validation_samples: Optional[int] = None, - filepaths: Optional[list] = None, - cond_type: Optional[str] = None, - prompt_embeds_path: Optional[str] = None, - gt_path: Optional[str] = None - ) -> None: - """Create a dataset from a folder of images. - If you pass in a root directory it will be searched for images - ending in ext (ext can be a list) - """ - self.root_dir = root_dir - self.num_views = num_views - self.img_wh = img_wh - self.crop_size = crop_size - self.bg_color = bg_color - self.cond_type = cond_type - self.gt_path = gt_path - - - if single_image is None: - if filepaths is None: - # Get a list of all files in the directory - file_list = os.listdir(self.root_dir) - else: - file_list = filepaths - - # Filter the files that end with .png or .jpg - self.file_list = [file for file in file_list if file.endswith(('.png', '.jpg', '.webp'))] - else: - self.file_list = None - - # load all images - self.all_images = [] - self.all_alphas = [] - bg_color = self.get_bg_color() - - if single_image is not None: - image, alpha = self.load_image(None, bg_color, return_type='pt', Imagefile=single_image) - self.all_images.append(image) - self.all_alphas.append(alpha) - else: - for file in self.file_list: - print(os.path.join(self.root_dir, file)) - image, alpha = self.load_image(os.path.join(self.root_dir, file), bg_color, return_type='pt') - self.all_images.append(image) - self.all_alphas.append(alpha) - - - - self.all_images = self.all_images[:num_validation_samples] - self.all_alphas = self.all_alphas[:num_validation_samples] - - try: - self.normal_text_embeds = torch.load(f'{prompt_embeds_path}/normal_embeds.pt') - self.color_text_embeds = torch.load(f'{prompt_embeds_path}/clr_embeds.pt') # 4view - except: - self.color_text_embeds = torch.load(f'{prompt_embeds_path}/embeds.pt') - self.normal_text_embeds = None - - def __len__(self): - return len(self.all_images) - - def get_bg_color(self): - if self.bg_color == 'white': - bg_color = np.array([1., 1., 1.], dtype=np.float32) - elif self.bg_color == 'black': - bg_color = np.array([0., 0., 0.], dtype=np.float32) - elif self.bg_color == 'gray': - bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32) - elif self.bg_color == 'random': - bg_color = np.random.rand(3) - elif isinstance(self.bg_color, float): - bg_color = np.array([self.bg_color] * 3, dtype=np.float32) - else: - raise NotImplementedError - return bg_color - - - def load_image(self, img_path, bg_color, return_type='np', Imagefile=None): - # pil always returns uint8 - if Imagefile is None: - image_input = Image.open(img_path) - else: - image_input = Imagefile - image_size = self.img_wh[0] - - if self.crop_size!=-1: - alpha_np = np.asarray(image_input)[:, :, 3] - coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)] - min_x, min_y = np.min(coords, 0) - max_x, max_y = np.max(coords, 0) - ref_img_ = image_input.crop((min_x, min_y, max_x, max_y)) - h, w = ref_img_.height, ref_img_.width - scale = self.crop_size / max(h, w) - h_, w_ = int(scale * h), int(scale * w) - ref_img_ = ref_img_.resize((w_, h_)) - image_input = add_margin(ref_img_, size=image_size) - else: - image_input = add_margin(image_input, size=max(image_input.height, image_input.width)) - image_input = image_input.resize((image_size, image_size)) - - # img = scale_and_place_object(img, self.scale_ratio) - img = np.array(image_input) - img = img.astype(np.float32) / 255. # [0, 1] - assert img.shape[-1] == 4 # RGBA - - alpha = img[...,3:4] - img = img[...,:3] * alpha + bg_color * (1 - alpha) - - if return_type == "np": - pass - elif return_type == "pt": - img = torch.from_numpy(img) - alpha = torch.from_numpy(alpha) - else: - raise NotImplementedError - - return img, alpha - - - def __getitem__(self, index): - image = self.all_images[index%len(self.all_images)] - alpha = self.all_alphas[index%len(self.all_images)] - if self.file_list is not None: - filename = self.file_list[index%len(self.all_images)].replace(".png", "") - else: - filename = 'null' - img_tensors_in = [ - image.permute(2, 0, 1) - ] * self.num_views - - alpha_tensors_in = [ - alpha.permute(2, 0, 1) - ] * self.num_views - - img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W) - alpha_tensors_in = torch.stack(alpha_tensors_in, dim=0).float() # (Nv, 3, H, W) - - if self.gt_path is not None: - gt_image = self.gt_images[index%len(self.all_images)] - gt_alpha = self.gt_alpha[index%len(self.all_images)] - gt_img_tensors_in = [gt_image.permute(2, 0, 1) ] * self.num_views - gt_alpha_tensors_in = [gt_alpha.permute(2, 0, 1) ] * self.num_views - gt_img_tensors_in = torch.stack(gt_img_tensors_in, dim=0).float() - gt_alpha_tensors_in = torch.stack(gt_alpha_tensors_in, dim=0).float() - - normal_prompt_embeddings = self.normal_text_embeds if hasattr(self, 'normal_text_embeds') else None - color_prompt_embeddings = self.color_text_embeds if hasattr(self, 'color_text_embeds') else None - - out = { - 'imgs_in': img_tensors_in.unsqueeze(0), - 'alphas': alpha_tensors_in.unsqueeze(0), - 'normal_prompt_embeddings': normal_prompt_embeddings.unsqueeze(0), - 'color_prompt_embeddings': color_prompt_embeddings.unsqueeze(0), - 'filename': filename, - } - - return out - - - diff --git a/apps/third_party/Era3D/mvdiffusion/models/transformer_mv2d_image.py b/apps/third_party/Era3D/mvdiffusion/models/transformer_mv2d_image.py deleted file mode 100644 index 68d8b4d0ba8d05bfd6a760acf99080cfde7cbb62..0000000000000000000000000000000000000000 --- a/apps/third_party/Era3D/mvdiffusion/models/transformer_mv2d_image.py +++ /dev/null @@ -1,1029 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Any, Dict, Optional - -import torch -import torch.nn.functional as F -from torch import nn - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.models.embeddings import ImagePositionalEmbeddings -from diffusers.utils import BaseOutput, deprecate -from diffusers.utils.torch_utils import maybe_allow_in_graph -from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention -from diffusers.models.embeddings import PatchEmbed -from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear -from diffusers.models.modeling_utils import ModelMixin -from diffusers.utils.import_utils import is_xformers_available - -from einops import rearrange, repeat -import pdb -import random - - -if is_xformers_available(): - import xformers - import xformers.ops -else: - xformers = None - -def my_repeat(tensor, num_repeats): - """ - Repeat a tensor along a given dimension - """ - if len(tensor.shape) == 3: - return repeat(tensor, "b d c -> (b v) d c", v=num_repeats) - elif len(tensor.shape) == 4: - return repeat(tensor, "a b d c -> (a v) b d c", v=num_repeats) - - -@dataclass -class TransformerMV2DModelOutput(BaseOutput): - """ - The output of [`Transformer2DModel`]. - - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): - The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability - distributions for the unnoised latent pixels. - """ - - sample: torch.FloatTensor - - -class TransformerMV2DModel(ModelMixin, ConfigMixin): - """ - A 2D Transformer model for image-like data. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - The number of channels in the input and output (specify if the input is **continuous**). - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). - This is fixed during training since it is used to learn a number of position embeddings. - num_vector_embeds (`int`, *optional*): - The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). - Includes the class for the masked latent pixel. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. - num_embeds_ada_norm ( `int`, *optional*): - The number of diffusion steps used during training. Pass if at least one of the norm_layers is - `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are - added to the hidden states. - - During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. - attention_bias (`bool`, *optional*): - Configure if the `TransformerBlocks` attention should contain a bias parameter. - """ - - @register_to_config - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - out_channels: Optional[int] = None, - num_layers: int = 1, - dropout: float = 0.0, - norm_num_groups: int = 32, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - sample_size: Optional[int] = None, - num_vector_embeds: Optional[int] = None, - patch_size: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - norm_type: str = "layer_norm", - norm_elementwise_affine: bool = True, - num_views: int = 1, - cd_attention_last: bool=False, - cd_attention_mid: bool=False, - multiview_attention: bool=True, - sparse_mv_attention: bool = False, - mvcd_attention: bool=False - ): - super().__init__() - self.use_linear_projection = use_linear_projection - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim - - # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` - # Define whether input is continuous or discrete depending on configuration - self.is_input_continuous = (in_channels is not None) and (patch_size is None) - self.is_input_vectorized = num_vector_embeds is not None - self.is_input_patches = in_channels is not None and patch_size is not None - - if norm_type == "layer_norm" and num_embeds_ada_norm is not None: - deprecation_message = ( - f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" - " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." - " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" - " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" - " would be very nice if you could open a Pull request for the `transformer/config.json` file" - ) - deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) - norm_type = "ada_norm" - - if self.is_input_continuous and self.is_input_vectorized: - raise ValueError( - f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" - " sure that either `in_channels` or `num_vector_embeds` is None." - ) - elif self.is_input_vectorized and self.is_input_patches: - raise ValueError( - f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" - " sure that either `num_vector_embeds` or `num_patches` is None." - ) - elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: - raise ValueError( - f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" - f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." - ) - - # 2. Define input layers - if self.is_input_continuous: - self.in_channels = in_channels - - self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) - if use_linear_projection: - self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) - else: - self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) - elif self.is_input_vectorized: - assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" - assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" - - self.height = sample_size - self.width = sample_size - self.num_vector_embeds = num_vector_embeds - self.num_latent_pixels = self.height * self.width - - self.latent_image_embedding = ImagePositionalEmbeddings( - num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width - ) - elif self.is_input_patches: - assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" - - self.height = sample_size - self.width = sample_size - - self.patch_size = patch_size - self.pos_embed = PatchEmbed( - height=sample_size, - width=sample_size, - patch_size=patch_size, - in_channels=in_channels, - embed_dim=inner_dim, - ) - - # 3. Define transformers blocks - self.transformer_blocks = nn.ModuleList( - [ - BasicMVTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - attention_bias=attention_bias, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - norm_type=norm_type, - norm_elementwise_affine=norm_elementwise_affine, - num_views=num_views, - cd_attention_last=cd_attention_last, - cd_attention_mid=cd_attention_mid, - multiview_attention=multiview_attention, - sparse_mv_attention=sparse_mv_attention, - mvcd_attention=mvcd_attention - ) - for d in range(num_layers) - ] - ) - - # 4. Define output layers - self.out_channels = in_channels if out_channels is None else out_channels - if self.is_input_continuous: - # TODO: should use out_channels for continuous projections - if use_linear_projection: - self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) - else: - self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - elif self.is_input_vectorized: - self.norm_out = nn.LayerNorm(inner_dim) - self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) - elif self.is_input_patches: - self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) - self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - timestep: Optional[torch.LongTensor] = None, - class_labels: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - attention_mask: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - return_dict: bool = True, - ): - """ - The [`Transformer2DModel`] forward method. - - Args: - hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): - Input `hidden_states`. - encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - timestep ( `torch.LongTensor`, *optional*): - Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. - class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): - Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in - `AdaLayerZeroNorm`. - encoder_attention_mask ( `torch.Tensor`, *optional*): - Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: - - * Mask `(batch, sequence_length)` True = keep, False = discard. - * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. - - If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format - above. This bias will be added to the cross-attention scores. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. - # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. - # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. - # expects mask of shape: - # [batch, key_tokens] - # adds singleton query_tokens dimension: - # [batch, 1, key_tokens] - # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: - # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) - # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) - if attention_mask is not None and attention_mask.ndim == 2: - # assume that mask is expressed as: - # (1 = keep, 0 = discard) - # convert mask into a bias that can be added to attention scores: - # (keep = +0, discard = -10000.0) - attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: - encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 - encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - - # 1. Input - if self.is_input_continuous: - batch, _, height, width = hidden_states.shape - residual = hidden_states - - hidden_states = self.norm(hidden_states) - if not self.use_linear_projection: - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - else: - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - hidden_states = self.proj_in(hidden_states) - elif self.is_input_vectorized: - hidden_states = self.latent_image_embedding(hidden_states) - elif self.is_input_patches: - hidden_states = self.pos_embed(hidden_states) - - # 2. Blocks - for block in self.transformer_blocks: - hidden_states = block( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) - - # 3. Output - if self.is_input_continuous: - if not self.use_linear_projection: - hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - hidden_states = self.proj_out(hidden_states) - else: - hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - - output = hidden_states + residual - elif self.is_input_vectorized: - hidden_states = self.norm_out(hidden_states) - logits = self.out(hidden_states) - # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) - logits = logits.permute(0, 2, 1) - - # log(p(x_0)) - output = F.log_softmax(logits.double(), dim=1).float() - elif self.is_input_patches: - # TODO: cleanup! - conditioning = self.transformer_blocks[0].norm1.emb( - timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] - hidden_states = self.proj_out_2(hidden_states) - - # unpatchify - height = width = int(hidden_states.shape[1] ** 0.5) - hidden_states = hidden_states.reshape( - shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) - ) - hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) - output = hidden_states.reshape( - shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) - ) - - if not return_dict: - return (output,) - - return TransformerMV2DModelOutput(sample=output) - - -@maybe_allow_in_graph -class BasicMVTransformerBlock(nn.Module): - r""" - A basic Transformer block. - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - only_cross_attention (`bool`, *optional*): - Whether to use only cross-attention layers. In this case two cross attention layers are used. - double_self_attention (`bool`, *optional*): - Whether to use two self-attention layers. In this case no cross attention layers are used. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout=0.0, - cross_attention_dim: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - attention_bias: bool = False, - only_cross_attention: bool = False, - double_self_attention: bool = False, - upcast_attention: bool = False, - norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", - final_dropout: bool = False, - num_views: int = 1, - cd_attention_last: bool = False, - cd_attention_mid: bool = False, - multiview_attention: bool = True, - sparse_mv_attention: bool = False, - mvcd_attention: bool = False - ): - super().__init__() - self.only_cross_attention = only_cross_attention - - self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" - self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" - - if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: - raise ValueError( - f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" - f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." - ) - - # Define 3 blocks. Each block has its own normalization layer. - # 1. Self-Attn - if self.use_ada_layer_norm: - self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) - elif self.use_ada_layer_norm_zero: - self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) - else: - self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - - self.multiview_attention = multiview_attention - self.sparse_mv_attention = sparse_mv_attention - self.mvcd_attention = mvcd_attention - - self.attn1 = CustomAttention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - processor=MVAttnProcessor() - ) - - # 2. Cross-Attn - if cross_attention_dim is not None or double_self_attention: - # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. - # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during - # the second cross attention block. - self.norm2 = ( - AdaLayerNorm(dim, num_embeds_ada_norm) - if self.use_ada_layer_norm - else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - ) - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim if not double_self_attention else None, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - ) # is self-attn if encoder_hidden_states is none - else: - self.norm2 = None - self.attn2 = None - - # 3. Feed-forward - self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 - - self.num_views = num_views - - self.cd_attention_last = cd_attention_last - - if self.cd_attention_last: - # Joint task -Attn - self.attn_joint_last = CustomJointAttention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - processor=JointAttnProcessor() - ) - nn.init.zeros_(self.attn_joint_last.to_out[0].weight.data) - self.norm_joint_last = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) - - - self.cd_attention_mid = cd_attention_mid - - if self.cd_attention_mid: - print("cross-domain attn in the middle") - # Joint task -Attn - self.attn_joint_mid = CustomJointAttention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - processor=JointAttnProcessor() - ) - nn.init.zeros_(self.attn_joint_mid.to_out[0].weight.data) - self.norm_joint_mid = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) - - def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim - - def forward( - self, - hidden_states: torch.FloatTensor, - attention_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - timestep: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[torch.LongTensor] = None, - ): - assert attention_mask is None # not supported yet - # Notice that normalization is always applied before the real computation in the following blocks. - # 1. Self-Attention - if self.use_ada_layer_norm: - norm_hidden_states = self.norm1(hidden_states, timestep) - elif self.use_ada_layer_norm_zero: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - else: - norm_hidden_states = self.norm1(hidden_states) - - cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask, - num_views=self.num_views, - multiview_attention=self.multiview_attention, - sparse_mv_attention=self.sparse_mv_attention, - mvcd_attention=self.mvcd_attention, - **cross_attention_kwargs, - ) - - - if self.use_ada_layer_norm_zero: - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = attn_output + hidden_states - - # joint attention twice - if self.cd_attention_mid: - norm_hidden_states = ( - self.norm_joint_mid(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_mid(hidden_states) - ) - hidden_states = self.attn_joint_mid(norm_hidden_states) + hidden_states - - # 2. Cross-Attention - if self.attn2 is not None: - norm_hidden_states = ( - self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) - ) - - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - **cross_attention_kwargs, - ) - hidden_states = attn_output + hidden_states - - # 3. Feed-forward - norm_hidden_states = self.norm3(hidden_states) - - if self.use_ada_layer_norm_zero: - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: - raise ValueError( - f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." - ) - - num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size - ff_output = torch.cat( - [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], - dim=self._chunk_dim, - ) - else: - ff_output = self.ff(norm_hidden_states) - - if self.use_ada_layer_norm_zero: - ff_output = gate_mlp.unsqueeze(1) * ff_output - - hidden_states = ff_output + hidden_states - - if self.cd_attention_last: - norm_hidden_states = ( - self.norm_joint_last(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_last(hidden_states) - ) - hidden_states = self.attn_joint_last(norm_hidden_states) + hidden_states - - return hidden_states - - -class CustomAttention(Attention): - def set_use_memory_efficient_attention_xformers( - self, use_memory_efficient_attention_xformers: bool, *args, **kwargs - ): - processor = XFormersMVAttnProcessor() - self.set_processor(processor) - # print("using xformers attention processor") - - -class CustomJointAttention(Attention): - def set_use_memory_efficient_attention_xformers( - self, use_memory_efficient_attention_xformers: bool, *args, **kwargs - ): - processor = XFormersJointAttnProcessor() - self.set_processor(processor) - # print("using xformers attention processor") - -class MVAttnProcessor: - r""" - Default processor for performing attention-related computations. - """ - - def __call__( - self, - attn: Attention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - num_views=1, - multiview_attention=True - ): - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - # print('query', query.shape, 'key', key.shape, 'value', value.shape) - #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320]) - # pdb.set_trace() - # multi-view self-attention - if multiview_attention: - if num_views <= 6: - # after use xformer; possible to train with 6 views - key = rearrange(key, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0) - value = rearrange(value, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0) - else:# apply sparse attention - pass - # print("use sparse attention") - # # seems that the sparse random sampling cause problems - # # don't use random sampling, just fix the indexes - # onekey = rearrange(key, "(b t) d c -> b t d c", t=num_views) - # onevalue = rearrange(value, "(b t) d c -> b t d c", t=num_views) - # allkeys = [] - # allvalues = [] - # all_indexes = { - # 0 : [0, 2, 3, 4], - # 1: [0, 1, 3, 5], - # 2: [0, 2, 3, 4], - # 3: [0, 2, 3, 4], - # 4: [0, 2, 3, 4], - # 5: [0, 1, 3, 5] - # } - # for jj in range(num_views): - # # valid_index = [x for x in range(0, num_views) if x!= jj] - # # indexes = random.sample(valid_index, 3) + [jj] + [0] - # indexes = all_indexes[jj] - - # indexes = torch.tensor(indexes).long().to(key.device) - # allkeys.append(onekey[:, indexes]) - # allvalues.append(onevalue[:, indexes]) - # keys = torch.stack(allkeys, dim=1) # checked, should be dim=1 - # values = torch.stack(allvalues, dim=1) - # key = rearrange(keys, 'b t f d c -> (b t) (f d) c') - # value = rearrange(values, 'b t f d c -> (b t) (f d) c') - - - query = attn.head_to_batch_dim(query).contiguous() - key = attn.head_to_batch_dim(key).contiguous() - value = attn.head_to_batch_dim(value).contiguous() - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -class XFormersMVAttnProcessor: - r""" - Default processor for performing attention-related computations. - """ - - def __call__( - self, - attn: Attention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - num_views=1., - multiview_attention=True, - sparse_mv_attention=False, - mvcd_attention=False, - ): - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - # from yuancheng; here attention_mask is None - if attention_mask is not None: - # expand our mask's singleton query_tokens dimension: - # [batch*heads, 1, key_tokens] -> - # [batch*heads, query_tokens, key_tokens] - # so that it can be added as a bias onto the attention scores that xformers computes: - # [batch*heads, query_tokens, key_tokens] - # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. - _, query_tokens, _ = hidden_states.shape - attention_mask = attention_mask.expand(-1, query_tokens, -1) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key_raw = attn.to_k(encoder_hidden_states) - value_raw = attn.to_v(encoder_hidden_states) - - # print('query', query.shape, 'key', key.shape, 'value', value.shape) - #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320]) - # pdb.set_trace() - # multi-view self-attention - if multiview_attention: - if not sparse_mv_attention: - key = my_repeat(rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views), num_views) - value = my_repeat(rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views), num_views) - else: - key_front = my_repeat(rearrange(key_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views) # [(b t), d, c] - value_front = my_repeat(rearrange(value_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views) - key = torch.cat([key_front, key_raw], dim=1) # shape (b t) (2 d) c - value = torch.cat([value_front, value_raw], dim=1) - - if mvcd_attention: - # memory efficient, cross domain attention - key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2) # keys shape (b t) d c - value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2) - key_cross = torch.concat([key_1, key_0], dim=0) - value_cross = torch.concat([value_1, value_0], dim=0) # shape (b t) d c - key = torch.cat([key, key_cross], dim=1) - value = torch.cat([value, value_cross], dim=1) # shape (b t) (t+1 d) c - else: - # print("don't use multiview attention.") - key = key_raw - value = value_raw - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - - -class XFormersJointAttnProcessor: - r""" - Default processor for performing attention-related computations. - """ - - def __call__( - self, - attn: Attention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - num_tasks=2 - ): - - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - # from yuancheng; here attention_mask is None - if attention_mask is not None: - # expand our mask's singleton query_tokens dimension: - # [batch*heads, 1, key_tokens] -> - # [batch*heads, query_tokens, key_tokens] - # so that it can be added as a bias onto the attention scores that xformers computes: - # [batch*heads, query_tokens, key_tokens] - # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. - _, query_tokens, _ = hidden_states.shape - attention_mask = attention_mask.expand(-1, query_tokens, -1) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - assert num_tasks == 2 # only support two tasks now - - key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c - value_0, value_1 = torch.chunk(value, dim=0, chunks=2) - key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c - value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c - key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c - value = torch.cat([value]*2, dim=0) # (2 b t) 2d c - - - query = attn.head_to_batch_dim(query).contiguous() - key = attn.head_to_batch_dim(key).contiguous() - value = attn.head_to_batch_dim(value).contiguous() - - hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -class JointAttnProcessor: - r""" - Default processor for performing attention-related computations. - """ - - def __call__( - self, - attn: Attention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - num_tasks=2 - ): - - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - assert num_tasks == 2 # only support two tasks now - - key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c - value_0, value_1 = torch.chunk(value, dim=0, chunks=2) - key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c - value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c - key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c - value = torch.cat([value]*2, dim=0) # (2 b t) 2d c - - - query = attn.head_to_batch_dim(query).contiguous() - key = attn.head_to_batch_dim(key).contiguous() - value = attn.head_to_batch_dim(value).contiguous() - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - diff --git a/apps/third_party/Era3D/mvdiffusion/models/transformer_mv2d_rowwise.py b/apps/third_party/Era3D/mvdiffusion/models/transformer_mv2d_rowwise.py deleted file mode 100644 index 5bf9643718a58410e099f71af2888f8da274fc5d..0000000000000000000000000000000000000000 --- a/apps/third_party/Era3D/mvdiffusion/models/transformer_mv2d_rowwise.py +++ /dev/null @@ -1,978 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Any, Dict, Optional - -import torch -import torch.nn.functional as F -from torch import nn - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.models.embeddings import ImagePositionalEmbeddings -from diffusers.utils import BaseOutput, deprecate -from diffusers.utils.torch_utils import maybe_allow_in_graph -from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention -from diffusers.models.embeddings import PatchEmbed -from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear -from diffusers.models.modeling_utils import ModelMixin -from diffusers.utils.import_utils import is_xformers_available - -from einops import rearrange -import pdb -import random -import math - - -if is_xformers_available(): - import xformers - import xformers.ops -else: - xformers = None - - -@dataclass -class TransformerMV2DModelOutput(BaseOutput): - """ - The output of [`Transformer2DModel`]. - - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): - The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability - distributions for the unnoised latent pixels. - """ - - sample: torch.FloatTensor - - -class TransformerMV2DModel(ModelMixin, ConfigMixin): - """ - A 2D Transformer model for image-like data. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - The number of channels in the input and output (specify if the input is **continuous**). - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). - This is fixed during training since it is used to learn a number of position embeddings. - num_vector_embeds (`int`, *optional*): - The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). - Includes the class for the masked latent pixel. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. - num_embeds_ada_norm ( `int`, *optional*): - The number of diffusion steps used during training. Pass if at least one of the norm_layers is - `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are - added to the hidden states. - - During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. - attention_bias (`bool`, *optional*): - Configure if the `TransformerBlocks` attention should contain a bias parameter. - """ - - @register_to_config - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - out_channels: Optional[int] = None, - num_layers: int = 1, - dropout: float = 0.0, - norm_num_groups: int = 32, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - sample_size: Optional[int] = None, - num_vector_embeds: Optional[int] = None, - patch_size: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - norm_type: str = "layer_norm", - norm_elementwise_affine: bool = True, - num_views: int = 1, - cd_attention_last: bool=False, - cd_attention_mid: bool=False, - multiview_attention: bool=True, - sparse_mv_attention: bool = True, # not used - mvcd_attention: bool=False - ): - super().__init__() - self.use_linear_projection = use_linear_projection - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim - - # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` - # Define whether input is continuous or discrete depending on configuration - self.is_input_continuous = (in_channels is not None) and (patch_size is None) - self.is_input_vectorized = num_vector_embeds is not None - self.is_input_patches = in_channels is not None and patch_size is not None - - if norm_type == "layer_norm" and num_embeds_ada_norm is not None: - deprecation_message = ( - f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" - " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." - " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" - " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" - " would be very nice if you could open a Pull request for the `transformer/config.json` file" - ) - deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) - norm_type = "ada_norm" - - if self.is_input_continuous and self.is_input_vectorized: - raise ValueError( - f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" - " sure that either `in_channels` or `num_vector_embeds` is None." - ) - elif self.is_input_vectorized and self.is_input_patches: - raise ValueError( - f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" - " sure that either `num_vector_embeds` or `num_patches` is None." - ) - elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: - raise ValueError( - f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" - f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." - ) - - # 2. Define input layers - if self.is_input_continuous: - self.in_channels = in_channels - - self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) - if use_linear_projection: - self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) - else: - self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) - elif self.is_input_vectorized: - assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" - assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" - - self.height = sample_size - self.width = sample_size - self.num_vector_embeds = num_vector_embeds - self.num_latent_pixels = self.height * self.width - - self.latent_image_embedding = ImagePositionalEmbeddings( - num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width - ) - elif self.is_input_patches: - assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" - - self.height = sample_size - self.width = sample_size - - self.patch_size = patch_size - self.pos_embed = PatchEmbed( - height=sample_size, - width=sample_size, - patch_size=patch_size, - in_channels=in_channels, - embed_dim=inner_dim, - ) - - # 3. Define transformers blocks - self.transformer_blocks = nn.ModuleList( - [ - BasicMVTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - attention_bias=attention_bias, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - norm_type=norm_type, - norm_elementwise_affine=norm_elementwise_affine, - num_views=num_views, - cd_attention_last=cd_attention_last, - cd_attention_mid=cd_attention_mid, - multiview_attention=multiview_attention, - mvcd_attention=mvcd_attention - ) - for d in range(num_layers) - ] - ) - - # 4. Define output layers - self.out_channels = in_channels if out_channels is None else out_channels - if self.is_input_continuous: - # TODO: should use out_channels for continuous projections - if use_linear_projection: - self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) - else: - self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - elif self.is_input_vectorized: - self.norm_out = nn.LayerNorm(inner_dim) - self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) - elif self.is_input_patches: - self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) - self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - timestep: Optional[torch.LongTensor] = None, - class_labels: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - attention_mask: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - return_dict: bool = True, - ): - """ - The [`Transformer2DModel`] forward method. - - Args: - hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): - Input `hidden_states`. - encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - timestep ( `torch.LongTensor`, *optional*): - Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. - class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): - Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in - `AdaLayerZeroNorm`. - encoder_attention_mask ( `torch.Tensor`, *optional*): - Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: - - * Mask `(batch, sequence_length)` True = keep, False = discard. - * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. - - If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format - above. This bias will be added to the cross-attention scores. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. - # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. - # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. - # expects mask of shape: - # [batch, key_tokens] - # adds singleton query_tokens dimension: - # [batch, 1, key_tokens] - # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: - # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) - # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) - if attention_mask is not None and attention_mask.ndim == 2: - # assume that mask is expressed as: - # (1 = keep, 0 = discard) - # convert mask into a bias that can be added to attention scores: - # (keep = +0, discard = -10000.0) - attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: - encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 - encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - - # 1. Input - if self.is_input_continuous: - batch, _, height, width = hidden_states.shape - residual = hidden_states - - hidden_states = self.norm(hidden_states) - if not self.use_linear_projection: - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - else: - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - hidden_states = self.proj_in(hidden_states) - elif self.is_input_vectorized: - hidden_states = self.latent_image_embedding(hidden_states) - elif self.is_input_patches: - hidden_states = self.pos_embed(hidden_states) - - # 2. Blocks - for block in self.transformer_blocks: - hidden_states = block( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) - - # 3. Output - if self.is_input_continuous: - if not self.use_linear_projection: - hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - hidden_states = self.proj_out(hidden_states) - else: - hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - - output = hidden_states + residual - elif self.is_input_vectorized: - hidden_states = self.norm_out(hidden_states) - logits = self.out(hidden_states) - # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) - logits = logits.permute(0, 2, 1) - - # log(p(x_0)) - output = F.log_softmax(logits.double(), dim=1).float() - elif self.is_input_patches: - # TODO: cleanup! - conditioning = self.transformer_blocks[0].norm1.emb( - timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] - hidden_states = self.proj_out_2(hidden_states) - - # unpatchify - height = width = int(hidden_states.shape[1] ** 0.5) - hidden_states = hidden_states.reshape( - shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) - ) - hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) - output = hidden_states.reshape( - shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) - ) - - if not return_dict: - return (output,) - - return TransformerMV2DModelOutput(sample=output) - - -@maybe_allow_in_graph -class BasicMVTransformerBlock(nn.Module): - r""" - A basic Transformer block. - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - only_cross_attention (`bool`, *optional*): - Whether to use only cross-attention layers. In this case two cross attention layers are used. - double_self_attention (`bool`, *optional*): - Whether to use two self-attention layers. In this case no cross attention layers are used. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout=0.0, - cross_attention_dim: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - attention_bias: bool = False, - only_cross_attention: bool = False, - double_self_attention: bool = False, - upcast_attention: bool = False, - norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", - final_dropout: bool = False, - num_views: int = 1, - cd_attention_last: bool = False, - cd_attention_mid: bool = False, - multiview_attention: bool = True, - mvcd_attention: bool = False, - rowwise_attention: bool = True - ): - super().__init__() - self.only_cross_attention = only_cross_attention - - self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" - self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" - - if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: - raise ValueError( - f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" - f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." - ) - - # Define 3 blocks. Each block has its own normalization layer. - # 1. Self-Attn - if self.use_ada_layer_norm: - self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) - elif self.use_ada_layer_norm_zero: - self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) - else: - self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - - self.multiview_attention = multiview_attention - self.mvcd_attention = mvcd_attention - self.rowwise_attention = multiview_attention and rowwise_attention - - # rowwise multiview attention - - print('INFO: using row wise attention...') - - self.attn1 = CustomAttention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - processor=MVAttnProcessor() - ) - - # 2. Cross-Attn - if cross_attention_dim is not None or double_self_attention: - # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. - # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during - # the second cross attention block. - self.norm2 = ( - AdaLayerNorm(dim, num_embeds_ada_norm) - if self.use_ada_layer_norm - else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - ) - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim if not double_self_attention else None, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - ) # is self-attn if encoder_hidden_states is none - else: - self.norm2 = None - self.attn2 = None - - # 3. Feed-forward - self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 - - self.num_views = num_views - - self.cd_attention_last = cd_attention_last - - if self.cd_attention_last: - # Joint task -Attn - self.attn_joint = CustomJointAttention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - processor=JointAttnProcessor() - ) - nn.init.zeros_(self.attn_joint.to_out[0].weight.data) - self.norm_joint = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) - - - self.cd_attention_mid = cd_attention_mid - - if self.cd_attention_mid: - print("joint twice") - # Joint task -Attn - self.attn_joint_twice = CustomJointAttention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - processor=JointAttnProcessor() - ) - nn.init.zeros_(self.attn_joint_twice.to_out[0].weight.data) - self.norm_joint_twice = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) - - def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim - - def forward( - self, - hidden_states: torch.FloatTensor, - attention_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - timestep: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[torch.LongTensor] = None, - ): - assert attention_mask is None # not supported yet - # Notice that normalization is always applied before the real computation in the following blocks. - # 1. Self-Attention - if self.use_ada_layer_norm: - norm_hidden_states = self.norm1(hidden_states, timestep) - elif self.use_ada_layer_norm_zero: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - else: - norm_hidden_states = self.norm1(hidden_states) - - cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask, - multiview_attention=self.multiview_attention, - mvcd_attention=self.mvcd_attention, - num_views=self.num_views, - **cross_attention_kwargs, - ) - - if self.use_ada_layer_norm_zero: - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = attn_output + hidden_states - - # joint attention twice - if self.cd_attention_mid: - norm_hidden_states = ( - self.norm_joint_twice(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_twice(hidden_states) - ) - hidden_states = self.attn_joint_twice(norm_hidden_states) + hidden_states - - # 2. Cross-Attention - if self.attn2 is not None: - norm_hidden_states = ( - self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) - ) - - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - **cross_attention_kwargs, - ) - hidden_states = attn_output + hidden_states - - # 3. Feed-forward - norm_hidden_states = self.norm3(hidden_states) - - if self.use_ada_layer_norm_zero: - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: - raise ValueError( - f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." - ) - - num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size - ff_output = torch.cat( - [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], - dim=self._chunk_dim, - ) - else: - ff_output = self.ff(norm_hidden_states) - - if self.use_ada_layer_norm_zero: - ff_output = gate_mlp.unsqueeze(1) * ff_output - - hidden_states = ff_output + hidden_states - - if self.cd_attention_last: - norm_hidden_states = ( - self.norm_joint(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint(hidden_states) - ) - hidden_states = self.attn_joint(norm_hidden_states) + hidden_states - - return hidden_states - - -class CustomAttention(Attention): - def set_use_memory_efficient_attention_xformers( - self, use_memory_efficient_attention_xformers: bool, *args, **kwargs - ): - processor = XFormersMVAttnProcessor() - self.set_processor(processor) - # print("using xformers attention processor") - - -class CustomJointAttention(Attention): - def set_use_memory_efficient_attention_xformers( - self, use_memory_efficient_attention_xformers: bool, *args, **kwargs - ): - processor = XFormersJointAttnProcessor() - self.set_processor(processor) - # print("using xformers attention processor") - -class MVAttnProcessor: - r""" - Default processor for performing attention-related computations. - """ - - def __call__( - self, - attn: Attention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - num_views=1, - multiview_attention=True - ): - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - height = int(math.sqrt(sequence_length)) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - # print('query', query.shape, 'key', key.shape, 'value', value.shape) - #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320]) - # pdb.set_trace() - # multi-view self-attention - key = rearrange(key, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) - value = rearrange(value, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) - query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320]) - - query = attn.head_to_batch_dim(query).contiguous() - key = attn.head_to_batch_dim(key).contiguous() - value = attn.head_to_batch_dim(value).contiguous() - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height) - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -class XFormersMVAttnProcessor: - r""" - Default processor for performing attention-related computations. - """ - - def __call__( - self, - attn: Attention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - num_views=1, - multiview_attention=True, - mvcd_attention=False, - ): - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - height = int(math.sqrt(sequence_length)) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # from yuancheng; here attention_mask is None - if attention_mask is not None: - # expand our mask's singleton query_tokens dimension: - # [batch*heads, 1, key_tokens] -> - # [batch*heads, query_tokens, key_tokens] - # so that it can be added as a bias onto the attention scores that xformers computes: - # [batch*heads, query_tokens, key_tokens] - # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. - _, query_tokens, _ = hidden_states.shape - attention_mask = attention_mask.expand(-1, query_tokens, -1) - - if attn.group_norm is not None: - print('Warning: using group norm, pay attention to use it in row-wise attention') - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key_raw = attn.to_k(encoder_hidden_states) - value_raw = attn.to_v(encoder_hidden_states) - - # print('query', query.shape, 'key', key.shape, 'value', value.shape) - # pdb.set_trace() - - key = rearrange(key_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) - value = rearrange(value_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) - query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320]) - if mvcd_attention: - # memory efficient, cross domain attention - key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2) # keys shape (b t) d c - value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2) - key_cross = torch.concat([key_1, key_0], dim=0) - value_cross = torch.concat([value_1, value_0], dim=0) # shape (b t) d c - key = torch.cat([key, key_cross], dim=1) - value = torch.cat([value, value_cross], dim=1) # shape (b t) (t+1 d) c - - - query = attn.head_to_batch_dim(query) # torch.Size([960, 384, 64]) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - # print(hidden_states.shape) - hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height) - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -class XFormersJointAttnProcessor: - r""" - Default processor for performing attention-related computations. - """ - - def __call__( - self, - attn: Attention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - num_tasks=2 - ): - - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - # from yuancheng; here attention_mask is None - if attention_mask is not None: - # expand our mask's singleton query_tokens dimension: - # [batch*heads, 1, key_tokens] -> - # [batch*heads, query_tokens, key_tokens] - # so that it can be added as a bias onto the attention scores that xformers computes: - # [batch*heads, query_tokens, key_tokens] - # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. - _, query_tokens, _ = hidden_states.shape - attention_mask = attention_mask.expand(-1, query_tokens, -1) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - assert num_tasks == 2 # only support two tasks now - - key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c - value_0, value_1 = torch.chunk(value, dim=0, chunks=2) - key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c - value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c - key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c - value = torch.cat([value]*2, dim=0) # (2 b t) 2d c - - - query = attn.head_to_batch_dim(query).contiguous() - key = attn.head_to_batch_dim(key).contiguous() - value = attn.head_to_batch_dim(value).contiguous() - - hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -class JointAttnProcessor: - r""" - Default processor for performing attention-related computations. - """ - - def __call__( - self, - attn: Attention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - num_tasks=2 - ): - - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - assert num_tasks == 2 # only support two tasks now - - key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c - value_0, value_1 = torch.chunk(value, dim=0, chunks=2) - key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c - value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c - key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c - value = torch.cat([value]*2, dim=0) # (2 b t) 2d c - - - query = attn.head_to_batch_dim(query).contiguous() - key = attn.head_to_batch_dim(key).contiguous() - value = attn.head_to_batch_dim(value).contiguous() - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states \ No newline at end of file diff --git a/apps/third_party/Era3D/mvdiffusion/models/transformer_mv2d_self_rowwise.py b/apps/third_party/Era3D/mvdiffusion/models/transformer_mv2d_self_rowwise.py deleted file mode 100644 index 4eca703670c8d2f1d0fbb08027694e87acb8c926..0000000000000000000000000000000000000000 --- a/apps/third_party/Era3D/mvdiffusion/models/transformer_mv2d_self_rowwise.py +++ /dev/null @@ -1,1038 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Any, Dict, Optional - -import torch -import torch.nn.functional as F -from torch import nn - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.models.embeddings import ImagePositionalEmbeddings -from diffusers.utils import BaseOutput, deprecate -from diffusers.utils.torch_utils import maybe_allow_in_graph -from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention -from diffusers.models.embeddings import PatchEmbed -from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear -from diffusers.models.modeling_utils import ModelMixin -from diffusers.utils.import_utils import is_xformers_available - -from einops import rearrange -import pdb -import random -import math - - -if is_xformers_available(): - import xformers - import xformers.ops -else: - xformers = None - - -@dataclass -class TransformerMV2DModelOutput(BaseOutput): - """ - The output of [`Transformer2DModel`]. - - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): - The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability - distributions for the unnoised latent pixels. - """ - - sample: torch.FloatTensor - - -class TransformerMV2DModel(ModelMixin, ConfigMixin): - """ - A 2D Transformer model for image-like data. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - The number of channels in the input and output (specify if the input is **continuous**). - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). - This is fixed during training since it is used to learn a number of position embeddings. - num_vector_embeds (`int`, *optional*): - The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). - Includes the class for the masked latent pixel. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. - num_embeds_ada_norm ( `int`, *optional*): - The number of diffusion steps used during training. Pass if at least one of the norm_layers is - `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are - added to the hidden states. - - During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. - attention_bias (`bool`, *optional*): - Configure if the `TransformerBlocks` attention should contain a bias parameter. - """ - - @register_to_config - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - out_channels: Optional[int] = None, - num_layers: int = 1, - dropout: float = 0.0, - norm_num_groups: int = 32, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - sample_size: Optional[int] = None, - num_vector_embeds: Optional[int] = None, - patch_size: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - norm_type: str = "layer_norm", - norm_elementwise_affine: bool = True, - num_views: int = 1, - cd_attention_mid: bool=False, - cd_attention_last: bool=False, - multiview_attention: bool=True, - sparse_mv_attention: bool = True, # not used - mvcd_attention: bool=False, - use_dino: bool=False - ): - super().__init__() - self.use_linear_projection = use_linear_projection - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim - - # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` - # Define whether input is continuous or discrete depending on configuration - self.is_input_continuous = (in_channels is not None) and (patch_size is None) - self.is_input_vectorized = num_vector_embeds is not None - self.is_input_patches = in_channels is not None and patch_size is not None - - if norm_type == "layer_norm" and num_embeds_ada_norm is not None: - deprecation_message = ( - f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" - " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." - " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" - " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" - " would be very nice if you could open a Pull request for the `transformer/config.json` file" - ) - deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) - norm_type = "ada_norm" - - if self.is_input_continuous and self.is_input_vectorized: - raise ValueError( - f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" - " sure that either `in_channels` or `num_vector_embeds` is None." - ) - elif self.is_input_vectorized and self.is_input_patches: - raise ValueError( - f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" - " sure that either `num_vector_embeds` or `num_patches` is None." - ) - elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: - raise ValueError( - f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" - f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." - ) - - # 2. Define input layers - if self.is_input_continuous: - self.in_channels = in_channels - - self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) - if use_linear_projection: - self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) - else: - self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) - elif self.is_input_vectorized: - assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" - assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" - - self.height = sample_size - self.width = sample_size - self.num_vector_embeds = num_vector_embeds - self.num_latent_pixels = self.height * self.width - - self.latent_image_embedding = ImagePositionalEmbeddings( - num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width - ) - elif self.is_input_patches: - assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" - - self.height = sample_size - self.width = sample_size - - self.patch_size = patch_size - self.pos_embed = PatchEmbed( - height=sample_size, - width=sample_size, - patch_size=patch_size, - in_channels=in_channels, - embed_dim=inner_dim, - ) - - # 3. Define transformers blocks - self.transformer_blocks = nn.ModuleList( - [ - BasicMVTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - attention_bias=attention_bias, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - norm_type=norm_type, - norm_elementwise_affine=norm_elementwise_affine, - num_views=num_views, - cd_attention_last=cd_attention_last, - cd_attention_mid=cd_attention_mid, - multiview_attention=multiview_attention, - mvcd_attention=mvcd_attention, - use_dino=use_dino - ) - for d in range(num_layers) - ] - ) - - # 4. Define output layers - self.out_channels = in_channels if out_channels is None else out_channels - if self.is_input_continuous: - # TODO: should use out_channels for continuous projections - if use_linear_projection: - self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) - else: - self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - elif self.is_input_vectorized: - self.norm_out = nn.LayerNorm(inner_dim) - self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) - elif self.is_input_patches: - self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) - self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - dino_feature: Optional[torch.Tensor] = None, - timestep: Optional[torch.LongTensor] = None, - class_labels: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - attention_mask: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - return_dict: bool = True, - ): - """ - The [`Transformer2DModel`] forward method. - - Args: - hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): - Input `hidden_states`. - encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - timestep ( `torch.LongTensor`, *optional*): - Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. - class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): - Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in - `AdaLayerZeroNorm`. - encoder_attention_mask ( `torch.Tensor`, *optional*): - Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: - - * Mask `(batch, sequence_length)` True = keep, False = discard. - * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. - - If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format - above. This bias will be added to the cross-attention scores. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. - # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. - # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. - # expects mask of shape: - # [batch, key_tokens] - # adds singleton query_tokens dimension: - # [batch, 1, key_tokens] - # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: - # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) - # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) - if attention_mask is not None and attention_mask.ndim == 2: - # assume that mask is expressed as: - # (1 = keep, 0 = discard) - # convert mask into a bias that can be added to attention scores: - # (keep = +0, discard = -10000.0) - attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: - encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 - encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - - # 1. Input - if self.is_input_continuous: - batch, _, height, width = hidden_states.shape - residual = hidden_states - - hidden_states = self.norm(hidden_states) - if not self.use_linear_projection: - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - else: - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - hidden_states = self.proj_in(hidden_states) - elif self.is_input_vectorized: - hidden_states = self.latent_image_embedding(hidden_states) - elif self.is_input_patches: - hidden_states = self.pos_embed(hidden_states) - - # 2. Blocks - for block in self.transformer_blocks: - hidden_states = block( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - dino_feature=dino_feature, - encoder_attention_mask=encoder_attention_mask, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) - - # 3. Output - if self.is_input_continuous: - if not self.use_linear_projection: - hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - hidden_states = self.proj_out(hidden_states) - else: - hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - - output = hidden_states + residual - elif self.is_input_vectorized: - hidden_states = self.norm_out(hidden_states) - logits = self.out(hidden_states) - # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) - logits = logits.permute(0, 2, 1) - - # log(p(x_0)) - output = F.log_softmax(logits.double(), dim=1).float() - elif self.is_input_patches: - # TODO: cleanup! - conditioning = self.transformer_blocks[0].norm1.emb( - timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] - hidden_states = self.proj_out_2(hidden_states) - - # unpatchify - height = width = int(hidden_states.shape[1] ** 0.5) - hidden_states = hidden_states.reshape( - shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) - ) - hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) - output = hidden_states.reshape( - shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) - ) - - if not return_dict: - return (output,) - - return TransformerMV2DModelOutput(sample=output) - - -@maybe_allow_in_graph -class BasicMVTransformerBlock(nn.Module): - r""" - A basic Transformer block. - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - only_cross_attention (`bool`, *optional*): - Whether to use only cross-attention layers. In this case two cross attention layers are used. - double_self_attention (`bool`, *optional*): - Whether to use two self-attention layers. In this case no cross attention layers are used. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout=0.0, - cross_attention_dim: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - attention_bias: bool = False, - only_cross_attention: bool = False, - double_self_attention: bool = False, - upcast_attention: bool = False, - norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", - final_dropout: bool = False, - num_views: int = 1, - cd_attention_last: bool = False, - cd_attention_mid: bool = False, - multiview_attention: bool = True, - mvcd_attention: bool = False, - rowwise_attention: bool = True, - use_dino: bool = False - ): - super().__init__() - self.only_cross_attention = only_cross_attention - - self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" - self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" - - if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: - raise ValueError( - f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" - f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." - ) - - # Define 3 blocks. Each block has its own normalization layer. - # 1. Self-Attn - if self.use_ada_layer_norm: - self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) - elif self.use_ada_layer_norm_zero: - self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) - else: - self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - - self.multiview_attention = multiview_attention - self.mvcd_attention = mvcd_attention - self.cd_attention_mid = cd_attention_mid - self.rowwise_attention = multiview_attention and rowwise_attention - - if mvcd_attention and (not cd_attention_mid): - # add cross domain attn to self attn - self.attn1 = CustomJointAttention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - processor=JointAttnProcessor() - ) - else: - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention - ) - # 1.1 rowwise multiview attention - if self.rowwise_attention: - # print('INFO: using self+row_wise mv attention...') - self.norm_mv = ( - AdaLayerNorm(dim, num_embeds_ada_norm) - if self.use_ada_layer_norm - else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - ) - self.attn_mv = CustomAttention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - processor=MVAttnProcessor() - ) - nn.init.zeros_(self.attn_mv.to_out[0].weight.data) - else: - self.norm_mv = None - self.attn_mv = None - - # # 1.2 rowwise cross-domain attn - # if mvcd_attention: - # self.attn_joint = CustomJointAttention( - # query_dim=dim, - # heads=num_attention_heads, - # dim_head=attention_head_dim, - # dropout=dropout, - # bias=attention_bias, - # cross_attention_dim=cross_attention_dim if only_cross_attention else None, - # upcast_attention=upcast_attention, - # processor=JointAttnProcessor() - # ) - # nn.init.zeros_(self.attn_joint.to_out[0].weight.data) - # self.norm_joint = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) - # else: - # self.attn_joint = None - # self.norm_joint = None - - # 2. Cross-Attn - if cross_attention_dim is not None or double_self_attention: - # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. - # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during - # the second cross attention block. - self.norm2 = ( - AdaLayerNorm(dim, num_embeds_ada_norm) - if self.use_ada_layer_norm - else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - ) - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim if not double_self_attention else None, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - ) # is self-attn if encoder_hidden_states is none - else: - self.norm2 = None - self.attn2 = None - - # 3. Feed-forward - self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 - - self.num_views = num_views - - - def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim - - def forward( - self, - hidden_states: torch.FloatTensor, - attention_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - timestep: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[torch.LongTensor] = None, - dino_feature: Optional[torch.FloatTensor] = None - ): - assert attention_mask is None # not supported yet - # Notice that normalization is always applied before the real computation in the following blocks. - # 1. Self-Attention - if self.use_ada_layer_norm: - norm_hidden_states = self.norm1(hidden_states, timestep) - elif self.use_ada_layer_norm_zero: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - else: - norm_hidden_states = self.norm1(hidden_states) - - cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask, - # multiview_attention=self.multiview_attention, - # mvcd_attention=self.mvcd_attention, - **cross_attention_kwargs, - ) - - - if self.use_ada_layer_norm_zero: - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = attn_output + hidden_states - - # import pdb;pdb.set_trace() - # 1.1 row wise multiview attention - if self.rowwise_attention: - norm_hidden_states = ( - self.norm_mv(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_mv(hidden_states) - ) - attn_output = self.attn_mv( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask, - num_views=self.num_views, - multiview_attention=self.multiview_attention, - cd_attention_mid=self.cd_attention_mid, - **cross_attention_kwargs, - ) - hidden_states = attn_output + hidden_states - - - # 2. Cross-Attention - if self.attn2 is not None: - norm_hidden_states = ( - self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) - ) - - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - **cross_attention_kwargs, - ) - hidden_states = attn_output + hidden_states - - # 3. Feed-forward - norm_hidden_states = self.norm3(hidden_states) - - if self.use_ada_layer_norm_zero: - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: - raise ValueError( - f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." - ) - - num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size - ff_output = torch.cat( - [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], - dim=self._chunk_dim, - ) - else: - ff_output = self.ff(norm_hidden_states) - - if self.use_ada_layer_norm_zero: - ff_output = gate_mlp.unsqueeze(1) * ff_output - - hidden_states = ff_output + hidden_states - - return hidden_states - - -class CustomAttention(Attention): - def set_use_memory_efficient_attention_xformers( - self, use_memory_efficient_attention_xformers: bool, *args, **kwargs - ): - processor = XFormersMVAttnProcessor() - self.set_processor(processor) - # print("using xformers attention processor") - - -class CustomJointAttention(Attention): - def set_use_memory_efficient_attention_xformers( - self, use_memory_efficient_attention_xformers: bool, *args, **kwargs - ): - processor = XFormersJointAttnProcessor() - self.set_processor(processor) - # print("using xformers attention processor") - -class MVAttnProcessor: - r""" - Default processor for performing attention-related computations. - """ - - def __call__( - self, - attn: Attention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - num_views=1, - cd_attention_mid=False - ): - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - height = int(math.sqrt(sequence_length)) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - # print('query', query.shape, 'key', key.shape, 'value', value.shape) - #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320]) - # pdb.set_trace() - # multi-view self-attention - def transpose(tensor): - tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height) - tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c - tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c - tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height) - return tensor - - if cd_attention_mid: - key = transpose(key) - value = transpose(value) - query = transpose(query) - else: - key = rearrange(key, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) - value = rearrange(value, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) - query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320]) - - query = attn.head_to_batch_dim(query).contiguous() - key = attn.head_to_batch_dim(key).contiguous() - value = attn.head_to_batch_dim(value).contiguous() - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - if cd_attention_mid: - hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height) - hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c - hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c - hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height) - else: - hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height) - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -class XFormersMVAttnProcessor: - r""" - Default processor for performing attention-related computations. - """ - - def __call__( - self, - attn: Attention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - num_views=1, - multiview_attention=True, - cd_attention_mid=False - ): - # print(num_views) - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - height = int(math.sqrt(sequence_length)) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # from yuancheng; here attention_mask is None - if attention_mask is not None: - # expand our mask's singleton query_tokens dimension: - # [batch*heads, 1, key_tokens] -> - # [batch*heads, query_tokens, key_tokens] - # so that it can be added as a bias onto the attention scores that xformers computes: - # [batch*heads, query_tokens, key_tokens] - # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. - _, query_tokens, _ = hidden_states.shape - attention_mask = attention_mask.expand(-1, query_tokens, -1) - - if attn.group_norm is not None: - print('Warning: using group norm, pay attention to use it in row-wise attention') - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key_raw = attn.to_k(encoder_hidden_states) - value_raw = attn.to_v(encoder_hidden_states) - - # print('query', query.shape, 'key', key.shape, 'value', value.shape) - # pdb.set_trace() - def transpose(tensor): - tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height) - tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c - tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c - tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height) - return tensor - # print(mvcd_attention) - # import pdb;pdb.set_trace() - if cd_attention_mid: - key = transpose(key_raw) - value = transpose(value_raw) - query = transpose(query) - else: - key = rearrange(key_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) - value = rearrange(value_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) - query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320]) - - - query = attn.head_to_batch_dim(query) # torch.Size([960, 384, 64]) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if cd_attention_mid: - hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height) - hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c - hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c - hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height) - else: - hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height) - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -class XFormersJointAttnProcessor: - r""" - Default processor for performing attention-related computations. - """ - - def __call__( - self, - attn: Attention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - num_tasks=2 - ): - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - # from yuancheng; here attention_mask is None - if attention_mask is not None: - # expand our mask's singleton query_tokens dimension: - # [batch*heads, 1, key_tokens] -> - # [batch*heads, query_tokens, key_tokens] - # so that it can be added as a bias onto the attention scores that xformers computes: - # [batch*heads, query_tokens, key_tokens] - # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. - _, query_tokens, _ = hidden_states.shape - attention_mask = attention_mask.expand(-1, query_tokens, -1) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - assert num_tasks == 2 # only support two tasks now - - def transpose(tensor): - tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c - tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c - return tensor - key = transpose(key) - value = transpose(value) - query = transpose(query) - # from icecream import ic - # ic(key.shape, value.shape, query.shape) - # import pdb;pdb.set_trace() - query = attn.head_to_batch_dim(query).contiguous() - key = attn.head_to_batch_dim(key).contiguous() - value = attn.head_to_batch_dim(value).contiguous() - - hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - hidden_states_normal, hidden_states_color = torch.chunk(hidden_states, dim=1, chunks=2) - hidden_states = torch.cat([hidden_states_normal, hidden_states_color], dim=0) # 2bv hw c - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -class JointAttnProcessor: - r""" - Default processor for performing attention-related computations. - """ - - def __call__( - self, - attn: Attention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - num_tasks=2 - ): - - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - assert num_tasks == 2 # only support two tasks now - - def transpose(tensor): - tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c - tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c - return tensor - key = transpose(key) - value = transpose(value) - query = transpose(query) - - - query = attn.head_to_batch_dim(query).contiguous() - key = attn.head_to_batch_dim(key).contiguous() - value = attn.head_to_batch_dim(value).contiguous() - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - hidden_states = torch.cat([hidden_states[:, 0], hidden_states[:, 1]], dim=0) # 2bv hw c - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states \ No newline at end of file diff --git a/apps/third_party/Era3D/mvdiffusion/models/unet_mv2d_blocks.py b/apps/third_party/Era3D/mvdiffusion/models/unet_mv2d_blocks.py deleted file mode 100644 index 14face576d731c4d17c864e6fd2b9d156a63a8cf..0000000000000000000000000000000000000000 --- a/apps/third_party/Era3D/mvdiffusion/models/unet_mv2d_blocks.py +++ /dev/null @@ -1,970 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Dict, Optional, Tuple - -import numpy as np -import torch -import torch.nn.functional as F -from torch import nn - -from diffusers.utils import is_torch_version, logging -from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 -from diffusers.models.dual_transformer_2d import DualTransformer2DModel -from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D - -from diffusers.models.unet_2d_blocks import DownBlock2D, ResnetDownsampleBlock2D, AttnDownBlock2D, CrossAttnDownBlock2D, SimpleCrossAttnDownBlock2D, SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, KCrossAttnDownBlock2D -from diffusers.models.unet_2d_blocks import UpBlock2D, ResnetUpsampleBlock2D, CrossAttnUpBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -def get_down_block( - down_block_type, - num_layers, - in_channels, - out_channels, - temb_channels, - add_downsample, - resnet_eps, - resnet_act_fn, - transformer_layers_per_block=1, - num_attention_heads=None, - resnet_groups=None, - cross_attention_dim=None, - downsample_padding=None, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - resnet_time_scale_shift="default", - resnet_skip_time_act=False, - resnet_out_scale_factor=1.0, - cross_attention_norm=None, - attention_head_dim=None, - downsample_type=None, - num_views=1, - cd_attention_last: bool = False, - cd_attention_mid: bool = False, - multiview_attention: bool = True, - sparse_mv_attention: bool = False, - selfattn_block: str = "custom", - mvcd_attention: bool=False, - use_dino: bool = False -): - # If attn head dim is not defined, we default it to the number of heads - if attention_head_dim is None: - logger.warn( - f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." - ) - attention_head_dim = num_attention_heads - - down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type - if down_block_type == "DownBlock2D": - return DownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "ResnetDownsampleBlock2D": - return ResnetDownsampleBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - skip_time_act=resnet_skip_time_act, - output_scale_factor=resnet_out_scale_factor, - ) - elif down_block_type == "AttnDownBlock2D": - if add_downsample is False: - downsample_type = None - else: - downsample_type = downsample_type or "conv" # default to 'conv' - return AttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - attention_head_dim=attention_head_dim, - resnet_time_scale_shift=resnet_time_scale_shift, - downsample_type=downsample_type, - ) - elif down_block_type == "CrossAttnDownBlock2D": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") - return CrossAttnDownBlock2D( - num_layers=num_layers, - transformer_layers_per_block=transformer_layers_per_block, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - # custom MV2D attention block - elif down_block_type == "CrossAttnDownBlockMV2D": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMV2D") - return CrossAttnDownBlockMV2D( - num_layers=num_layers, - transformer_layers_per_block=transformer_layers_per_block, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - num_views=num_views, - cd_attention_last=cd_attention_last, - cd_attention_mid=cd_attention_mid, - multiview_attention=multiview_attention, - sparse_mv_attention=sparse_mv_attention, - selfattn_block=selfattn_block, - mvcd_attention=mvcd_attention, - use_dino=use_dino - ) - elif down_block_type == "SimpleCrossAttnDownBlock2D": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D") - return SimpleCrossAttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - attention_head_dim=attention_head_dim, - resnet_time_scale_shift=resnet_time_scale_shift, - skip_time_act=resnet_skip_time_act, - output_scale_factor=resnet_out_scale_factor, - only_cross_attention=only_cross_attention, - cross_attention_norm=cross_attention_norm, - ) - elif down_block_type == "SkipDownBlock2D": - return SkipDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "AttnSkipDownBlock2D": - return AttnSkipDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - attention_head_dim=attention_head_dim, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "DownEncoderBlock2D": - return DownEncoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "AttnDownEncoderBlock2D": - return AttnDownEncoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - attention_head_dim=attention_head_dim, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "KDownBlock2D": - return KDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - ) - elif down_block_type == "KCrossAttnDownBlock2D": - return KCrossAttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - cross_attention_dim=cross_attention_dim, - attention_head_dim=attention_head_dim, - add_self_attention=True if not add_downsample else False, - ) - raise ValueError(f"{down_block_type} does not exist.") - - -def get_up_block( - up_block_type, - num_layers, - in_channels, - out_channels, - prev_output_channel, - temb_channels, - add_upsample, - resnet_eps, - resnet_act_fn, - transformer_layers_per_block=1, - num_attention_heads=None, - resnet_groups=None, - cross_attention_dim=None, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - resnet_time_scale_shift="default", - resnet_skip_time_act=False, - resnet_out_scale_factor=1.0, - cross_attention_norm=None, - attention_head_dim=None, - upsample_type=None, - num_views=1, - cd_attention_last: bool = False, - cd_attention_mid: bool = False, - multiview_attention: bool = True, - sparse_mv_attention: bool = False, - selfattn_block: str = "custom", - mvcd_attention: bool=False, - use_dino: bool = False -): - # If attn head dim is not defined, we default it to the number of heads - if attention_head_dim is None: - logger.warn( - f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." - ) - attention_head_dim = num_attention_heads - - up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type - if up_block_type == "UpBlock2D": - return UpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "ResnetUpsampleBlock2D": - return ResnetUpsampleBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - skip_time_act=resnet_skip_time_act, - output_scale_factor=resnet_out_scale_factor, - ) - elif up_block_type == "CrossAttnUpBlock2D": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") - return CrossAttnUpBlock2D( - num_layers=num_layers, - transformer_layers_per_block=transformer_layers_per_block, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - # custom MV2D attention block - elif up_block_type == "CrossAttnUpBlockMV2D": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMV2D") - return CrossAttnUpBlockMV2D( - num_layers=num_layers, - transformer_layers_per_block=transformer_layers_per_block, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - num_views=num_views, - cd_attention_last=cd_attention_last, - cd_attention_mid=cd_attention_mid, - multiview_attention=multiview_attention, - sparse_mv_attention=sparse_mv_attention, - selfattn_block=selfattn_block, - mvcd_attention=mvcd_attention, - use_dino=use_dino - ) - elif up_block_type == "SimpleCrossAttnUpBlock2D": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D") - return SimpleCrossAttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - attention_head_dim=attention_head_dim, - resnet_time_scale_shift=resnet_time_scale_shift, - skip_time_act=resnet_skip_time_act, - output_scale_factor=resnet_out_scale_factor, - only_cross_attention=only_cross_attention, - cross_attention_norm=cross_attention_norm, - ) - elif up_block_type == "AttnUpBlock2D": - if add_upsample is False: - upsample_type = None - else: - upsample_type = upsample_type or "conv" # default to 'conv' - - return AttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - attention_head_dim=attention_head_dim, - resnet_time_scale_shift=resnet_time_scale_shift, - upsample_type=upsample_type, - ) - elif up_block_type == "SkipUpBlock2D": - return SkipUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "AttnSkipUpBlock2D": - return AttnSkipUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - attention_head_dim=attention_head_dim, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "UpDecoderBlock2D": - return UpDecoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - temb_channels=temb_channels, - ) - elif up_block_type == "AttnUpDecoderBlock2D": - return AttnUpDecoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - attention_head_dim=attention_head_dim, - resnet_time_scale_shift=resnet_time_scale_shift, - temb_channels=temb_channels, - ) - elif up_block_type == "KUpBlock2D": - return KUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - ) - elif up_block_type == "KCrossAttnUpBlock2D": - return KCrossAttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - cross_attention_dim=cross_attention_dim, - attention_head_dim=attention_head_dim, - ) - - raise ValueError(f"{up_block_type} does not exist.") - - -class UNetMidBlockMV2DCrossAttn(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - num_attention_heads=1, - output_scale_factor=1.0, - cross_attention_dim=1280, - dual_cross_attention=False, - use_linear_projection=False, - upcast_attention=False, - num_views: int = 1, - cd_attention_last: bool = False, - cd_attention_mid: bool = False, - multiview_attention: bool = True, - sparse_mv_attention: bool = False, - selfattn_block: str = "custom", - mvcd_attention: bool=False, - use_dino: bool = False - ): - super().__init__() - - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - if selfattn_block == "custom": - from .transformer_mv2d import TransformerMV2DModel - elif selfattn_block == "rowwise": - from .transformer_mv2d_rowwise import TransformerMV2DModel - elif selfattn_block == "self_rowwise": - from .transformer_mv2d_self_rowwise import TransformerMV2DModel - else: - raise NotImplementedError - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - - for _ in range(num_layers): - if not dual_cross_attention: - attentions.append( - TransformerMV2DModel( - num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, - num_layers=transformer_layers_per_block, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - num_views=num_views, - cd_attention_last=cd_attention_last, - cd_attention_mid=cd_attention_mid, - multiview_attention=multiview_attention, - sparse_mv_attention=sparse_mv_attention, - mvcd_attention=mvcd_attention, - use_dino=use_dino - ) - ) - else: - raise NotImplementedError - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward( - self, - hidden_states: torch.FloatTensor, - temb: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - dino_feature: Optional[torch.FloatTensor] = None - ) -> torch.FloatTensor: - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - dino_feature=dino_feature, - return_dict=False, - )[0] - hidden_states = resnet(hidden_states, temb) - - return hidden_states - - -class CrossAttnUpBlockMV2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - prev_output_channel: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - num_attention_heads=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_upsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - num_views: int = 1, - cd_attention_last: bool = False, - cd_attention_mid: bool = False, - multiview_attention: bool = True, - sparse_mv_attention: bool = False, - selfattn_block: str = "custom", - mvcd_attention: bool=False, - use_dino: bool = False - ): - super().__init__() - resnets = [] - attentions = [] - - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - - if selfattn_block == "custom": - from .transformer_mv2d import TransformerMV2DModel - elif selfattn_block == "rowwise": - from .transformer_mv2d_rowwise import TransformerMV2DModel - elif selfattn_block == "self_rowwise": - from .transformer_mv2d_self_rowwise import TransformerMV2DModel - else: - raise NotImplementedError - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - if not dual_cross_attention: - attentions.append( - TransformerMV2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=transformer_layers_per_block, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - num_views=num_views, - cd_attention_last=cd_attention_last, - cd_attention_mid=cd_attention_mid, - multiview_attention=multiview_attention, - sparse_mv_attention=sparse_mv_attention, - mvcd_attention=mvcd_attention, - use_dino=use_dino - ) - ) - else: - raise NotImplementedError - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.FloatTensor, - res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], - temb: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - upsample_size: Optional[int] = None, - attention_mask: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - dino_feature: Optional[torch.FloatTensor] = None - ): - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - dino_feature, - None, # timestep - None, # class_labels - cross_attention_kwargs, - attention_mask, - encoder_attention_mask, - **ckpt_kwargs, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - dino_feature=dino_feature, - return_dict=False, - )[0] - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states - - -class CrossAttnDownBlockMV2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - num_attention_heads=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - num_views: int = 1, - cd_attention_last: bool = False, - cd_attention_mid: bool = False, - multiview_attention: bool = True, - sparse_mv_attention: bool = False, - selfattn_block: str = "custom", - mvcd_attention: bool=False, - use_dino: bool = False - ): - super().__init__() - resnets = [] - attentions = [] - - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - if selfattn_block == "custom": - from .transformer_mv2d import TransformerMV2DModel - elif selfattn_block == "rowwise": - from .transformer_mv2d_rowwise import TransformerMV2DModel - elif selfattn_block == "self_rowwise": - from .transformer_mv2d_self_rowwise import TransformerMV2DModel - else: - raise NotImplementedError - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - if not dual_cross_attention: - attentions.append( - TransformerMV2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=transformer_layers_per_block, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - num_views=num_views, - cd_attention_last=cd_attention_last, - cd_attention_mid=cd_attention_mid, - multiview_attention=multiview_attention, - sparse_mv_attention=sparse_mv_attention, - mvcd_attention=mvcd_attention, - use_dino=use_dino - ) - ) - else: - raise NotImplementedError - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.FloatTensor, - temb: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - dino_feature: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - additional_residuals=None, - ): - output_states = () - - blocks = list(zip(self.resnets, self.attentions)) - - for i, (resnet, attn) in enumerate(blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - dino_feature, - None, # timestep - None, # class_labels - cross_attention_kwargs, - attention_mask, - encoder_attention_mask, - **ckpt_kwargs, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - dino_feature=dino_feature, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - - # apply additional residuals to the output of the last pair of resnet and attention blocks - if i == len(blocks) - 1 and additional_residuals is not None: - hidden_states = hidden_states + additional_residuals - - output_states = output_states + (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states = output_states + (hidden_states,) - - return hidden_states, output_states - diff --git a/apps/third_party/Era3D/mvdiffusion/models/unet_mv2d_condition.py b/apps/third_party/Era3D/mvdiffusion/models/unet_mv2d_condition.py deleted file mode 100644 index f49b01322b7b7dae9a02258943730b4346fc1792..0000000000000000000000000000000000000000 --- a/apps/third_party/Era3D/mvdiffusion/models/unet_mv2d_condition.py +++ /dev/null @@ -1,1686 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union -import os - -import torch -import torch.nn as nn -import torch.utils.checkpoint - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.loaders import UNet2DConditionLoadersMixin -from diffusers.utils import BaseOutput, logging -from diffusers.models.activations import get_activation -from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor -from diffusers.models.embeddings import ( - GaussianFourierProjection, - ImageHintTimeEmbedding, - ImageProjection, - ImageTimeEmbedding, - TextImageProjection, - TextImageTimeEmbedding, - TextTimeEmbedding, - TimestepEmbedding, - Timesteps, -) -from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model -from diffusers.models.unet_2d_blocks import ( - CrossAttnDownBlock2D, - CrossAttnUpBlock2D, - DownBlock2D, - UNetMidBlock2DCrossAttn, - UNetMidBlock2DSimpleCrossAttn, - UpBlock2D, -) -from diffusers.utils import ( - CONFIG_NAME, - FLAX_WEIGHTS_NAME, - SAFETENSORS_WEIGHTS_NAME, - WEIGHTS_NAME, - _add_variant, - _get_model_file, - deprecate, - is_torch_version, - logging, -) -from diffusers.utils.import_utils import is_accelerate_available -from diffusers.utils.hub_utils import HF_HUB_OFFLINE -from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE -DIFFUSERS_CACHE = HUGGINGFACE_HUB_CACHE - -from diffusers import __version__ -from .unet_mv2d_blocks import ( - CrossAttnDownBlockMV2D, - CrossAttnUpBlockMV2D, - UNetMidBlockMV2DCrossAttn, - get_down_block, - get_up_block, -) -from einops import rearrange, repeat - -from diffusers import __version__ -from mvdiffusion.models.unet_mv2d_blocks import ( - CrossAttnDownBlockMV2D, - CrossAttnUpBlockMV2D, - UNetMidBlockMV2DCrossAttn, - get_down_block, - get_up_block, -) - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -class UNetMV2DConditionOutput(BaseOutput): - """ - The output of [`UNet2DConditionModel`]. - - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. - """ - - sample: torch.FloatTensor = None - - -class ResidualBlock(nn.Module): - def __init__(self, dim): - super(ResidualBlock, self).__init__() - self.linear1 = nn.Linear(dim, dim) - self.activation = nn.SiLU() - self.linear2 = nn.Linear(dim, dim) - - def forward(self, x): - identity = x - out = self.linear1(x) - out = self.activation(out) - out = self.linear2(out) - out += identity - out = self.activation(out) - return out - -class ResidualLiner(nn.Module): - def __init__(self, in_features, out_features, dim, act=None, num_block=1): - super(ResidualLiner, self).__init__() - self.linear_in = nn.Sequential(nn.Linear(in_features, dim), nn.SiLU()) - - blocks = nn.ModuleList() - for _ in range(num_block): - blocks.append(ResidualBlock(dim)) - self.blocks = blocks - - self.linear_out = nn.Linear(dim, out_features) - self.act = act - - def forward(self, x): - out = self.linear_in(x) - for block in self.blocks: - out = block(out) - out = self.linear_out(out) - if self.act is not None: - out = self.act(out) - return out - -class BasicConvBlock(nn.Module): - def __init__(self, in_channels, out_channels, stride=1): - super(BasicConvBlock, self).__init__() - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) - self.norm1 = nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True) - self.act = nn.SiLU() - self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) - self.norm2 = nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True) - self.downsample = nn.Sequential() - if stride != 1 or in_channels != out_channels: - self.downsample = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), - nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True) - ) - - def forward(self, x): - identity = x - out = self.conv1(x) - out = self.norm1(out) - out = self.act(out) - out = self.conv2(out) - out = self.norm2(out) - out += self.downsample(identity) - out = self.act(out) - return out - -class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): - r""" - A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample - shaped output. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented - for all models (such as downloading or saving). - - Parameters: - sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): - Height and width of input/output sample. - in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. - center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. - flip_sin_to_cos (`bool`, *optional*, defaults to `False`): - Whether to flip the sin to cos in the time embedding. - freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. - down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): - The tuple of downsample blocks to use. - mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): - Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or - `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): - The tuple of upsample blocks to use. - only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): - Whether to include self-attention in the basic transformer blocks, see - [`~models.attention.BasicTransformerBlock`]. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. - mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. - If `None`, normalization and activation layers is skipped in post-processing. - norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. - cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): - The dimension of the cross attention features. - transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): - The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for - [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], - [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. - encoder_hid_dim (`int`, *optional*, defaults to None): - If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` - dimension to `cross_attention_dim`. - encoder_hid_dim_type (`str`, *optional*, defaults to `None`): - If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text - embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. - attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. - num_attention_heads (`int`, *optional*): - The number of attention heads. If not defined, defaults to `attention_head_dim` - resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config - for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. - class_embed_type (`str`, *optional*, defaults to `None`): - The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, - `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. - addition_embed_type (`str`, *optional*, defaults to `None`): - Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or - "text". "text" will use the `TextTimeEmbedding` layer. - addition_time_embed_dim: (`int`, *optional*, defaults to `None`): - Dimension for the timestep embeddings. - num_class_embeds (`int`, *optional*, defaults to `None`): - Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing - class conditioning with `class_embed_type` equal to `None`. - time_embedding_type (`str`, *optional*, defaults to `positional`): - The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. - time_embedding_dim (`int`, *optional*, defaults to `None`): - An optional override for the dimension of the projected time embedding. - time_embedding_act_fn (`str`, *optional*, defaults to `None`): - Optional activation function to use only once on the time embeddings before they are passed to the rest of - the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. - timestep_post_act (`str`, *optional*, defaults to `None`): - The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. - time_cond_proj_dim (`int`, *optional*, defaults to `None`): - The dimension of `cond_proj` layer in the timestep embedding. - conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. - conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. - projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when - `class_embed_type="projection"`. Required when `class_embed_type="projection"`. - class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time - embeddings with the class embeddings. - mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): - Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If - `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the - `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` - otherwise. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - sample_size: Optional[int] = None, - in_channels: int = 4, - out_channels: int = 4, - center_input_sample: bool = False, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlockMV2D", - "CrossAttnDownBlockMV2D", - "CrossAttnDownBlockMV2D", - "DownBlock2D", - ), - mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn", - up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"), - only_cross_attention: Union[bool, Tuple[bool]] = False, - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - layers_per_block: Union[int, Tuple[int]] = 2, - downsample_padding: int = 1, - mid_block_scale_factor: float = 1, - act_fn: str = "silu", - norm_num_groups: Optional[int] = 32, - norm_eps: float = 1e-5, - cross_attention_dim: Union[int, Tuple[int]] = 1280, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, - encoder_hid_dim: Optional[int] = None, - encoder_hid_dim_type: Optional[str] = None, - attention_head_dim: Union[int, Tuple[int]] = 8, - num_attention_heads: Optional[Union[int, Tuple[int]]] = None, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - class_embed_type: Optional[str] = None, - addition_embed_type: Optional[str] = None, - addition_time_embed_dim: Optional[int] = None, - num_class_embeds: Optional[int] = None, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - resnet_skip_time_act: bool = False, - resnet_out_scale_factor: int = 1.0, - time_embedding_type: str = "positional", - time_embedding_dim: Optional[int] = None, - time_embedding_act_fn: Optional[str] = None, - timestep_post_act: Optional[str] = None, - time_cond_proj_dim: Optional[int] = None, - conv_in_kernel: int = 3, - conv_out_kernel: int = 3, - projection_class_embeddings_input_dim: Optional[int] = None, - projection_camera_embeddings_input_dim: Optional[int] = None, - class_embeddings_concat: bool = False, - mid_block_only_cross_attention: Optional[bool] = None, - cross_attention_norm: Optional[str] = None, - addition_embed_type_num_heads=64, - num_views: int = 1, - cd_attention_last: bool = False, - cd_attention_mid: bool = False, - multiview_attention: bool = True, - sparse_mv_attention: bool = False, - selfattn_block: str = "custom", - mvcd_attention: bool = False, - regress_elevation: bool = False, - regress_focal_length: bool = False, - num_regress_blocks: int = 4, - use_dino: bool = False, - addition_downsample: bool = False, - addition_channels: Optional[Tuple[int]] = (1280, 1280, 1280), - ): - super().__init__() - - self.sample_size = sample_size - self.num_views = num_views - self.mvcd_attention = mvcd_attention - if num_attention_heads is not None: - raise ValueError( - "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." - ) - - # If `num_attention_heads` is not defined (which is the case for most models) - # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. - # The reason for this behavior is to correct for incorrectly named variables that were introduced - # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 - # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking - # which is why we correct for the naming here. - num_attention_heads = num_attention_heads or attention_head_dim - - # Check inputs - if len(down_block_types) != len(up_block_types): - raise ValueError( - f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." - ) - - if len(block_out_channels) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." - ) - - if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." - ) - - # input - conv_in_padding = (conv_in_kernel - 1) // 2 - self.conv_in = nn.Conv2d( - in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding - ) - - # time - if time_embedding_type == "fourier": - time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 - if time_embed_dim % 2 != 0: - raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") - self.time_proj = GaussianFourierProjection( - time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos - ) - timestep_input_dim = time_embed_dim - elif time_embedding_type == "positional": - time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 - - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - else: - raise ValueError( - f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." - ) - - self.time_embedding = TimestepEmbedding( - timestep_input_dim, - time_embed_dim, - act_fn=act_fn, - post_act_fn=timestep_post_act, - cond_proj_dim=time_cond_proj_dim, - ) - - if encoder_hid_dim_type is None and encoder_hid_dim is not None: - encoder_hid_dim_type = "text_proj" - self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) - logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") - - if encoder_hid_dim is None and encoder_hid_dim_type is not None: - raise ValueError( - f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." - ) - - if encoder_hid_dim_type == "text_proj": - self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) - elif encoder_hid_dim_type == "text_image_proj": - # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much - # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use - # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` - self.encoder_hid_proj = TextImageProjection( - text_embed_dim=encoder_hid_dim, - image_embed_dim=cross_attention_dim, - cross_attention_dim=cross_attention_dim, - ) - elif encoder_hid_dim_type == "image_proj": - # Kandinsky 2.2 - self.encoder_hid_proj = ImageProjection( - image_embed_dim=encoder_hid_dim, - cross_attention_dim=cross_attention_dim, - ) - elif encoder_hid_dim_type is not None: - raise ValueError( - f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." - ) - else: - self.encoder_hid_proj = None - - # class embedding - if class_embed_type is None and num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) - elif class_embed_type == "identity": - self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) - elif class_embed_type == "projection": - if projection_class_embeddings_input_dim is None: - raise ValueError( - "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" - ) - # The projection `class_embed_type` is the same as the timestep `class_embed_type` except - # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings - # 2. it projects from an arbitrary input dimension. - # - # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. - # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. - # As a result, `TimestepEmbedding` can be passed arbitrary vectors. - self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - elif class_embed_type == "simple_projection": - if projection_class_embeddings_input_dim is None: - raise ValueError( - "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" - ) - self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) - else: - self.class_embedding = None - - if addition_embed_type == "text": - if encoder_hid_dim is not None: - text_time_embedding_from_dim = encoder_hid_dim - else: - text_time_embedding_from_dim = cross_attention_dim - - self.add_embedding = TextTimeEmbedding( - text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads - ) - elif addition_embed_type == "text_image": - # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much - # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use - # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` - self.add_embedding = TextImageTimeEmbedding( - text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim - ) - elif addition_embed_type == "text_time": - self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) - self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - elif addition_embed_type == "image": - # Kandinsky 2.2 - self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) - elif addition_embed_type == "image_hint": - # Kandinsky 2.2 ControlNet - self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) - elif addition_embed_type is not None: - raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") - - if time_embedding_act_fn is None: - self.time_embed_act = None - else: - self.time_embed_act = get_activation(time_embedding_act_fn) - - self.down_blocks = nn.ModuleList([]) - self.up_blocks = nn.ModuleList([]) - - if isinstance(only_cross_attention, bool): - if mid_block_only_cross_attention is None: - mid_block_only_cross_attention = only_cross_attention - - only_cross_attention = [only_cross_attention] * len(down_block_types) - - if mid_block_only_cross_attention is None: - mid_block_only_cross_attention = False - - if isinstance(num_attention_heads, int): - num_attention_heads = (num_attention_heads,) * len(down_block_types) - - if isinstance(attention_head_dim, int): - attention_head_dim = (attention_head_dim,) * len(down_block_types) - - if isinstance(cross_attention_dim, int): - cross_attention_dim = (cross_attention_dim,) * len(down_block_types) - - if isinstance(layers_per_block, int): - layers_per_block = [layers_per_block] * len(down_block_types) - - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) - - if class_embeddings_concat: - # The time embeddings are concatenated with the class embeddings. The dimension of the - # time embeddings passed to the down, middle, and up blocks is twice the dimension of the - # regular time embeddings - blocks_time_embed_dim = time_embed_dim * 2 - else: - blocks_time_embed_dim = time_embed_dim - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block[i], - transformer_layers_per_block=transformer_layers_per_block[i], - in_channels=input_channel, - out_channels=output_channel, - temb_channels=blocks_time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim[i], - num_attention_heads=num_attention_heads[i], - downsample_padding=downsample_padding, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - resnet_skip_time_act=resnet_skip_time_act, - resnet_out_scale_factor=resnet_out_scale_factor, - cross_attention_norm=cross_attention_norm, - attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, - num_views=num_views, - cd_attention_last=cd_attention_last, - cd_attention_mid=cd_attention_mid, - multiview_attention=multiview_attention, - sparse_mv_attention=sparse_mv_attention, - selfattn_block=selfattn_block, - mvcd_attention=mvcd_attention, - use_dino=use_dino - ) - self.down_blocks.append(down_block) - - # mid - if mid_block_type == "UNetMidBlock2DCrossAttn": - self.mid_block = UNetMidBlock2DCrossAttn( - transformer_layers_per_block=transformer_layers_per_block[-1], - in_channels=block_out_channels[-1], - temb_channels=blocks_time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, - cross_attention_dim=cross_attention_dim[-1], - num_attention_heads=num_attention_heads[-1], - resnet_groups=norm_num_groups, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - # custom MV2D attention block - elif mid_block_type == "UNetMidBlockMV2DCrossAttn": - self.mid_block = UNetMidBlockMV2DCrossAttn( - transformer_layers_per_block=transformer_layers_per_block[-1], - in_channels=block_out_channels[-1], - temb_channels=blocks_time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, - cross_attention_dim=cross_attention_dim[-1], - num_attention_heads=num_attention_heads[-1], - resnet_groups=norm_num_groups, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - num_views=num_views, - cd_attention_last=cd_attention_last, - cd_attention_mid=cd_attention_mid, - multiview_attention=multiview_attention, - sparse_mv_attention=sparse_mv_attention, - selfattn_block=selfattn_block, - mvcd_attention=mvcd_attention, - use_dino=use_dino - ) - elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": - self.mid_block = UNetMidBlock2DSimpleCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=blocks_time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - cross_attention_dim=cross_attention_dim[-1], - attention_head_dim=attention_head_dim[-1], - resnet_groups=norm_num_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - skip_time_act=resnet_skip_time_act, - only_cross_attention=mid_block_only_cross_attention, - cross_attention_norm=cross_attention_norm, - ) - elif mid_block_type is None: - self.mid_block = None - else: - raise ValueError(f"unknown mid_block_type : {mid_block_type}") - - self.addition_downsample = addition_downsample - if self.addition_downsample: - inc = block_out_channels[-1] - self.downsample = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.conv_block = nn.ModuleList() - self.conv_block.append(BasicConvBlock(inc, addition_channels[0], stride=1)) - for dim_ in addition_channels[1:-1]: - self.conv_block.append(BasicConvBlock(dim_, dim_, stride=1)) - self.conv_block.append(BasicConvBlock(dim_, inc)) - self.addition_conv_out = nn.Conv2d(inc, inc, kernel_size=1, bias=False) - nn.init.zeros_(self.addition_conv_out.weight.data) - self.addition_act_out = nn.SiLU() - self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) - - self.regress_elevation = regress_elevation - self.regress_focal_length = regress_focal_length - if regress_elevation or regress_focal_length: - self.pool = nn.AdaptiveAvgPool2d((1, 1)) - self.camera_embedding = TimestepEmbedding(projection_camera_embeddings_input_dim, time_embed_dim=time_embed_dim) - - regress_in_dim = block_out_channels[-1]*2 if mvcd_attention else block_out_channels - - if regress_elevation: - self.elevation_regressor = ResidualLiner(regress_in_dim, 1, 1280, act=None, num_block=num_regress_blocks) - if regress_focal_length: - self.focal_regressor = ResidualLiner(regress_in_dim, 1, 1280, act=None, num_block=num_regress_blocks) - ''' - self.regress_elevation = regress_elevation - self.regress_focal_length = regress_focal_length - if regress_elevation and (not regress_focal_length): - print("Regressing elevation") - cam_dim = 1 - elif regress_focal_length and (not regress_elevation): - print("Regressing focal length") - cam_dim = 6 - elif regress_elevation and regress_focal_length: - print("Regressing both elevation and focal length") - cam_dim = 7 - else: - cam_dim = 0 - assert projection_camera_embeddings_input_dim == 2*cam_dim, "projection_camera_embeddings_input_dim should be 2*cam_dim" - if regress_elevation or regress_focal_length: - self.elevation_regressor = nn.ModuleList([ - nn.Linear(block_out_channels[-1], 1280), - nn.SiLU(), - nn.Linear(1280, 1280), - nn.SiLU(), - nn.Linear(1280, cam_dim) - ]) - self.pool = nn.AdaptiveAvgPool2d((1, 1)) - self.focal_act = nn.Softmax(dim=-1) - self.camera_embedding = TimestepEmbedding(projection_camera_embeddings_input_dim, time_embed_dim=time_embed_dim) - ''' - - # count how many layers upsample the images - self.num_upsamplers = 0 - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - reversed_num_attention_heads = list(reversed(num_attention_heads)) - reversed_layers_per_block = list(reversed(layers_per_block)) - reversed_cross_attention_dim = list(reversed(cross_attention_dim)) - reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) - only_cross_attention = list(reversed(only_cross_attention)) - - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - is_final_block = i == len(block_out_channels) - 1 - - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] - - # add upsample block for all BUT final layer - if not is_final_block: - add_upsample = True - self.num_upsamplers += 1 - else: - add_upsample = False - - up_block = get_up_block( - up_block_type, - num_layers=reversed_layers_per_block[i] + 1, - transformer_layers_per_block=reversed_transformer_layers_per_block[i], - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=blocks_time_embed_dim, - add_upsample=add_upsample, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=reversed_cross_attention_dim[i], - num_attention_heads=reversed_num_attention_heads[i], - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - resnet_skip_time_act=resnet_skip_time_act, - resnet_out_scale_factor=resnet_out_scale_factor, - cross_attention_norm=cross_attention_norm, - attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, - num_views=num_views, - cd_attention_last=cd_attention_last, - cd_attention_mid=cd_attention_mid, - multiview_attention=multiview_attention, - sparse_mv_attention=sparse_mv_attention, - selfattn_block=selfattn_block, - mvcd_attention=mvcd_attention, - use_dino=use_dino - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - if norm_num_groups is not None: - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps - ) - - self.conv_act = get_activation(act_fn) - - else: - self.conv_norm_out = None - self.conv_act = None - - conv_out_padding = (conv_out_kernel - 1) // 2 - self.conv_out = nn.Conv2d( - block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding - ) - - @property - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "set_processor"): - processors[f"{name}.processor"] = module.processor - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - def set_default_attn_processor(self): - """ - Disables custom attention processors and sets the default attention implementation. - """ - self.set_attn_processor(AttnProcessor()) - - def set_attention_slice(self, slice_size): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module splits the input tensor in slices to compute attention in - several steps. This is useful for saving some memory in exchange for a small decrease in speed. - - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If - `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - sliceable_head_dims = [] - - def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): - if hasattr(module, "set_attention_slice"): - sliceable_head_dims.append(module.sliceable_head_dim) - - for child in module.children(): - fn_recursive_retrieve_sliceable_dims(child) - - # retrieve number of attention layers - for module in self.children(): - fn_recursive_retrieve_sliceable_dims(module) - - num_sliceable_layers = len(sliceable_head_dims) - - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = [dim // 2 for dim in sliceable_head_dims] - elif slice_size == "max": - # make smallest slice possible - slice_size = num_sliceable_layers * [1] - - slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size - - if len(slice_size) != len(sliceable_head_dims): - raise ValueError( - f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" - f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." - ) - - for i in range(len(slice_size)): - size = slice_size[i] - dim = sliceable_head_dims[i] - if size is not None and size > dim: - raise ValueError(f"size {size} has to be smaller or equal to {dim}.") - - # Recursively walk through all the children. - # Any children which exposes the set_attention_slice method - # gets the message - def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): - if hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size.pop()) - - for child in module.children(): - fn_recursive_set_attention_slice(child, slice_size) - - reversed_slice_size = list(reversed(slice_size)) - for module in self.children(): - fn_recursive_set_attention_slice(module, reversed_slice_size) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)): - module.gradient_checkpointing = value - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - class_labels: Optional[torch.Tensor] = None, - timestep_cond: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, - down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, - mid_block_additional_residual: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - dino_feature: Optional[torch.Tensor] = None, - return_dict: bool = True, - vis_max_min: bool = False, - ) -> Union[UNetMV2DConditionOutput, Tuple]: - r""" - The [`UNet2DConditionModel`] forward method. - - Args: - sample (`torch.FloatTensor`): - The noisy input tensor with the following shape `(batch, channel, height, width)`. - timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. - encoder_hidden_states (`torch.FloatTensor`): - The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. - encoder_attention_mask (`torch.Tensor`): - A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If - `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, - which adds large negative values to the attention scores corresponding to "discard" tokens. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain - tuple. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. - added_cond_kwargs: (`dict`, *optional*): - A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that - are passed along to the UNet blocks. - - Returns: - [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise - a `tuple` is returned where the first element is the sample tensor. - """ - record_max_min = {} - # By default samples have to be AT least a multiple of the overall upsampling factor. - # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). - # However, the upsampling interpolation output size can be forced to fit any upsampling size - # on the fly if necessary. - default_overall_up_factor = 2**self.num_upsamplers - - # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` - forward_upsample_size = False - upsample_size = None - - if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - logger.info("Forward upsample size to force interpolation output size.") - forward_upsample_size = True - - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension - # expects mask of shape: - # [batch, key_tokens] - # adds singleton query_tokens dimension: - # [batch, 1, key_tokens] - # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: - # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) - # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) - if attention_mask is not None: - # assume that mask is expressed as: - # (1 = keep, 0 = discard) - # convert mask into a bias that can be added to attention scores: - # (keep = +0, discard = -10000.0) - attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None: - encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 - encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - - # 0. center input if necessary - if self.config.center_input_sample: - sample = 2 * sample - 1.0 - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = self.time_proj(timesteps) - - # `Timesteps` does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=sample.dtype) - - emb = self.time_embedding(t_emb, timestep_cond) - aug_emb = None - if self.class_embedding is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - - if self.config.class_embed_type == "timestep": - class_labels = self.time_proj(class_labels) - - # `Timesteps` does not contain any weights and will always return f32 tensors - # there might be better ways to encapsulate this. - class_labels = class_labels.to(dtype=sample.dtype) - - class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) - if self.config.class_embeddings_concat: - emb = torch.cat([emb, class_emb], dim=-1) - else: - emb = emb + class_emb - - if self.config.addition_embed_type == "text": - aug_emb = self.add_embedding(encoder_hidden_states) - elif self.config.addition_embed_type == "text_image": - # Kandinsky 2.1 - style - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" - ) - - image_embs = added_cond_kwargs.get("image_embeds") - text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) - aug_emb = self.add_embedding(text_embs, image_embs) - elif self.config.addition_embed_type == "text_time": - # SDXL - style - if "text_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" - ) - text_embeds = added_cond_kwargs.get("text_embeds") - if "time_ids" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" - ) - time_ids = added_cond_kwargs.get("time_ids") - time_embeds = self.add_time_proj(time_ids.flatten()) - time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) - - add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) - add_embeds = add_embeds.to(emb.dtype) - aug_emb = self.add_embedding(add_embeds) - elif self.config.addition_embed_type == "image": - # Kandinsky 2.2 - style - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" - ) - image_embs = added_cond_kwargs.get("image_embeds") - aug_emb = self.add_embedding(image_embs) - elif self.config.addition_embed_type == "image_hint": - # Kandinsky 2.2 - style - if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" - ) - image_embs = added_cond_kwargs.get("image_embeds") - hint = added_cond_kwargs.get("hint") - aug_emb, hint = self.add_embedding(image_embs, hint) - sample = torch.cat([sample, hint], dim=1) - - emb = emb + aug_emb if aug_emb is not None else emb - emb_pre_act = emb - if self.time_embed_act is not None: - emb = self.time_embed_act(emb) - - if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": - encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) - elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": - # Kadinsky 2.1 - style - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" - ) - - image_embeds = added_cond_kwargs.get("image_embeds") - encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) - elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": - # Kandinsky 2.2 - style - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" - ) - image_embeds = added_cond_kwargs.get("image_embeds") - encoder_hidden_states = self.encoder_hid_proj(image_embeds) - # 2. pre-process - sample = self.conv_in(sample) - # 3. down - - is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None - is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None - - down_block_res_samples = (sample,) - for i, downsample_block in enumerate(self.down_blocks): - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - # For t2i-adapter CrossAttnDownBlock2D - additional_residuals = {} - if is_adapter and len(down_block_additional_residuals) > 0: - additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) - - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - dino_feature=dino_feature, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, - **additional_residuals, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - - if is_adapter and len(down_block_additional_residuals) > 0: - sample += down_block_additional_residuals.pop(0) - - down_block_res_samples += res_samples - - if is_controlnet: - new_down_block_res_samples = () - - for down_block_res_sample, down_block_additional_residual in zip( - down_block_res_samples, down_block_additional_residuals - ): - down_block_res_sample = down_block_res_sample + down_block_additional_residual - new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) - - down_block_res_samples = new_down_block_res_samples - - if self.addition_downsample: - global_sample = sample - global_sample = self.downsample(global_sample) - for layer in self.conv_block: - global_sample = layer(global_sample) - global_sample = self.addition_act_out(self.addition_conv_out(global_sample)) - global_sample = self.upsample(global_sample) - # 4. mid - if self.mid_block is not None: - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - dino_feature=dino_feature, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, - ) - # 4.1 regress elevation and focal length - # # predict elevation -> embed -> projection -> add to time emb - if self.regress_elevation or self.regress_focal_length: - pool_embeds = self.pool(sample.detach()).squeeze(-1).squeeze(-1) # (2B, C) - if self.mvcd_attention: - pool_embeds_normal, pool_embeds_color = torch.chunk(pool_embeds, 2, dim=0) - pool_embeds = torch.cat([pool_embeds_normal, pool_embeds_color], dim=-1) # (B, 2C) - pose_pred = [] - if self.regress_elevation: - ele_pred = self.elevation_regressor(pool_embeds) - ele_pred = rearrange(ele_pred, '(b v) c -> b v c', v=self.num_views) - ele_pred = torch.mean(ele_pred, dim=1) - pose_pred.append(ele_pred) # b, c - - if self.regress_focal_length: - focal_pred = self.focal_regressor(pool_embeds) - focal_pred = rearrange(focal_pred, '(b v) c -> b v c', v=self.num_views) - focal_pred = torch.mean(focal_pred, dim=1) - pose_pred.append(focal_pred) - pose_pred = torch.cat(pose_pred, dim=-1) - # 'e_de_da_sincos', (B, 2) - pose_embeds = torch.cat([ - torch.sin(pose_pred), - torch.cos(pose_pred) - ], dim=-1) - pose_embeds = self.camera_embedding(pose_embeds) - pose_embeds = torch.repeat_interleave(pose_embeds, self.num_views, 0) - if self.mvcd_attention: - pose_embeds = torch.cat([pose_embeds,] * 2, dim=0) - - emb = pose_embeds + emb_pre_act - if self.time_embed_act is not None: - emb = self.time_embed_act(emb) - - if is_controlnet: - sample = sample + mid_block_additional_residual - - if self.addition_downsample: - sample = sample + global_sample - - # 5. up - for i, upsample_block in enumerate(self.up_blocks): - is_final_block = i == len(self.up_blocks) - 1 - - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - - # if we have not reached the final block and need to forward the - # upsample size, we do it here - if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] - - if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states, - dino_feature=dino_feature, - cross_attention_kwargs=cross_attention_kwargs, - upsample_size=upsample_size, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - ) - else: - sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size - ) - if torch.isnan(sample).any() or torch.isinf(sample).any(): - print("NAN in sample, stop training.") - exit() - # 6. post-process - if self.conv_norm_out: - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - if not return_dict: - return (sample, pose_pred) - if self.regress_elevation or self.regress_focal_length: - return UNetMV2DConditionOutput(sample=sample), pose_pred - else: - return UNetMV2DConditionOutput(sample=sample) - - - @classmethod - def from_pretrained_2d( - cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], - camera_embedding_type: str, num_views: int, sample_size: int, - zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False, - projection_camera_embeddings_input_dim: int=2, - cd_attention_last: bool = False, num_regress_blocks: int = 4, - cd_attention_mid: bool = False, multiview_attention: bool = True, - sparse_mv_attention: bool = False, selfattn_block: str = 'custom', mvcd_attention: bool = False, - in_channels: int = 8, out_channels: int = 4, unclip: bool = False, regress_elevation: bool = False, regress_focal_length: bool = False, - init_mvattn_with_selfattn: bool= False, use_dino: bool = False, addition_downsample: bool = False, - **kwargs - ): - r""" - Instantiate a pretrained PyTorch model from a pretrained model configuration. - - The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To - train the model, set it back in training mode with `model.train()`. - - Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`~ModelMixin.save_pretrained`]. - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the - dtype is automatically derived from the model's weights. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to resume downloading the model weights and configuration files. If set to `False`, any - incompletely downloaded files are deleted. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info (`bool`, *optional*, defaults to `False`): - Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - from_flax (`bool`, *optional*, defaults to `False`): - Load the model weights from a Flax checkpoint save file. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - mirror (`str`, *optional*): - Mirror source to resolve accessibility issues if you're downloading a model in China. We do not - guarantee the timeliness or safety of the source, and you should refer to the mirror site for more - information. - device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): - A map that specifies where each submodule should go. It doesn't need to be defined for each - parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the - same device. - - Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For - more information about each option see [designing a device - map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). - max_memory (`Dict`, *optional*): - A dictionary device identifier for the maximum memory. Will default to the maximum memory available for - each GPU and the available CPU RAM if unset. - offload_folder (`str` or `os.PathLike`, *optional*): - The path to offload weights if `device_map` contains the value `"disk"`. - offload_state_dict (`bool`, *optional*): - If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if - the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` - when there is some disk offload. - low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): - Speed up model loading only loading the pretrained weights and not initializing the weights. This also - tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. - Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this - argument to `True` will raise an error. - variant (`str`, *optional*): - Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when - loading `from_flax`. - use_safetensors (`bool`, *optional*, defaults to `None`): - If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the - `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` - weights. If set to `False`, `safetensors` weights are not loaded. - - - - To use private or [gated models](https://huggingface.co./docs/hub/models-gated#gated-models), log-in with - `huggingface-cli login`. You can also activate the special - ["offline-mode"](https://huggingface.co./diffusers/installation.html#offline-mode) to use this method in a - firewalled environment. - - - - Example: - - ```py - from diffusers import UNet2DConditionModel - - unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") - ``` - - If you get the error message below, you need to finetune the weights for your downstream task: - - ```bash - Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: - - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated - You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. - ``` - """ - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) - force_download = kwargs.pop("force_download", False) - from_flax = kwargs.pop("from_flax", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - output_loading_info = kwargs.pop("output_loading_info", False) - local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - torch_dtype = kwargs.pop("torch_dtype", None) - subfolder = kwargs.pop("subfolder", None) - device_map = kwargs.pop("device_map", None) - max_memory = kwargs.pop("max_memory", None) - offload_folder = kwargs.pop("offload_folder", None) - offload_state_dict = kwargs.pop("offload_state_dict", False) - variant = kwargs.pop("variant", None) - use_safetensors = kwargs.pop("use_safetensors", None) - - if use_safetensors: - raise ValueError( - "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" - ) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - if device_map is not None and not is_accelerate_available(): - raise NotImplementedError( - "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" - " `device_map=None`. You can install accelerate with `pip install accelerate`." - ) - - # Check if we can handle device_map and dispatching the weights - if device_map is not None and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError( - "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" - " `device_map=None`." - ) - - # Load config if we don't provide a configuration - config_path = pretrained_model_name_or_path - - user_agent = { - "diffusers": __version__, - "file_type": "model", - "framework": "pytorch", - } - - # load config - config, unused_kwargs, commit_hash = cls.load_config( - config_path, - cache_dir=cache_dir, - return_unused_kwargs=True, - return_commit_hash=True, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - device_map=device_map, - max_memory=max_memory, - offload_folder=offload_folder, - offload_state_dict=offload_state_dict, - user_agent=user_agent, - **kwargs, - ) - - # modify config - config["_class_name"] = cls.__name__ - config['in_channels'] = in_channels - config['out_channels'] = out_channels - config['sample_size'] = sample_size # training resolution - config['num_views'] = num_views - config['cd_attention_last'] = cd_attention_last - config['cd_attention_mid'] = cd_attention_mid - config['multiview_attention'] = multiview_attention - config['sparse_mv_attention'] = sparse_mv_attention - config['selfattn_block'] = selfattn_block - config['mvcd_attention'] = mvcd_attention - config["down_block_types"] = [ - "CrossAttnDownBlockMV2D", - "CrossAttnDownBlockMV2D", - "CrossAttnDownBlockMV2D", - "DownBlock2D" - ] - config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn" - config["up_block_types"] = [ - "UpBlock2D", - "CrossAttnUpBlockMV2D", - "CrossAttnUpBlockMV2D", - "CrossAttnUpBlockMV2D" - ] - - - config['regress_elevation'] = regress_elevation # true - config['regress_focal_length'] = regress_focal_length # true - config['projection_camera_embeddings_input_dim'] = projection_camera_embeddings_input_dim # 2 for elevation and 10 for focal_length - config['use_dino'] = use_dino - config['num_regress_blocks'] = num_regress_blocks - config['addition_downsample'] = addition_downsample - # load model - model_file = None - if from_flax: - raise NotImplementedError - else: - if use_safetensors: - try: - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - commit_hash=commit_hash, - ) - except IOError as e: - if not allow_pickle: - raise e - pass - if model_file is None: - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=_add_variant(WEIGHTS_NAME, variant), - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - commit_hash=commit_hash, - ) - - model = cls.from_config(config, **unused_kwargs) - import copy - state_dict_pretrain = load_state_dict(model_file, variant=variant) - state_dict = copy.deepcopy(state_dict_pretrain) - - if init_mvattn_with_selfattn: - for key in state_dict_pretrain: - if 'attn1' in key: - key_mv = key.replace('attn1', 'attn_mv') - state_dict[key_mv] = state_dict_pretrain[key] - if 'to_out.0.weight' in key: - nn.init.zeros_(state_dict[key_mv].data) - if 'transformer_blocks' in key and 'norm1' in key: # in case that initialize the norm layer in resnet block - key_mv = key.replace('norm1', 'norm_mv') - state_dict[key_mv] = state_dict_pretrain[key] - # del state_dict_pretrain - - model._convert_deprecated_attention_blocks(state_dict) - - conv_in_weight = state_dict['conv_in.weight'] - conv_out_weight = state_dict['conv_out.weight'] - model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d( - model, - state_dict, - model_file, - pretrained_model_name_or_path, - ignore_mismatched_sizes=True, - ) - if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]): - # initialize from the original SD structure - model.conv_in.weight.data[:,:4] = conv_in_weight - - # whether to place all zero to new layers? - if zero_init_conv_in: - model.conv_in.weight.data[:,4:] = 0. - - if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]): - # initialize from the original SD structure - model.conv_out.weight.data[:,:4] = conv_out_weight - if out_channels == 8: # copy for the last 4 channels - model.conv_out.weight.data[:, 4:] = conv_out_weight - - if zero_init_camera_projection: # true - params = [p for p in model.camera_embedding.parameters()] - torch.nn.init.zeros_(params[-1].data) - - loading_info = { - "missing_keys": missing_keys, - "unexpected_keys": unexpected_keys, - "mismatched_keys": mismatched_keys, - "error_msgs": error_msgs, - } - - if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): - raise ValueError( - f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." - ) - elif torch_dtype is not None: - model = model.to(torch_dtype) - - model.register_to_config(_name_or_path=pretrained_model_name_or_path) - - # Set model in evaluation mode to deactivate DropOut modules by default - model.eval() - if output_loading_info: - return model, loading_info - return model - - @classmethod - def _load_pretrained_model_2d( - cls, - model, - state_dict, - resolved_archive_file, - pretrained_model_name_or_path, - ignore_mismatched_sizes=False, - ): - # Retrieve missing & unexpected_keys - model_state_dict = model.state_dict() - loaded_keys = list(state_dict.keys()) - - expected_keys = list(model_state_dict.keys()) - - original_loaded_keys = loaded_keys - - missing_keys = list(set(expected_keys) - set(loaded_keys)) - unexpected_keys = list(set(loaded_keys) - set(expected_keys)) - - # Make sure we are able to load base models as well as derived models (with heads) - model_to_load = model - - def _find_mismatched_keys( - state_dict, - model_state_dict, - loaded_keys, - ignore_mismatched_sizes, - ): - mismatched_keys = [] - if ignore_mismatched_sizes: - for checkpoint_key in loaded_keys: - model_key = checkpoint_key - - if ( - model_key in model_state_dict - and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape - ): - mismatched_keys.append( - (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) - ) - del state_dict[checkpoint_key] - return mismatched_keys - - if state_dict is not None: - # Whole checkpoint - mismatched_keys = _find_mismatched_keys( - state_dict, - model_state_dict, - original_loaded_keys, - ignore_mismatched_sizes, - ) - error_msgs = _load_state_dict_into_model(model_to_load, state_dict) - - if len(error_msgs) > 0: - error_msg = "\n\t".join(error_msgs) - if "size mismatch" in error_msg: - error_msg += ( - "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." - ) - raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") - - if len(unexpected_keys) > 0: - logger.warning( - f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" - f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" - f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" - " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" - " BertForPreTraining model).\n- This IS NOT expected if you are initializing" - f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" - " identical (initializing a BertForSequenceClassification model from a" - " BertForSequenceClassification model)." - ) - else: - logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") - if len(missing_keys) > 0: - logger.warning( - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" - f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" - " TRAIN this model on a down-stream task to be able to use it for predictions and inference." - ) - elif len(mismatched_keys) == 0: - logger.info( - f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" - f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" - f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" - " without further training." - ) - if len(mismatched_keys) > 0: - mismatched_warning = "\n".join( - [ - f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" - for key, shape1, shape2 in mismatched_keys - ] - ) - logger.warning( - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" - f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" - f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" - " able to use it for predictions and inference." - ) - - return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs - diff --git a/apps/third_party/Era3D/pipelines/pipeline_mvdiffusion_unclip.py b/apps/third_party/Era3D/pipelines/pipeline_mvdiffusion_unclip.py deleted file mode 100644 index c3b04b6c93830de654aeee68ad90a1fb0c686211..0000000000000000000000000000000000000000 --- a/apps/third_party/Era3D/pipelines/pipeline_mvdiffusion_unclip.py +++ /dev/null @@ -1,633 +0,0 @@ -import inspect -import warnings -from typing import Callable, List, Optional, Union, Dict, Any -import PIL -import torch -from packaging import version -from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, CLIPFeatureExtractor, CLIPTokenizer, CLIPTextModel -from diffusers.utils.import_utils import is_accelerate_available -from diffusers.configuration_utils import FrozenDict -from diffusers.image_processor import VaeImageProcessor -from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.models.embeddings import get_timestep_embedding -from diffusers.schedulers import KarrasDiffusionSchedulers -from diffusers.utils import deprecate, logging -from diffusers.utils.torch_utils import randn_tensor -from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer -import os -import torchvision.transforms.functional as TF -from einops import rearrange -logger = logging.get_logger(__name__) - -class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): - """ - Pipeline for text-guided image to image generation using stable unCLIP. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - feature_extractor ([`CLIPFeatureExtractor`]): - Feature extractor for image pre-processing before being encoded. - image_encoder ([`CLIPVisionModelWithProjection`]): - CLIP vision model for encoding images. - image_normalizer ([`StableUnCLIPImageNormalizer`]): - Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image - embeddings after the noise has been applied. - image_noising_scheduler ([`KarrasDiffusionSchedulers`]): - Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined - by `noise_level` in `StableUnCLIPPipeline.__call__`. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co./docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`KarrasDiffusionSchedulers`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - """ - # image encoding components - feature_extractor: CLIPFeatureExtractor - image_encoder: CLIPVisionModelWithProjection - # image noising components - image_normalizer: StableUnCLIPImageNormalizer - image_noising_scheduler: KarrasDiffusionSchedulers - # regular denoising components - tokenizer: CLIPTokenizer - text_encoder: CLIPTextModel - unet: UNet2DConditionModel - scheduler: KarrasDiffusionSchedulers - vae: AutoencoderKL - - def __init__( - self, - # image encoding components - feature_extractor: CLIPFeatureExtractor, - image_encoder: CLIPVisionModelWithProjection, - # image noising components - image_normalizer: StableUnCLIPImageNormalizer, - image_noising_scheduler: KarrasDiffusionSchedulers, - # regular denoising components - tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModel, - unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - # vae - vae: AutoencoderKL, - num_views: int = 4, - ): - super().__init__() - - self.register_modules( - feature_extractor=feature_extractor, - image_encoder=image_encoder, - image_normalizer=image_normalizer, - image_noising_scheduler=image_noising_scheduler, - tokenizer=tokenizer, - text_encoder=text_encoder, - unet=unet, - scheduler=scheduler, - vae=vae, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.num_views: int = num_views - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. - - When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several - steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's - models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only - when their specific submodule has its `forward` method called. - """ - if is_accelerate_available(): - from accelerate import cpu_offload - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device(f"cuda:{gpu_id}") - - # TODO: self.image_normalizer.{scale,unscale} are not covered by the offload hooks, so they fails if added to the list - models = [ - self.image_encoder, - self.text_encoder, - self.unet, - self.vae, - ] - for cpu_offloaded_model in models: - if cpu_offloaded_model is not None: - cpu_offload(cpu_offloaded_model, device) - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - if do_classifier_free_guidance: - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - normal_prompt_embeds, color_prompt_embeds = torch.chunk(prompt_embeds, 2, dim=0) - - prompt_embeds = torch.cat([normal_prompt_embeds, normal_prompt_embeds, color_prompt_embeds, color_prompt_embeds], 0) - - return prompt_embeds - - def _encode_image( - self, - image_pil, - device, - num_images_per_prompt, - do_classifier_free_guidance, - noise_level: int=0, - generator: Optional[torch.Generator] = None - ): - dtype = next(self.image_encoder.parameters()).dtype - # ______________________________clip image embedding______________________________ - image = self.feature_extractor(images=image_pil, return_tensors="pt").pixel_values - image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - - image_embeds = self.noise_image_embeddings( - image_embeds=image_embeds, - noise_level=noise_level, - generator=generator, - ) - # duplicate image embeddings for each generation per prompt, using mps friendly method - # image_embeds = image_embeds.unsqueeze(1) - # note: the condition input is same - image_embeds = image_embeds.repeat(num_images_per_prompt, 1) - - if do_classifier_free_guidance: - normal_image_embeds, color_image_embeds = torch.chunk(image_embeds, 2, dim=0) - negative_prompt_embeds = torch.zeros_like(normal_image_embeds) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - image_embeds = torch.cat([negative_prompt_embeds, normal_image_embeds, negative_prompt_embeds, color_image_embeds], 0) - - # _____________________________vae input latents__________________________________________________ - image_pt = torch.stack([TF.to_tensor(img) for img in image_pil], dim=0).to(dtype=self.vae.dtype, device=device) - image_pt = image_pt * 2.0 - 1.0 - image_latents = self.vae.encode(image_pt).latent_dist.mode() * self.vae.config.scaling_factor - # Note: repeat differently from official pipelines - image_latents = image_latents.repeat(num_images_per_prompt, 1, 1, 1) - - if do_classifier_free_guidance: - normal_image_latents, color_image_latents = torch.chunk(image_latents, 2, dim=0) - image_latents = torch.cat([torch.zeros_like(normal_image_latents), normal_image_latents, - torch.zeros_like(color_image_latents), color_image_latents], 0) - - return image_embeds, image_latents - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - image, - height, - width, - callback_steps, - noise_level, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - - if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: - raise ValueError( - f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_unclip.StableUnCLIPPipeline.noise_image_embeddings - def noise_image_embeddings( - self, - image_embeds: torch.Tensor, - noise_level: int, - noise: Optional[torch.FloatTensor] = None, - generator: Optional[torch.Generator] = None, - ): - """ - Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher - `noise_level` increases the variance in the final un-noised images. - - The noise is applied in two ways - 1. A noise schedule is applied directly to the embeddings - 2. A vector of sinusoidal time embeddings are appended to the output. - - In both cases, the amount of noise is controlled by the same `noise_level`. - - The embeddings are normalized before the noise is applied and un-normalized after the noise is applied. - """ - if noise is None: - noise = randn_tensor( - image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype - ) - - noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) - - image_embeds = self.image_normalizer.scale(image_embeds) - - image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) - - image_embeds = self.image_normalizer.unscale(image_embeds) - - noise_level = get_timestep_embedding( - timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0 - ) - - # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors, - # but we might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - noise_level = noise_level.to(image_embeds.dtype) - - image_embeds = torch.cat((image_embeds, noise_level), 1) - - return image_embeds - - @torch.no_grad() - # @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - image: Union[torch.FloatTensor, PIL.Image.Image], - prompt: Union[str, List[str]], - prompt_embeds: torch.FloatTensor = None, - dino_feature: torch.FloatTensor = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 20, - guidance_scale: float = 10, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - noise_level: int = 0, - image_embeds: Optional[torch.FloatTensor] = None, - return_elevation_focal: Optional[bool] = False, - gt_img_in: Optional[torch.FloatTensor] = None, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch. The image will be encoded to its CLIP embedding which - the unet will be conditioned on. Note that the image is _not_ encoded by the vae and then used as the - latents in the denoising process such as in the standard stable diffusion text guided image variation - process. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 20): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 10.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - noise_level (`int`, *optional*, defaults to `0`): - The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in - the final un-noised images. See `StableUnCLIPPipeline.noise_image_embeddings` for details. - image_embeds (`torch.FloatTensor`, *optional*): - Pre-generated CLIP embeddings to condition the unet on. Note that these are not latents to be used in - the denoising process. If you want to provide pre-generated latents, pass them to `__call__` as - `latents`. - - Examples: - - Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is - True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. - """ - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt=prompt, - image=image, - height=height, - width=width, - callback_steps=callback_steps, - noise_level=noise_level - ) - - # 2. Define call parameters - if isinstance(image, list): - batch_size = len(image) - elif isinstance(image, torch.Tensor): - batch_size = image.shape[0] - assert batch_size >= self.num_views and batch_size % self.num_views == 0 - elif isinstance(image, PIL.Image.Image): - image = [image]*self.num_views*2 - batch_size = self.num_views*2 - - if isinstance(prompt, str): - prompt = [prompt] * self.num_views * 2 - - device = self._execution_device - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale != 1.0 - - # 3. Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) - prompt_embeds = self._encode_prompt( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, - negative_prompt=negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - ) - - - # 4. Encoder input image - if isinstance(image, list): - image_pil = image - elif isinstance(image, torch.Tensor): - image_pil = [TF.to_pil_image(image[i]) for i in range(image.shape[0])] - noise_level = torch.tensor([noise_level], device=device) - image_embeds, image_latents = self._encode_image( - image_pil=image_pil, - device=device, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, - noise_level=noise_level, - generator=generator, - ) - - # 5. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 6. Prepare latent variables - num_channels_latents = self.unet.config.out_channels - if gt_img_in is not None: - latents = gt_img_in * self.scheduler.init_noise_sigma - else: - latents = self.prepare_latents( - batch_size=batch_size, - num_channels_latents=num_channels_latents, - height=height, - width=width, - dtype=prompt_embeds.dtype, - device=device, - generator=generator, - latents=latents, - ) - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - eles, focals = [], [] - # 8. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps)): - if do_classifier_free_guidance: - normal_latents, color_latents = torch.chunk(latents, 2, dim=0) - latent_model_input = torch.cat([normal_latents, normal_latents, color_latents, color_latents], 0) - else: - latent_model_input = latents - latent_model_input = torch.cat([ - latent_model_input, image_latents - ], dim=1) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - unet_out = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - dino_feature=dino_feature, - class_labels=image_embeds, - cross_attention_kwargs=cross_attention_kwargs, - return_dict=False) - - noise_pred = unet_out[0] - if return_elevation_focal: - uncond_pose, pose = torch.chunk(unet_out[1], 2, 0) - pose = uncond_pose + guidance_scale * (pose - uncond_pose) - ele = pose[:, 0].detach().cpu().numpy() # b - eles.append(ele) - focal = pose[:, 1].detach().cpu().numpy() - focals.append(focal) - - # perform guidance - if do_classifier_free_guidance: - normal_noise_pred_uncond, normal_noise_pred_text, color_noise_pred_uncond, color_noise_pred_text = torch.chunk(noise_pred, 4, dim=0) - - noise_pred_uncond, noise_pred_text = torch.cat([normal_noise_pred_uncond, color_noise_pred_uncond], 0), torch.cat([normal_noise_pred_text, color_noise_pred_text], 0) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 9. Post-processing - if not output_type == "latent": - if num_channels_latents == 8: - latents = torch.cat([latents[:, :4], latents[:, 4:]], dim=0) - with torch.no_grad(): - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - else: - image = latents - - image = self.image_processor.postprocess(image, output_type=output_type) - - # Offload last model to CPU - # if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - # self.final_offload_hook.offload() - if not return_dict: - return (image, ) - if return_elevation_focal: - return ImagePipelineOutput(images=image), eles, focals - else: - return ImagePipelineOutput(images=image) \ No newline at end of file diff --git a/apps/third_party/InstantMeshes b/apps/third_party/InstantMeshes deleted file mode 100644 index 2350cb16946fb0e157b7195fd0688a4c0a2e4596..0000000000000000000000000000000000000000 --- a/apps/third_party/InstantMeshes +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a1946606e15d452b3d16ff3ec0161ff12053b6e0c8568d3eab6cd7e97486448d -size 2759416 diff --git a/apps/third_party/LGM/.gitignore b/apps/third_party/LGM/.gitignore deleted file mode 100644 index 6d505a0dde1a5e08118f2ba9d49dc32561b85ef8..0000000000000000000000000000000000000000 --- a/apps/third_party/LGM/.gitignore +++ /dev/null @@ -1,8 +0,0 @@ -*.pt -*.yaml -**/__pycache__ -*.pyc - -weights* -models -sd-v2* \ No newline at end of file diff --git a/apps/third_party/LGM/README.md b/apps/third_party/LGM/README.md deleted file mode 100644 index 6450c21d2cc4c083350f99fa8f8b765efb72491a..0000000000000000000000000000000000000000 --- a/apps/third_party/LGM/README.md +++ /dev/null @@ -1,71 +0,0 @@ -# MVDream-diffusers - -A **unified** diffusers implementation of [MVDream](https://github.com/bytedance/MVDream) and [ImageDream](https://github.com/bytedance/ImageDream). - -We provide converted `fp16` weights on huggingface: -* [MVDream](https://huggingface.co./ashawkey/mvdream-sd2.1-diffusers) -* [ImageDream](https://huggingface.co./ashawkey/imagedream-ipmv-diffusers) - - -### Install -```bash -# dependency -pip install -r requirements.txt - -# xformers is required! please refer to https://github.com/facebookresearch/xformers -pip install ninja -pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers -``` - -### Usage - -```bash -python run_mvdream.py "a cute owl" -python run_imagedream.py data/anya_rgba.png -``` - -### Convert weights - -MVDream: -```bash -# download original ckpt (we only support the SD 2.1 version) -mkdir models -cd models -wget https://huggingface.co./MVDream/MVDream/resolve/main/sd-v2.1-base-4view.pt -wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd-v2-base.yaml -cd .. - -# convert -python convert_mvdream_to_diffusers.py --checkpoint_path models/sd-v2.1-base-4view.pt --dump_path ./weights_mvdream --original_config_file models/sd-v2-base.yaml --half --to_safetensors --test -``` - -ImageDream: -```bash -# download original ckpt (we only support the pixel-controller version) -cd models -wget https://huggingface.co./Peng-Wang/ImageDream/resolve/main/sd-v2.1-base-4view-ipmv.pt -wget https://raw.githubusercontent.com/bytedance/ImageDream/main/extern/ImageDream/imagedream/configs/sd_v2_base_ipmv.yaml -cd .. - -# convert -python convert_mvdream_to_diffusers.py --checkpoint_path models/sd-v2.1-base-4view-ipmv.pt --dump_path ./weights_imagedream --original_config_file models/sd_v2_base_ipmv.yaml --half --to_safetensors --test -``` - -### Acknowledgement - -* The original papers: - ```bibtex - @article{shi2023MVDream, - author = {Shi, Yichun and Wang, Peng and Ye, Jianglong and Mai, Long and Li, Kejie and Yang, Xiao}, - title = {MVDream: Multi-view Diffusion for 3D Generation}, - journal = {arXiv:2308.16512}, - year = {2023}, - } - @article{wang2023imagedream, - title={ImageDream: Image-Prompt Multi-view Diffusion for 3D Generation}, - author={Wang, Peng and Shi, Yichun}, - journal={arXiv preprint arXiv:2312.02201}, - year={2023} - } - ``` -* This codebase is modified from [mvdream-hf](https://github.com/KokeCacao/mvdream-hf). \ No newline at end of file diff --git a/apps/third_party/LGM/__pycache__/mv_unet.cpython-310.pyc b/apps/third_party/LGM/__pycache__/mv_unet.cpython-310.pyc deleted file mode 100644 index c3ed493ca51586c8da9d3fb25e5a1e9a7f97403e..0000000000000000000000000000000000000000 Binary files a/apps/third_party/LGM/__pycache__/mv_unet.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/LGM/__pycache__/mv_unet.cpython-38.pyc b/apps/third_party/LGM/__pycache__/mv_unet.cpython-38.pyc deleted file mode 100644 index 537fda445006262ed0efe0083ace93a1c8356c9b..0000000000000000000000000000000000000000 Binary files a/apps/third_party/LGM/__pycache__/mv_unet.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/LGM/__pycache__/pipeline_mvdream.cpython-310.pyc b/apps/third_party/LGM/__pycache__/pipeline_mvdream.cpython-310.pyc deleted file mode 100644 index bf696f21c305e41ded94b91498138bbfd0f6f95d..0000000000000000000000000000000000000000 Binary files a/apps/third_party/LGM/__pycache__/pipeline_mvdream.cpython-310.pyc and /dev/null differ diff --git a/apps/third_party/LGM/__pycache__/pipeline_mvdream.cpython-38.pyc b/apps/third_party/LGM/__pycache__/pipeline_mvdream.cpython-38.pyc deleted file mode 100644 index c26aa68daf81aebf5ac8f0a20bf869140268c35c..0000000000000000000000000000000000000000 Binary files a/apps/third_party/LGM/__pycache__/pipeline_mvdream.cpython-38.pyc and /dev/null differ diff --git a/apps/third_party/LGM/convert_mvdream_to_diffusers.py b/apps/third_party/LGM/convert_mvdream_to_diffusers.py deleted file mode 100644 index 1008cb46c0a42ec6da65f98645e5a2b1f9baea9e..0000000000000000000000000000000000000000 --- a/apps/third_party/LGM/convert_mvdream_to_diffusers.py +++ /dev/null @@ -1,597 +0,0 @@ -# Modified from https://github.com/huggingface/diffusers/blob/bc691231360a4cbc7d19a58742ebb8ed0f05e027/scripts/convert_original_stable_diffusion_to_diffusers.py - -import argparse -import torch -import sys - -sys.path.insert(0, ".") - -from diffusers.models import ( - AutoencoderKL, -) -from omegaconf import OmegaConf -from diffusers.schedulers import DDIMScheduler -from diffusers.utils import logging -from typing import Any -from accelerate import init_empty_weights -from accelerate.utils import set_module_tensor_to_device -from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPImageProcessor - -from mv_unet import MultiViewUNetModel -from pipeline_mvdream import MVDreamPipeline -import kiui - -logger = logging.get_logger(__name__) - - -def assign_to_checkpoint( - paths, - checkpoint, - old_checkpoint, - attention_paths_to_split=None, - additional_replacements=None, - config=None, -): - """ - This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits - attention layers, and takes into account additional replacements that may arise. - Assigns the weights to the new checkpoint. - """ - assert isinstance( - paths, list - ), "Paths should be a list of dicts containing 'old' and 'new' keys." - - # Splits the attention layers into three variables. - if attention_paths_to_split is not None: - for path, path_map in attention_paths_to_split.items(): - old_tensor = old_checkpoint[path] - channels = old_tensor.shape[0] // 3 - - target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) - - assert config is not None - num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 - - old_tensor = old_tensor.reshape( - (num_heads, 3 * channels // num_heads) + old_tensor.shape[1:] - ) - query, key, value = old_tensor.split(channels // num_heads, dim=1) - - checkpoint[path_map["query"]] = query.reshape(target_shape) - checkpoint[path_map["key"]] = key.reshape(target_shape) - checkpoint[path_map["value"]] = value.reshape(target_shape) - - for path in paths: - new_path = path["new"] - - # These have already been assigned - if ( - attention_paths_to_split is not None - and new_path in attention_paths_to_split - ): - continue - - # Global renaming happens here - new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") - new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") - new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") - - if additional_replacements is not None: - for replacement in additional_replacements: - new_path = new_path.replace(replacement["old"], replacement["new"]) - - # proj_attn.weight has to be converted from conv 1D to linear - is_attn_weight = "proj_attn.weight" in new_path or ( - "attentions" in new_path and "to_" in new_path - ) - shape = old_checkpoint[path["old"]].shape - if is_attn_weight and len(shape) == 3: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] - elif is_attn_weight and len(shape) == 4: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] - else: - checkpoint[new_path] = old_checkpoint[path["old"]] - - -def shave_segments(path, n_shave_prefix_segments=1): - """ - Removes segments. Positive values shave the first segments, negative shave the last segments. - """ - if n_shave_prefix_segments >= 0: - return ".".join(path.split(".")[n_shave_prefix_segments:]) - else: - return ".".join(path.split(".")[:n_shave_prefix_segments]) - - -def create_vae_diffusers_config(original_config, image_size): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - - - if 'imagedream' in original_config.model.target: - vae_params = original_config.model.params.vae_config.params.ddconfig - _ = original_config.model.params.vae_config.params.embed_dim - vae_key = "vae_model." - else: - vae_params = original_config.model.params.first_stage_config.params.ddconfig - _ = original_config.model.params.first_stage_config.params.embed_dim - vae_key = "first_stage_model." - - block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] - down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) - up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) - - config = { - "sample_size": image_size, - "in_channels": vae_params.in_channels, - "out_channels": vae_params.out_ch, - "down_block_types": tuple(down_block_types), - "up_block_types": tuple(up_block_types), - "block_out_channels": tuple(block_out_channels), - "latent_channels": vae_params.z_channels, - "layers_per_block": vae_params.num_res_blocks, - } - return config, vae_key - - -def convert_ldm_vae_checkpoint(checkpoint, config, vae_key): - # extract state dict for VAE - vae_state_dict = {} - keys = list(checkpoint.keys()) - for key in keys: - if key.startswith(vae_key): - vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) - - new_checkpoint = {} - - new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] - new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] - new_checkpoint["encoder.conv_out.weight"] = vae_state_dict[ - "encoder.conv_out.weight" - ] - new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] - new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict[ - "encoder.norm_out.weight" - ] - new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict[ - "encoder.norm_out.bias" - ] - - new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] - new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] - new_checkpoint["decoder.conv_out.weight"] = vae_state_dict[ - "decoder.conv_out.weight" - ] - new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] - new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict[ - "decoder.norm_out.weight" - ] - new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict[ - "decoder.norm_out.bias" - ] - - new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] - new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] - new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] - new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] - - # Retrieves the keys for the encoder down blocks only - num_down_blocks = len( - { - ".".join(layer.split(".")[:3]) - for layer in vae_state_dict - if "encoder.down" in layer - } - ) - down_blocks = { - layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] - for layer_id in range(num_down_blocks) - } - - # Retrieves the keys for the decoder up blocks only - num_up_blocks = len( - { - ".".join(layer.split(".")[:3]) - for layer in vae_state_dict - if "decoder.up" in layer - } - ) - up_blocks = { - layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] - for layer_id in range(num_up_blocks) - } - - for i in range(num_down_blocks): - resnets = [ - key - for key in down_blocks[i] - if f"down.{i}" in key and f"down.{i}.downsample" not in key - ] - - if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[ - f"encoder.down_blocks.{i}.downsamplers.0.conv.weight" - ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight") - new_checkpoint[ - f"encoder.down_blocks.{i}.downsamplers.0.conv.bias" - ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias") - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) - - mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) - - mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) - conv_attn_to_linear(new_checkpoint) - - for i in range(num_up_blocks): - block_id = num_up_blocks - 1 - i - resnets = [ - key - for key in up_blocks[block_id] - if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key - ] - - if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[ - f"decoder.up_blocks.{i}.upsamplers.0.conv.weight" - ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"] - new_checkpoint[ - f"decoder.up_blocks.{i}.upsamplers.0.conv.bias" - ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) - - mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) - - mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) - conv_attn_to_linear(new_checkpoint) - return new_checkpoint - - -def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("nin_shortcut", "conv_shortcut") - new_item = shave_segments( - new_item, n_shave_prefix_segments=n_shave_prefix_segments - ) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("norm.weight", "group_norm.weight") - new_item = new_item.replace("norm.bias", "group_norm.bias") - - new_item = new_item.replace("q.weight", "to_q.weight") - new_item = new_item.replace("q.bias", "to_q.bias") - - new_item = new_item.replace("k.weight", "to_k.weight") - new_item = new_item.replace("k.bias", "to_k.bias") - - new_item = new_item.replace("v.weight", "to_v.weight") - new_item = new_item.replace("v.bias", "to_v.bias") - - new_item = new_item.replace("proj_out.weight", "to_out.0.weight") - new_item = new_item.replace("proj_out.bias", "to_out.0.bias") - - new_item = shave_segments( - new_item, n_shave_prefix_segments=n_shave_prefix_segments - ) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def conv_attn_to_linear(checkpoint): - keys = list(checkpoint.keys()) - attn_keys = ["query.weight", "key.weight", "value.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in attn_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - elif "proj_attn.weight" in key: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0] - - -def create_unet_config(original_config): - return OmegaConf.to_container( - original_config.model.params.unet_config.params, resolve=True - ) - - -def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, device): - checkpoint = torch.load(checkpoint_path, map_location=device) - # print(f"Checkpoint: {checkpoint.keys()}") - torch.cuda.empty_cache() - - original_config = OmegaConf.load(original_config_file) - # print(f"Original Config: {original_config}") - prediction_type = "epsilon" - image_size = 256 - num_train_timesteps = ( - getattr(original_config.model.params, "timesteps", None) or 1000 - ) - beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02 - beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085 - scheduler = DDIMScheduler( - beta_end=beta_end, - beta_schedule="scaled_linear", - beta_start=beta_start, - num_train_timesteps=num_train_timesteps, - steps_offset=1, - clip_sample=False, - set_alpha_to_one=False, - prediction_type=prediction_type, - ) - scheduler.register_to_config(clip_sample=False) - - unet_config = create_unet_config(original_config) - - # remove unused configs - unet_config.pop('legacy', None) - unet_config.pop('use_linear_in_transformer', None) - unet_config.pop('use_spatial_transformer', None) - - unet_config.pop('ip_mode', None) - unet_config.pop('with_ip', None) - - unet = MultiViewUNetModel(**unet_config) - unet.register_to_config(**unet_config) - # print(f"Unet State Dict: {unet.state_dict().keys()}") - unet.load_state_dict( - { - key.replace("model.diffusion_model.", ""): value - for key, value in checkpoint.items() - if key.replace("model.diffusion_model.", "") in unet.state_dict() - } - ) - for param_name, param in unet.state_dict().items(): - set_module_tensor_to_device(unet, param_name, device=device, value=param) - - # Convert the VAE model. - vae_config, vae_key = create_vae_diffusers_config(original_config, image_size=image_size) - converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config, vae_key) - - if ( - "model" in original_config - and "params" in original_config.model - and "scale_factor" in original_config.model.params - ): - vae_scaling_factor = original_config.model.params.scale_factor - else: - vae_scaling_factor = 0.18215 # default SD scaling factor - - vae_config["scaling_factor"] = vae_scaling_factor - - with init_empty_weights(): - vae = AutoencoderKL(**vae_config) - - for param_name, param in converted_vae_checkpoint.items(): - set_module_tensor_to_device(vae, param_name, device=device, value=param) - - # we only supports SD 2.1 based model - tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="tokenizer") - text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="text_encoder").to(device=device) # type: ignore - - # imagedream variant - if unet.ip_dim > 0: - feature_extractor: CLIPImageProcessor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") - image_encoder: CLIPVisionModel = CLIPVisionModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") - else: - feature_extractor = None - image_encoder = None - - pipe = MVDreamPipeline( - vae=vae, - unet=unet, - tokenizer=tokenizer, - text_encoder=text_encoder, - scheduler=scheduler, - feature_extractor=feature_extractor, - image_encoder=image_encoder, - ) - - return pipe - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - - parser.add_argument( - "--checkpoint_path", - default=None, - type=str, - required=True, - help="Path to the checkpoint to convert.", - ) - parser.add_argument( - "--original_config_file", - default=None, - type=str, - help="The YAML config file corresponding to the original architecture.", - ) - parser.add_argument( - "--to_safetensors", - action="store_true", - help="Whether to store pipeline in safetensors format or not.", - ) - parser.add_argument( - "--half", action="store_true", help="Save weights in half precision." - ) - parser.add_argument( - "--test", - action="store_true", - help="Whether to test inference after convertion.", - ) - parser.add_argument( - "--dump_path", - default=None, - type=str, - required=True, - help="Path to the output model.", - ) - parser.add_argument( - "--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)" - ) - args = parser.parse_args() - - args.device = torch.device( - args.device - if args.device is not None - else "cuda" - if torch.cuda.is_available() - else "cpu" - ) - - pipe = convert_from_original_mvdream_ckpt( - checkpoint_path=args.checkpoint_path, - original_config_file=args.original_config_file, - device=args.device, - ) - - if args.half: - pipe.to(torch_dtype=torch.float16) - - print(f"Saving pipeline to {args.dump_path}...") - pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) - - if args.test: - try: - # mvdream - if pipe.unet.ip_dim == 0: - print(f"Testing each subcomponent of the pipeline...") - images = pipe( - prompt="Head of Hatsune Miku", - negative_prompt="painting, bad quality, flat", - output_type="pil", - guidance_scale=7.5, - num_inference_steps=50, - device=args.device, - ) - for i, image in enumerate(images): - image.save(f"test_image_{i}.png") # type: ignore - - print(f"Testing entire pipeline...") - loaded_pipe = MVDreamPipeline.from_pretrained(args.dump_path) # type: ignore - images = loaded_pipe( - prompt="Head of Hatsune Miku", - negative_prompt="painting, bad quality, flat", - output_type="pil", - guidance_scale=7.5, - num_inference_steps=50, - device=args.device, - ) - for i, image in enumerate(images): - image.save(f"test_image_{i}.png") # type: ignore - # imagedream - else: - input_image = kiui.read_image('data/anya_rgba.png', mode='float') - print(f"Testing each subcomponent of the pipeline...") - images = pipe( - image=input_image, - prompt="", - negative_prompt="", - output_type="pil", - guidance_scale=5.0, - num_inference_steps=50, - device=args.device, - ) - for i, image in enumerate(images): - image.save(f"test_image_{i}.png") # type: ignore - - print(f"Testing entire pipeline...") - loaded_pipe = MVDreamPipeline.from_pretrained(args.dump_path) # type: ignore - images = loaded_pipe( - image=input_image, - prompt="", - negative_prompt="", - output_type="pil", - guidance_scale=5.0, - num_inference_steps=50, - device=args.device, - ) - for i, image in enumerate(images): - image.save(f"test_image_{i}.png") # type: ignore - - - print("Inference test passed!") - except Exception as e: - print(f"Failed to test inference: {e}") diff --git a/apps/third_party/LGM/data/anya_rgba.png b/apps/third_party/LGM/data/anya_rgba.png deleted file mode 100644 index 089499e16e410207c890b45bc865627352df967d..0000000000000000000000000000000000000000 Binary files a/apps/third_party/LGM/data/anya_rgba.png and /dev/null differ diff --git a/apps/third_party/LGM/data/corgi.jpg b/apps/third_party/LGM/data/corgi.jpg deleted file mode 100644 index 946cd78b1234ed2abd7d63bbb320a5939b349ca0..0000000000000000000000000000000000000000 Binary files a/apps/third_party/LGM/data/corgi.jpg and /dev/null differ diff --git a/apps/third_party/LGM/mv_unet.py b/apps/third_party/LGM/mv_unet.py deleted file mode 100644 index 6a4c58faff5736531f36c433c335fcd77790edc1..0000000000000000000000000000000000000000 --- a/apps/third_party/LGM/mv_unet.py +++ /dev/null @@ -1,1005 +0,0 @@ -import math -import numpy as np -from inspect import isfunction -from typing import Optional, Any, List - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange, repeat - -from diffusers.configuration_utils import ConfigMixin -from diffusers.models.modeling_utils import ModelMixin - -# require xformers! -import xformers -import xformers.ops - -from kiui.cam import orbit_camera - -def get_camera( - num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False, -): - angle_gap = azimuth_span / num_frames - cameras = [] - for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap): - - pose = orbit_camera(-elevation, azimuth, radius=1) # kiui's elevation is negated, [4, 4] - - # opengl to blender - if blender_coord: - pose[2] *= -1 - pose[[1, 2]] = pose[[2, 1]] - - cameras.append(pose.flatten()) - - if extra_view: - cameras.append(np.zeros_like(cameras[0])) - - return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16] - - -def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): - """ - Create sinusoidal timestep embeddings. - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - if not repeat_only: - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) - * torch.arange(start=0, end=half, dtype=torch.float32) - / half - ).to(device=timesteps.device) - args = timesteps[:, None] * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat( - [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 - ) - else: - embedding = repeat(timesteps, "b -> b d", d=dim) - # import pdb; pdb.set_trace() - return embedding - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return nn.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -def avg_pool_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D average pooling module. - """ - if dims == 1: - return nn.AvgPool1d(*args, **kwargs) - elif dims == 2: - return nn.AvgPool2d(*args, **kwargs) - elif dims == 3: - return nn.AvgPool3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -def default(val, d): - if val is not None: - return val - return d() if isfunction(d) else d - - -class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def forward(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) - - -class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): - super().__init__() - inner_dim = int(dim * mult) - dim_out = default(dim_out, dim) - project_in = ( - nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) - if not glu - else GEGLU(dim, inner_dim) - ) - - self.net = nn.Sequential( - project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) - ) - - def forward(self, x): - return self.net(x) - - -class MemoryEfficientCrossAttention(nn.Module): - # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - def __init__( - self, - query_dim, - context_dim=None, - heads=8, - dim_head=64, - dropout=0.0, - ip_dim=0, - ip_weight=1, - ): - super().__init__() - - inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) - - self.heads = heads - self.dim_head = dim_head - - self.ip_dim = ip_dim - self.ip_weight = ip_weight - - if self.ip_dim > 0: - self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False) - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) - ) - self.attention_op = None - - def forward(self, x, context=None): - q = self.to_q(x) - context = default(context, x) - - if self.ip_dim > 0: - # context: [B, 77 + 16(ip), 1024] - token_len = context.shape[1] - context_ip = context[:, -self.ip_dim :, :] - k_ip = self.to_k_ip(context_ip) - v_ip = self.to_v_ip(context_ip) - context = context[:, : (token_len - self.ip_dim), :] - - k = self.to_k(context) - v = self.to_v(context) - - b, _, _ = q.shape - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(b, t.shape[1], self.heads, self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b * self.heads, t.shape[1], self.dim_head) - .contiguous(), - (q, k, v), - ) - - # actually compute the attention, what we cannot get enough of - out = xformers.ops.memory_efficient_attention( - q, k, v, attn_bias=None, op=self.attention_op - ) - - if self.ip_dim > 0: - k_ip, v_ip = map( - lambda t: t.unsqueeze(3) - .reshape(b, t.shape[1], self.heads, self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b * self.heads, t.shape[1], self.dim_head) - .contiguous(), - (k_ip, v_ip), - ) - # actually compute the attention, what we cannot get enough of - out_ip = xformers.ops.memory_efficient_attention( - q, k_ip, v_ip, attn_bias=None, op=self.attention_op - ) - out = out + self.ip_weight * out_ip - - out = ( - out.unsqueeze(0) - .reshape(b, self.heads, out.shape[1], self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b, out.shape[1], self.heads * self.dim_head) - ) - return self.to_out(out) - - -class BasicTransformerBlock3D(nn.Module): - - def __init__( - self, - dim, - n_heads, - d_head, - context_dim, - dropout=0.0, - gated_ff=True, - ip_dim=0, - ip_weight=1, - ): - super().__init__() - - self.attn1 = MemoryEfficientCrossAttention( - query_dim=dim, - context_dim=None, # self-attention - heads=n_heads, - dim_head=d_head, - dropout=dropout, - ) - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = MemoryEfficientCrossAttention( - query_dim=dim, - context_dim=context_dim, - heads=n_heads, - dim_head=d_head, - dropout=dropout, - # ip only applies to cross-attention - ip_dim=ip_dim, - ip_weight=ip_weight, - ) - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - self.norm3 = nn.LayerNorm(dim) - - def forward(self, x, context=None, num_frames=1): - x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous() - x = self.attn1(self.norm1(x), context=None) + x - x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous() - x = self.attn2(self.norm2(x), context=context) + x - x = self.ff(self.norm3(x)) + x - return x - - -class SpatialTransformer3D(nn.Module): - - def __init__( - self, - in_channels, - n_heads, - d_head, - context_dim, # cross attention input dim - depth=1, - dropout=0.0, - ip_dim=0, - ip_weight=1, - ): - super().__init__() - - if not isinstance(context_dim, list): - context_dim = [context_dim] - - self.in_channels = in_channels - - inner_dim = n_heads * d_head - self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - self.proj_in = nn.Linear(in_channels, inner_dim) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock3D( - inner_dim, - n_heads, - d_head, - context_dim=context_dim[d], - dropout=dropout, - ip_dim=ip_dim, - ip_weight=ip_weight, - ) - for d in range(depth) - ] - ) - - self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) - - - def forward(self, x, context=None, num_frames=1): - # note: if no context is given, cross-attention defaults to self-attention - if not isinstance(context, list): - context = [context] - b, c, h, w = x.shape - x_in = x - x = self.norm(x) - x = rearrange(x, "b c h w -> b (h w) c").contiguous() - x = self.proj_in(x) - for i, block in enumerate(self.transformer_blocks): - x = block(x, context=context[i], num_frames=num_frames) - x = self.proj_out(x) - x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() - - return x + x_in - - -class PerceiverAttention(nn.Module): - def __init__(self, *, dim, dim_head=64, heads=8): - super().__init__() - self.scale = dim_head ** -0.5 - self.dim_head = dim_head - self.heads = heads - inner_dim = dim_head * heads - - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - - self.to_q = nn.Linear(dim, inner_dim, bias=False) - self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) - self.to_out = nn.Linear(inner_dim, dim, bias=False) - - def forward(self, x, latents): - """ - Args: - x (torch.Tensor): image features - shape (b, n1, D) - latent (torch.Tensor): latent features - shape (b, n2, D) - """ - x = self.norm1(x) - latents = self.norm2(latents) - - b, l, _ = latents.shape - - q = self.to_q(latents) - kv_input = torch.cat((x, latents), dim=-2) - k, v = self.to_kv(kv_input).chunk(2, dim=-1) - - q, k, v = map( - lambda t: t.reshape(b, t.shape[1], self.heads, -1) - .transpose(1, 2) - .reshape(b, self.heads, t.shape[1], -1) - .contiguous(), - (q, k, v), - ) - - # attention - scale = 1 / math.sqrt(math.sqrt(self.dim_head)) - weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - out = weight @ v - - out = out.permute(0, 2, 1, 3).reshape(b, l, -1) - - return self.to_out(out) - - -class Resampler(nn.Module): - def __init__( - self, - dim=1024, - depth=8, - dim_head=64, - heads=16, - num_queries=8, - embedding_dim=768, - output_dim=1024, - ff_mult=4, - ): - super().__init__() - self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5) - self.proj_in = nn.Linear(embedding_dim, dim) - self.proj_out = nn.Linear(dim, output_dim) - self.norm_out = nn.LayerNorm(output_dim) - - self.layers = nn.ModuleList([]) - for _ in range(depth): - self.layers.append( - nn.ModuleList( - [ - PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), - nn.Sequential( - nn.LayerNorm(dim), - nn.Linear(dim, dim * ff_mult, bias=False), - nn.GELU(), - nn.Linear(dim * ff_mult, dim, bias=False), - ) - ] - ) - ) - - def forward(self, x): - latents = self.latents.repeat(x.size(0), 1, 1) - x = self.proj_in(x) - for attn, ff in self.layers: - latents = attn(x, latents) + latents - latents = ff(latents) + latents - - latents = self.proj_out(latents) - return self.norm_out(latents) - - -class CondSequential(nn.Sequential): - """ - A sequential module that passes timestep embeddings to the children that - support it as an extra input. - """ - - def forward(self, x, emb, context=None, num_frames=1): - for layer in self: - if isinstance(layer, ResBlock): - x = layer(x, emb) - elif isinstance(layer, SpatialTransformer3D): - x = layer(x, context, num_frames=num_frames) - else: - x = layer(x) - return x - - -class Upsample(nn.Module): - """ - An upsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - upsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - if use_conv: - self.conv = conv_nd( - dims, self.channels, self.out_channels, 3, padding=padding - ) - - def forward(self, x): - assert x.shape[1] == self.channels - if self.dims == 3: - x = F.interpolate( - x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" - ) - else: - x = F.interpolate(x, scale_factor=2, mode="nearest") - if self.use_conv: - x = self.conv(x) - return x - - -class Downsample(nn.Module): - """ - A downsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - downsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - stride = 2 if dims != 3 else (1, 2, 2) - if use_conv: - self.op = conv_nd( - dims, - self.channels, - self.out_channels, - 3, - stride=stride, - padding=padding, - ) - else: - assert self.channels == self.out_channels - self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) - - def forward(self, x): - assert x.shape[1] == self.channels - return self.op(x) - - -class ResBlock(nn.Module): - """ - A residual block that can optionally change the number of channels. - :param channels: the number of input channels. - :param emb_channels: the number of timestep embedding channels. - :param dropout: the rate of dropout. - :param out_channels: if specified, the number of out channels. - :param use_conv: if True and out_channels is specified, use a spatial - convolution instead of a smaller 1x1 convolution to change the - channels in the skip connection. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param up: if True, use this block for upsampling. - :param down: if True, use this block for downsampling. - """ - - def __init__( - self, - channels, - emb_channels, - dropout, - out_channels=None, - use_conv=False, - use_scale_shift_norm=False, - dims=2, - up=False, - down=False, - ): - super().__init__() - self.channels = channels - self.emb_channels = emb_channels - self.dropout = dropout - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_scale_shift_norm = use_scale_shift_norm - - self.in_layers = nn.Sequential( - nn.GroupNorm(32, channels), - nn.SiLU(), - conv_nd(dims, channels, self.out_channels, 3, padding=1), - ) - - self.updown = up or down - - if up: - self.h_upd = Upsample(channels, False, dims) - self.x_upd = Upsample(channels, False, dims) - elif down: - self.h_upd = Downsample(channels, False, dims) - self.x_upd = Downsample(channels, False, dims) - else: - self.h_upd = self.x_upd = nn.Identity() - - self.emb_layers = nn.Sequential( - nn.SiLU(), - nn.Linear( - emb_channels, - 2 * self.out_channels if use_scale_shift_norm else self.out_channels, - ), - ) - self.out_layers = nn.Sequential( - nn.GroupNorm(32, self.out_channels), - nn.SiLU(), - nn.Dropout(p=dropout), - zero_module( - conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) - ), - ) - - if self.out_channels == channels: - self.skip_connection = nn.Identity() - elif use_conv: - self.skip_connection = conv_nd( - dims, channels, self.out_channels, 3, padding=1 - ) - else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) - - def forward(self, x, emb): - if self.updown: - in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] - h = in_rest(x) - h = self.h_upd(h) - x = self.x_upd(x) - h = in_conv(h) - else: - h = self.in_layers(x) - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] - if self.use_scale_shift_norm: - out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = torch.chunk(emb_out, 2, dim=1) - h = out_norm(h) * (1 + scale) + shift - h = out_rest(h) - else: - h = h + emb_out - h = self.out_layers(h) - return self.skip_connection(x) + h - - -class MultiViewUNetModel(ModelMixin, ConfigMixin): - """ - The full multi-view UNet model with attention, timestep embedding and camera embedding. - :param in_channels: channels in the input Tensor. - :param model_channels: base channel count for the model. - :param out_channels: channels in the output Tensor. - :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. - :param dropout: the dropout probability. - :param channel_mult: channel multiplier for each level of the UNet. - :param conv_resample: if True, use learned convolutions for upsampling and - downsampling. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param num_classes: if specified (as an int), then this model will be - class-conditional with `num_classes` classes. - :param num_heads: the number of attention heads in each attention layer. - :param num_heads_channels: if specified, ignore num_heads and instead use - a fixed channel width per attention head. - :param num_heads_upsample: works with num_heads to set a different number - of heads for upsampling. Deprecated. - :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. - :param resblock_updown: use residual blocks for up/downsampling. - :param use_new_attention_order: use a different attention pattern for potentially - increased efficiency. - :param camera_dim: dimensionality of camera input. - """ - - def __init__( - self, - image_size, - in_channels, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - num_classes=None, - num_heads=-1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - transformer_depth=1, - context_dim=None, - n_embed=None, - num_attention_blocks=None, - adm_in_channels=None, - camera_dim=None, - ip_dim=0, # imagedream uses ip_dim > 0 - ip_weight=1.0, - **kwargs, - ): - super().__init__() - assert context_dim is not None - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - if num_heads == -1: - assert ( - num_head_channels != -1 - ), "Either num_heads or num_head_channels has to be set" - - if num_head_channels == -1: - assert ( - num_heads != -1 - ), "Either num_heads or num_head_channels has to be set" - - self.image_size = image_size - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - if isinstance(num_res_blocks, int): - self.num_res_blocks = len(channel_mult) * [num_res_blocks] - else: - if len(num_res_blocks) != len(channel_mult): - raise ValueError( - "provide num_res_blocks either as an int (globally constant) or " - "as a list/tuple (per-level) with the same length as channel_mult" - ) - self.num_res_blocks = num_res_blocks - - if num_attention_blocks is not None: - assert len(num_attention_blocks) == len(self.num_res_blocks) - assert all( - map( - lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], - range(len(num_attention_blocks)), - ) - ) - print( - f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " - f"This option has LESS priority than attention_resolutions {attention_resolutions}, " - f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " - f"attention will still not be set." - ) - - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.num_classes = num_classes - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - self.predict_codebook_ids = n_embed is not None - - self.ip_dim = ip_dim - self.ip_weight = ip_weight - - if self.ip_dim > 0: - self.image_embed = Resampler( - dim=context_dim, - depth=4, - dim_head=64, - heads=12, - num_queries=ip_dim, # num token - embedding_dim=1280, - output_dim=context_dim, - ff_mult=4, - ) - - time_embed_dim = model_channels * 4 - self.time_embed = nn.Sequential( - nn.Linear(model_channels, time_embed_dim), - nn.SiLU(), - nn.Linear(time_embed_dim, time_embed_dim), - ) - - if camera_dim is not None: - time_embed_dim = model_channels * 4 - self.camera_embed = nn.Sequential( - nn.Linear(camera_dim, time_embed_dim), - nn.SiLU(), - nn.Linear(time_embed_dim, time_embed_dim), - ) - - if self.num_classes is not None: - if isinstance(self.num_classes, int): - self.label_emb = nn.Embedding(self.num_classes, time_embed_dim) - elif self.num_classes == "continuous": - # print("setting up linear c_adm embedding layer") - self.label_emb = nn.Linear(1, time_embed_dim) - elif self.num_classes == "sequential": - assert adm_in_channels is not None - self.label_emb = nn.Sequential( - nn.Sequential( - nn.Linear(adm_in_channels, time_embed_dim), - nn.SiLU(), - nn.Linear(time_embed_dim, time_embed_dim), - ) - ) - else: - raise ValueError() - - self.input_blocks = nn.ModuleList( - [ - CondSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] - ) - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - for level, mult in enumerate(channel_mult): - for nr in range(self.num_res_blocks[level]): - layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=mult * model_channels, - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - - if num_attention_blocks is None or nr < num_attention_blocks[level]: - layers.append( - SpatialTransformer3D( - ch, - num_heads, - dim_head, - context_dim=context_dim, - depth=transformer_depth, - ip_dim=self.ip_dim, - ip_weight=self.ip_weight, - ) - ) - self.input_blocks.append(CondSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - CondSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - - self.middle_block = CondSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - ), - SpatialTransformer3D( - ch, - num_heads, - dim_head, - context_dim=context_dim, - depth=transformer_depth, - ip_dim=self.ip_dim, - ip_weight=self.ip_weight, - ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - ), - ) - self._feature_size += ch - - self.output_blocks = nn.ModuleList([]) - for level, mult in list(enumerate(channel_mult))[::-1]: - for i in range(self.num_res_blocks[level] + 1): - ich = input_block_chans.pop() - layers = [ - ResBlock( - ch + ich, - time_embed_dim, - dropout, - out_channels=model_channels * mult, - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = model_channels * mult - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - - if num_attention_blocks is None or i < num_attention_blocks[level]: - layers.append( - SpatialTransformer3D( - ch, - num_heads, - dim_head, - context_dim=context_dim, - depth=transformer_depth, - ip_dim=self.ip_dim, - ip_weight=self.ip_weight, - ) - ) - if level and i == self.num_res_blocks[level]: - out_ch = ch - layers.append( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - up=True, - ) - if resblock_updown - else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) - ) - ds //= 2 - self.output_blocks.append(CondSequential(*layers)) - self._feature_size += ch - - self.out = nn.Sequential( - nn.GroupNorm(32, ch), - nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), - ) - if self.predict_codebook_ids: - self.id_predictor = nn.Sequential( - nn.GroupNorm(32, ch), - conv_nd(dims, model_channels, n_embed, 1), - # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits - ) - - def forward( - self, - x, - timesteps=None, - context=None, - y=None, - camera=None, - num_frames=1, - ip=None, - ip_img=None, - **kwargs, - ): - """ - Apply the model to an input batch. - :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views). - :param timesteps: a 1-D batch of timesteps. - :param context: conditioning plugged in via crossattn - :param y: an [N] Tensor of labels, if class-conditional. - :param num_frames: a integer indicating number of frames for tensor reshaping. - :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views). - """ - assert ( - x.shape[0] % num_frames == 0 - ), "input batch size must be dividable by num_frames!" - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" - - hs = [] - - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) - - emb = self.time_embed(t_emb) - - if self.num_classes is not None: - assert y is not None - assert y.shape[0] == x.shape[0] - emb = emb + self.label_emb(y) - - # Add camera embeddings - if camera is not None: - emb = emb + self.camera_embed(camera) - - # imagedream variant - if self.ip_dim > 0: - x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9] - ip_emb = self.image_embed(ip) - context = torch.cat((context, ip_emb), 1) - - h = x - for module in self.input_blocks: - h = module(h, emb, context, num_frames=num_frames) - hs.append(h) - h = self.middle_block(h, emb, context, num_frames=num_frames) - for module in self.output_blocks: - h = torch.cat([h, hs.pop()], dim=1) - h = module(h, emb, context, num_frames=num_frames) - h = h.type(x.dtype) - if self.predict_codebook_ids: - return self.id_predictor(h) - else: - return self.out(h) \ No newline at end of file diff --git a/apps/third_party/LGM/pipeline_mvdream.py b/apps/third_party/LGM/pipeline_mvdream.py deleted file mode 100644 index cf66bd2683583c07418bf479df70ce14ff4d9185..0000000000000000000000000000000000000000 --- a/apps/third_party/LGM/pipeline_mvdream.py +++ /dev/null @@ -1,557 +0,0 @@ -import torch -import torch.nn.functional as F -import inspect -import numpy as np -from typing import Callable, List, Optional, Union -from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPImageProcessor -from diffusers import AutoencoderKL, DiffusionPipeline -from diffusers.utils import ( - deprecate, - is_accelerate_available, - is_accelerate_version, - logging, -) -from diffusers.configuration_utils import FrozenDict -from diffusers.schedulers import DDIMScheduler -from diffusers.utils.torch_utils import randn_tensor - -from apps.third_party.LGM.mv_unet import MultiViewUNetModel, get_camera - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class MVDreamPipeline(DiffusionPipeline): - - _optional_components = ["feature_extractor", "image_encoder"] - - def __init__( - self, - vae: AutoencoderKL, - unet: MultiViewUNetModel, - tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModel, - scheduler: DDIMScheduler, - # imagedream variant - feature_extractor: CLIPImageProcessor, - image_encoder: CLIPVisionModel, - requires_safety_checker: bool = False, - ): - super().__init__() - - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: # type: ignore - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " # type: ignore - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate( - "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False - ) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: # type: ignore - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate( - "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False - ) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - self.register_modules( - vae=vae, - unet=unet, - scheduler=scheduler, - tokenizer=tokenizer, - text_encoder=text_encoder, - feature_extractor=feature_extractor, - image_encoder=image_encoder, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. - - When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several - steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. - - When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in - several steps. This is useful to save a large amount of memory and to allow the processing of larger images. - """ - self.vae.enable_tiling() - - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - Note that offloading happens on a submodule basis. Memory savings are higher than with - `enable_model_cpu_offload`, but performance is lower. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): - from accelerate import cpu_offload - else: - raise ImportError( - "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher" - ) - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: - cpu_offload(cpu_offloaded_model, device) - - def enable_model_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared - to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` - method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with - `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate import cpu_offload_with_hook - else: - raise ImportError( - "`enable_model_offload` requires `accelerate v0.17.0` or higher." - ) - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - hook = None - for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: - _, hook = cpu_offload_with_hook( - cpu_offloaded_model, device, prev_module_hook=hook - ) - - # We'll offload the last model manually. - self.final_offload_hook = hook - - @property - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance: bool, - negative_prompt=None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError( - f"`prompt` should be either a string or a list of strings, but got {type(prompt)}." - ) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if ( - hasattr(self.text_encoder.config, "use_attention_mask") - and self.text_encoder.config.use_attention_mask - ): - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view( - bs_embed * num_images_per_prompt, seq_len, -1 - ) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if ( - hasattr(self.text_encoder.config, "use_attention_mask") - and self.text_encoder.config.use_attention_mask - ): - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to( - dtype=self.text_encoder.dtype, device=device - ) - - negative_prompt_embeds = negative_prompt_embeds.repeat( - 1, num_images_per_prompt, 1 - ) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def prepare_latents( - self, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - ): - shape = ( - batch_size, - num_channels_latents, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor( - shape, generator=generator, device=device, dtype=dtype - ) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - def encode_image(self, image, device, num_images_per_prompt): - dtype = next(self.image_encoder.parameters()).dtype - - if image.dtype == np.float32: - image = (image * 255).astype(np.uint8) - - image = self.feature_extractor(image, return_tensors="pt").pixel_values - image = image.to(device=device, dtype=dtype) - - image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - - return torch.zeros_like(image_embeds), image_embeds - - def encode_image_latents(self, image, device, num_images_per_prompt): - - dtype = next(self.image_encoder.parameters()).dtype - - image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device) # [1, 3, H, W] - image = 2 * image - 1 - image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False) - image = image.to(dtype=dtype) - - posterior = self.vae.encode(image).latent_dist - latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W] - latents = latents.repeat_interleave(num_images_per_prompt, dim=0) - - return torch.zeros_like(latents), latents - - @torch.no_grad() - def __call__( - self, - prompt: str = "", - image: Optional[np.ndarray] = None, - height: int = 256, - width: int = 256, - elevation: float = 0, - num_inference_steps: int = 50, - guidance_scale: float = 7.0, - negative_prompt: str = "", - num_images_per_prompt: int = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - output_type: Optional[str] = "numpy", # pil, numpy, latents - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - num_frames: int = 4, - device=torch.device("cuda:0"), - ): - self.unet = self.unet.to(device=device) - self.vae = self.vae.to(device=device) - self.text_encoder = self.text_encoder.to(device=device) - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # imagedream variant - if image is not None: - assert isinstance(image, np.ndarray) and image.dtype == np.float32 - self.image_encoder = self.image_encoder.to(device=device) - image_embeds_neg, image_embeds_pos = self.encode_image(image, device, num_images_per_prompt) - image_latents_neg, image_latents_pos = self.encode_image_latents(image, device, num_images_per_prompt) - - _prompt_embeds = self._encode_prompt( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, - negative_prompt=negative_prompt, - ) # type: ignore - prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2) - - # Prepare latent variables - actual_num_frames = num_frames if image is None else num_frames + 1 - latents: torch.Tensor = self.prepare_latents( - actual_num_frames * num_images_per_prompt, - 4, - height, - width, - prompt_embeds_pos.dtype, - device, - generator, - None, - ) - - # Get camera - camera = get_camera(num_frames, elevation=elevation, extra_view=(image is not None)).to(dtype=latents.dtype, device=device) - camera = camera.repeat_interleave(num_images_per_prompt, dim=0) - - # Prepare extra step kwargs. - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - multiplier = 2 if do_classifier_free_guidance else 1 - latent_model_input = torch.cat([latents] * multiplier) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - unet_inputs = { - 'x': latent_model_input, - 'timesteps': torch.tensor([t] * actual_num_frames * multiplier, dtype=latent_model_input.dtype, device=device), - 'context': torch.cat([prompt_embeds_neg] * actual_num_frames + [prompt_embeds_pos] * actual_num_frames), - 'num_frames': actual_num_frames, - 'camera': torch.cat([camera] * multiplier), - } - - if image is not None: - unet_inputs['ip'] = torch.cat([image_embeds_neg] * actual_num_frames + [image_embeds_pos] * actual_num_frames) - unet_inputs['ip_img'] = torch.cat([image_latents_neg] + [image_latents_pos]) # no repeat - - # predict the noise residual - noise_pred = self.unet.forward(**unet_inputs) - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - - # compute the previous noisy sample x_t -> x_t-1 - latents: torch.Tensor = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs, return_dict=False - )[0] - - # call the callback, if provided - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) # type: ignore - - # Post-processing - if output_type == "latent": - image = latents - elif output_type == "pil": - image = self.decode_latents(latents) - image = self.numpy_to_pil(image) - else: # numpy - image = self.decode_latents(latents) - - # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() - - return image diff --git a/apps/third_party/LGM/requirements.lock.txt b/apps/third_party/LGM/requirements.lock.txt deleted file mode 100644 index 5c6c4b55fefc0710ff4971eed9b0122da8c985da..0000000000000000000000000000000000000000 --- a/apps/third_party/LGM/requirements.lock.txt +++ /dev/null @@ -1,7 +0,0 @@ -omegaconf == 2.3.0 -diffusers == 0.23.1 -safetensors == 0.4.1 -huggingface_hub == 0.19.4 -transformers == 4.35.2 -accelerate == 0.25.0.dev0 -kiui == 0.2.0 \ No newline at end of file diff --git a/apps/third_party/LGM/requirements.txt b/apps/third_party/LGM/requirements.txt deleted file mode 100644 index ed8bdbc6697ac3bceef7cd610987cd0ba5777147..0000000000000000000000000000000000000000 --- a/apps/third_party/LGM/requirements.txt +++ /dev/null @@ -1,9 +0,0 @@ -omegaconf -diffusers -safetensors -huggingface_hub -transformers -accelerate -kiui -einops -rich diff --git a/apps/third_party/LGM/run_imagedream.py b/apps/third_party/LGM/run_imagedream.py deleted file mode 100644 index 258a45afb045e4b7764f6903c177d757f5f00822..0000000000000000000000000000000000000000 --- a/apps/third_party/LGM/run_imagedream.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch -import kiui -import numpy as np -import argparse -from pipeline_mvdream import MVDreamPipeline -import ipdb -pipe = MVDreamPipeline.from_pretrained( - # "./weights_imagedream", # local weights - "/mnt/cfs/home/liweiyu/codes/3DNativeGeneration/ckpts/pretrained_weights/huggingface/hub/models--ashawkey--imagedream-ipmv-diffusers/snapshots/73a034178e748421506492e91790cc62d6aefef5", # remote weights - torch_dtype=torch.float16, - trust_remote_code=True, -) -pipe = pipe.to("cuda") - - -parser = argparse.ArgumentParser(description="ImageDream") -parser.add_argument("image", type=str, default='data/anya_rgba.png') -parser.add_argument("--prompt", type=str, default="") -args = parser.parse_args() - -for i in range(5): - input_image = kiui.read_image(args.image, mode='float') - image = pipe(args.prompt, input_image, guidance_scale=5, num_inference_steps=30, elevation=0) - ipdb.set_trace() - # print(image) - grid = np.concatenate( - [ - np.concatenate([image[0], image[2]], axis=0), - np.concatenate([image[1], image[3]], axis=0), - ], - axis=1, - ) - # kiui.vis.plot_image(grid) - kiui.write_image(f'test_imagedream_{i}.jpg', grid) diff --git a/apps/third_party/LGM/run_mvdream.py b/apps/third_party/LGM/run_mvdream.py deleted file mode 100644 index c80b48e1f30a2ab85448f4c6a89f2cd47737404b..0000000000000000000000000000000000000000 --- a/apps/third_party/LGM/run_mvdream.py +++ /dev/null @@ -1,33 +0,0 @@ -import torch -import kiui -import numpy as np -import argparse -from pipeline_mvdream import MVDreamPipeline - -import ipdb -pipe = MVDreamPipeline.from_pretrained( - # "./weights_mvdream", # local weights - '/mnt/cfs/home/liweiyu/codes/3DNativeGeneration/ckpts/pretrained_weights/huggingface/hub/models--ashawkey--mvdream-sd2.1-diffusers/snapshots/503bb19fc2b2bc542c2afdb7d73ac87a7cbc2253', # remote weights - torch_dtype=torch.float16, - # trust_remote_code=True, -) - -pipe = pipe.to("cuda") - - -parser = argparse.ArgumentParser(description="MVDream") -parser.add_argument("prompt", type=str, default="a cute owl 3d model") -args = parser.parse_args() - -for i in range(5): - image = pipe(args.prompt, guidance_scale=5, num_inference_steps=30, elevation=0) - ipdb.set_trace() - grid = np.concatenate( - [ - np.concatenate([image[0], image[2]], axis=0), - np.concatenate([image[1], image[3]], axis=0), - ], - axis=1, - ) - # kiui.vis.plot_image(grid) - kiui.write_image(f'test_mvdream_{i}.jpg', grid) diff --git a/apps/utils.py b/apps/utils.py deleted file mode 100644 index 9b4a2e90001d1377599d699a0abca40ca08668c9..0000000000000000000000000000000000000000 --- a/apps/utils.py +++ /dev/null @@ -1,175 +0,0 @@ -from typing import Dict, Optional, Tuple, List -from dataclasses import dataclass -import os -import sys -proj_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(os.path.join(proj_dir)) -import time -import cv2 -import gradio as gr -import numpy as np -import torch -import PIL -from PIL import Image -import rembg -from rembg import remove -rembg_session = rembg.new_session() -from segment_anything import sam_model_registry, SamPredictor - -import craftsman -from craftsman.systems.base import BaseSystem -from craftsman.utils.config import ExperimentConfig, load_config - -parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - -def check_input_image(input_image): - if input_image is None: - raise gr.Error("No image uploaded!") - -def load_model( - ckpt_path: str, - config_path: str, - device = "cuda" - ): - cfg: ExperimentConfig - cfg = load_config(config_path) - - if 'pretrained_model_name_or_path' not in cfg.system.condition_model or cfg.system.condition_model.pretrained_model_name_or_path is None: - cfg.system.condition_model.config_path = config_path.replace("config.yaml", "clip_config.json") - - system: BaseSystem = craftsman.find(cfg.system_type)( - cfg.system, - ) - - print(f"Restoring states from the checkpoint path at {ckpt_path} with config {cfg}") - system.load_state_dict(torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict']) - system = system.to(device).eval() - - return system - -class RMBG(object): - def __init__(self, device): - sam = sam_model_registry["vit_h"](checkpoint=f"{parent_dir}/ckpts/SAM/sam_vit_h_4b8939.pth").to(device) - self.predictor = SamPredictor(sam) - - def rmbg_sam(self, input_image): - def _sam_segment(predictor, input_image, *bbox_coords): - bbox = np.array(bbox_coords) - image = np.asarray(input_image) - - start_time = time.time() - predictor.set_image(image) - - masks_bbox, scores_bbox, logits_bbox = predictor.predict(box=bbox, multimask_output=True) - - print(f"SAM Time: {time.time() - start_time:.3f}s") - out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8) - out_image[:, :, :3] = image - out_image_bbox = out_image.copy() - out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 - torch.cuda.empty_cache() - return Image.fromarray(out_image_bbox, mode='RGBA') - - RES = 1024 - input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS) - - image_rem = input_image.convert('RGBA') - image_nobg = remove(image_rem, alpha_matting=True) - arr = np.asarray(image_nobg)[:, :, -1] - x_nonzero = np.nonzero(arr.sum(axis=0)) - y_nonzero = np.nonzero(arr.sum(axis=1)) - x_min = int(x_nonzero[0].min()) - y_min = int(y_nonzero[0].min()) - x_max = int(x_nonzero[0].max()) - y_max = int(y_nonzero[0].max()) - return _sam_segment(self.predictor, input_image.convert('RGB'), x_min, y_min, x_max, y_max) - - def rmbg_rembg(self, input_image): - def _rembg_remove( - image: PIL.Image.Image, - rembg_session = None, - force: bool = False, - **rembg_kwargs, - ) -> PIL.Image.Image: - do_remove = True - if image.mode == "RGBA" and image.getextrema()[3][0] < 255: - # explain why current do not rm bg - print("alhpa channl not enpty, skip remove background, using alpha channel as mask") - background = Image.new("RGBA", image.size, (0, 0, 0, 0)) - image = Image.alpha_composite(background, image) - do_remove = False - do_remove = do_remove or force - if do_remove: - image = rembg.remove(image, session=rembg_session, **rembg_kwargs) - return image - return _rembg_remove(input_image, rembg_session, force_remove=True) - - def run(self, rm_type, image, foreground_ratio, background_choice, backgroud_color): - # image = cv2.resize(np.array(image), (crop_size, crop_size)) - # image = Image.fromarray(image) - - if background_choice == "Alpha as mask": - background = Image.new("RGBA", image.size, (backgroud_color[0], backgroud_color[1], backgroud_color[2], 0)) - return Image.alpha_composite(background, image) - elif "Remove" in background_choice: - if rm_type.upper() == "SAM": - image = self.rmbg_sam(image) - elif rm_type.upper() == "REMBG": - image = self.rmbg_rembg(image) - else: - return -1 - - image = do_resize_content(image, foreground_ratio) - image = expand_to_square(image) - # image = add_background(image, backgroud_color) - # return image.convert("RGB") - return image - - elif "Original" in background_choice: - return image - else: - return -1 - -def do_resize_content(original_image: Image, scale_rate): - # resize image content wile retain the original image size - if scale_rate != 1: - # Calculate the new size after rescaling - new_size = tuple(int(dim * scale_rate) for dim in original_image.size) - # Resize the image while maintaining the aspect ratio - resized_image = original_image.resize(new_size) - # Create a new image with the original size and black background - padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0)) - paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2) - padded_image.paste(resized_image, paste_position) - return padded_image - else: - return original_image - -def expand2square(pil_img, background_color): - width, height = pil_img.size - if width == height: - return pil_img - elif width > height: - result = Image.new(pil_img.mode, (width, width), background_color) - result.paste(pil_img, (0, (width - height) // 2)) - return result - else: - result = Image.new(pil_img.mode, (height, height), background_color) - result.paste(pil_img, ((height - width) // 2, 0)) - return result - -def expand_to_square(image, bg_color=(0, 0, 0, 0)): - # expand image to 1:1 - width, height = image.size - if width == height: - return image - new_size = (max(width, height), max(width, height)) - new_image = Image.new("RGBA", new_size, bg_color) - paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2) - new_image.paste(image, paste_position) - return new_image - -def add_background(image, bg_color=(255, 255, 255)): - # given an RGBA image, alpha channel is used as mask to add background color - background = Image.new("RGBA", image.size, bg_color) - return Image.alpha_composite(background, image) \ No newline at end of file diff --git a/asset/demo_result.png b/asset/demo_result.png new file mode 100644 index 0000000000000000000000000000000000000000..b00881e42677fdd54df3a22624e7ea5233c58163 Binary files /dev/null and b/asset/demo_result.png differ diff --git a/ckpts/CRM/pixel-diffusion.pth b/ckpts/CRM/pixel-diffusion.pth deleted file mode 100644 index d5540db4a739a1cb7f601735cdaa2ab44751cc0e..0000000000000000000000000000000000000000 --- a/ckpts/CRM/pixel-diffusion.pth +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:27445c0202f51cb7e6715978e2d7811d1f9a276adae0bc7af8bf7191dd2e71e0 -size 6162765693 diff --git a/ckpts/SAM/sam_vit_h_4b8939.pth b/ckpts/SAM/sam_vit_h_4b8939.pth deleted file mode 100755 index 8523acce9ddab1cf7e355628a08b1aab8ce08a72..0000000000000000000000000000000000000000 --- a/ckpts/SAM/sam_vit_h_4b8939.pth +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e -size 2564550879 diff --git a/configs/image-to-shape-diffusion/clip-dino-rgb-pixart-lr2e4-ddim.yaml b/configs/image-to-shape-diffusion/clip-dino-rgb-pixart-lr2e4-ddim.yaml new file mode 100755 index 0000000000000000000000000000000000000000..4e2aa03d6454c8ce4d36b56541019b33ec564393 --- /dev/null +++ b/configs/image-to-shape-diffusion/clip-dino-rgb-pixart-lr2e4-ddim.yaml @@ -0,0 +1,149 @@ +exp_root_dir: "outputs" +name: "image-to-shape-diffusion/clip-dino-rgb-pixart-lr2e4-ddim" +tag: "${rmspace:${system.shape_model_type}+n${data.n_samples}+pfeat${system.shape_model.point_feats}+lr${system.optimizer.args.lr},_}" +seed: 0 + +data_type: "objaverse-datamodule" +data: + root_dir: ./data/objaverse + data_type: "sdf" + sampling_strategy: random + n_samples: 10240 + + load_supervision: False + supervision_type: "" + n_supervision: 0 + + load_image: True # whether to load images + image_data_path: ./data/objaverse/render+blender+singleview+nv20 + image_type: "rgb" # rgb, normal + idx: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] # front view + n_views: 1 + background_color: [0.5, 0.5, 0.5] + marign_pix_dis: 30 + + batch_size: 40 + num_workers: 16 + +system_type: "pixart-diffusion-system" +system: + val_samples_json: "val_data/images/val_samples_rgb_image.json" + z_scale_factor: 1.0 + guidance_scale: 7.5 + num_inference_steps: 50 + eta: 0.0 + extract_mesh_func: diffdmc + + shape_model_type: michelangelo-autoencoder + shape_model: + pretrained_model_name_or_path: /mnt/cfs/public/native3D/ckpts/michelangelo-autoencoder-l256-e64-ne8-nd16-scaleup.ckpt + use_downsample: true + downsample_ratio: 0.0625 + num_latents: 768 + use_multi_reso: false + resolutions: [4096, 8192, 12288] + sampling_prob: [0, 0, 1] + embed_dim: 64 + point_feats: 3 + out_dim: 1 + num_freqs: 8 + include_pi: false + heads: 12 + width: 768 + num_encoder_layers: 8 + num_decoder_layers: 16 + use_ln_post: true + init_scale: 0.25 + qkv_bias: false + use_flash: true + use_checkpoint: true + + + condition_model_type: "cond-embedder" + condition_model: + pretrained_clip_name_or_path: openai/clip-vit-large-patch14 + pretrained_dino_name_or_path: facebook/dinov2-base + pretrained_tokenizer_name_or_path: openai/clip-vit-large-patch14 + freeze_modulation_clip: true + freeze_modulation_dino: true + encode_camera: false + camera_embeds_dim: 0 + n_views: ${data.n_views} + empty_embeds_ratio: 0.1 + normalize_embeds: false + zero_uncond_embeds: true + linear_proj_init: constant + image_size_dino: 224 + image_size_clip: 224 + + denoiser_model_type: "pixart-denoiser" + denoiser_model: + input_channels: ${system.shape_model.embed_dim} + output_channels: ${system.shape_model.embed_dim} + n_ctx: ${system.shape_model.num_latents} + width: 768 + layers: 32 + heads: 12 + context_dim: 1024 + init_scale: 1.0 + skip_ln: true + variance_type: ${system.noise_scheduler.variance_type} + use_checkpoint: true + dit_block: DiTBlock + + noise_scheduler_type: "diffusers.schedulers.DDPMScheduler" + noise_scheduler: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "scaled_linear" + variance_type: "fixed_small" + clip_sample: false + + denoise_scheduler_type: "diffusers.schedulers.DDIMScheduler" + denoise_scheduler: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "scaled_linear" + clip_sample: false # clip sample to -1~1 + set_alpha_to_one: false + steps_offset: 1 + + loggers: + wandb: + enable: false + project: "CraftsMan" + name: image-to-shape-diffusion+${name}+${tag} + + loss: + loss_type: "mse" + lambda_diffusion: 1. + + optimizer: + name: AdamW + args: + lr: 2.e-4 + betas: [0.9, 0.99] + eps: 1.e-6 + + scheduler: + name: CosineAnnealingLR + args: + T_max: 5000 + eta_min: 1e-6 + +trainer: + num_nodes: 1 + max_epochs: 100000 + log_every_n_steps: 5 + num_sanity_val_steps: 1 + check_val_every_n_epoch: 25 + enable_progress_bar: true + precision: 16-mixed + strategy: 'ddp_find_unused_parameters_true' + +checkpoint: + save_last: true + save_top_k: -1 + every_n_train_steps: 5000 \ No newline at end of file diff --git a/configs/image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6.yaml b/configs/image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6.yaml deleted file mode 100755 index e7b649eeb566ed3c0fc59e9b1c423ffd7ea2c3ab..0000000000000000000000000000000000000000 --- a/configs/image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6.yaml +++ /dev/null @@ -1,148 +0,0 @@ -exp_root_dir: "outputs" -name: "image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6" -tag: "${rmspace:${system.shape_model_type}+n${data.n_samples}+noise${data.noise_sigma}+pfeat${system.shape_model.point_feats}+normemb${system.condition_model.normalize_embeds}+lr${system.optimizer.args.lr}+qkvbias${system.shape_model.qkv_bias}+nfreq${system.shape_model.num_freqs}+ln_post${system.shape_model.use_ln_post},_}" -seed: 0 - -data_type: "objaverse-datamodule" -data: - root_dir: "data/objaverse_clean/cap3d_high_quality_170k_images" - data_type: "occupancy" - n_samples: 4096 - noise_sigma: 0. - - load_supervision: False - supervision_type: "occupancy" - n_supervision: 4096 - - load_image: True # whether to load images - image_data_path: data/objaverse_clean/raw_data/images/cap3d_high_quality_170k - image_type: "mvrgb" # rgb, normal, mvrgb, mvnormal - idx: [0, 4, 8, 12, 16] - n_views: 4 - load_caption: False # whether to load captions - rotate_points: False - - batch_size: 32 - num_workers: 16 - -system_type: "shape-diffusion-system" -system: - val_samples_json: "val_data/mv_images/val_samples_rgb_mvimage.json" - z_scale_factor: 1.0 - guidance_scale: 7.5 - num_inference_steps: 50 - eta: 0.0 - - shape_model_type: "michelangelo-autoencoder" - shape_model: - # pretrained_model_name_or_path: ./ckpts/3DNativeGeneration/michelangelo-aligned-autoencoder-l256-e64-ne8-nd16.ckpt - pretrained_model_name_or_path: "./outputs/image-to-shape-diffusion_bak/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/michelangelo-autoencoder+n4096+noise0.0+pfeat3+normembFalse+lr5e-05+qkvbiasFalse+nfreq8+ln_postTrue/ckpts/last.ckpt" - num_latents: 256 - embed_dim: 64 - point_feats: 3 # xyz + normal - out_dim: 1 # only occupancy - num_freqs: 8 - include_pi: false - heads: 12 - width: 768 - num_encoder_layers: 8 - num_decoder_layers: 16 - use_ln_post: true - init_scale: 0.25 - qkv_bias: false - use_flash: true - use_checkpoint: true - - condition_model_type: "clip-embedder" - condition_model: - pretrained_model_name_or_path: "./ckpts/pretrained_weights/huggingface/hub/models--openai--clip-vit-large-patch14/snapshots/8d052a0f05efbaefbc9e8786ba291cfdf93e5bff" - encode_camera: true - camera_embeds_dim: 32 # 16 * 2[sin, cos] - n_views: ${data.n_views} - empty_embeds_ratio: 0.1 - normalize_embeds: false - # zero_uncond_embeds: true - zero_uncond_embeds: false - - denoiser_model_type: "simple-denoiser" - denoiser_model: - # pretrained_model_name_or_path: "./ckpts/CraftsMan/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6.pth" - pretrained_model_name_or_path: "./ckpts/CraftsMan/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6-It500000.pth" - input_channels: ${system.shape_model.embed_dim} - output_channels: ${system.shape_model.embed_dim} - n_ctx: ${system.shape_model.num_latents} - width: 768 - layers: 6 # 2 * 6 + 1 = 13 - heads: 12 - context_dim: 1024 - init_scale: 1.0 - skip_ln: true - use_checkpoint: true - - noise_scheduler_type: "diffusers.schedulers.DDPMScheduler" - noise_scheduler: - num_train_timesteps: 1000 - beta_start: 0.00085 - beta_end: 0.012 - beta_schedule: "scaled_linear" - variance_type: "fixed_small" - clip_sample: false - - denoise_scheduler_type: "diffusers.schedulers.DDIMScheduler" - denoise_scheduler: - num_train_timesteps: 1000 - beta_start: 0.00085 - beta_end: 0.012 - beta_schedule: "scaled_linear" - clip_sample: false # clip sample to -1~1 - set_alpha_to_one: false - steps_offset: 1 - - loggers: - wandb: - enable: false - project: "CraftsMan" - name: image-to-shape-diffusion+${name}+${tag} - - loss: - loss_type: "mse" - lambda_diffusion: 1. - - optimizer: - name: AdamW - args: - lr: 5.e-5 - betas: [0.9, 0.99] - eps: 1.e-6 - - scheduler: - name: SequentialLR - interval: step - schedulers: - - name: LinearLR - interval: step - args: - start_factor: 1e-6 - end_factor: 1.0 - total_iters: 5000 - - name: CosineAnnealingLR - interval: step - args: - T_max: 5000 - eta_min: 0. - milestones: [5000] - -trainer: - num_nodes: 1 - max_epochs: 100000 - log_every_n_steps: 5 - num_sanity_val_steps: 1 - check_val_every_n_epoch: 3 - enable_progress_bar: true - precision: 16-mixed - strategy: 'ddp_find_unused_parameters_true' - -checkpoint: - save_last: true - save_top_k: -1 - every_n_train_steps: 5000 \ No newline at end of file diff --git a/configs/shape-autoencoder/l256-e64-ne8-nd16.yaml b/configs/shape-autoencoder/l256-e64-ne8-nd16.yaml deleted file mode 100755 index cf779a3a7ad958e51ed6ea898ede9a497338f676..0000000000000000000000000000000000000000 --- a/configs/shape-autoencoder/l256-e64-ne8-nd16.yaml +++ /dev/null @@ -1,95 +0,0 @@ -exp_root_dir: "outputs" -name: "michelangelo-autoencoder/l256-e64-ne8-nd16" -tag: "${rmspace:n${data.n_samples}+${data.supervision_type}+rot${data.rotate}+noise${data.noise_sigma}+${system.shape_model.embed_type}+dsample${system.shape_model.use_downsample}+pfeat${system.shape_model.point_feats}+logits${system.loss.lambda_logits}+kl${system.loss.lambda_kl}+lr${system.optimizer.args.lr},_}" -seed: 0 - -data_type: "objaverse-datamodule" -data: - root_dir: "data/objaverse_clean/sdf_100k" - data_type: "sdf" - n_samples: 4096 - noise_sigma: 0. - rotate: False - - load_supervision: True - supervision_type: "occupancy" - n_supervision: 4096 - - load_image: False # whether to load images - load_caption: False # whether to load captions - - batch_size: 128 - num_workers: 16 - -system_type: "shape-autoencoder-system" -system: - sample_posterior: true - - shape_model_type: "michelangelo-autoencoder" - shape_model: - num_latents: 256 # 256 - embed_dim: 64 - point_feats: 3 # xyz + normal - out_dim: 1 # only occupancy - embed_type: "fourier" - num_freqs: 8 - include_pi: false - heads: 12 - width: 768 - num_encoder_layers: 8 - num_decoder_layers: 16 - use_ln_post: true - init_scale: 0.25 - qkv_bias: true - use_flash: true - use_checkpoint: true - use_downsample: true - - loggers: - wandb: - enable: false - project: "CraftsMan" - name: shape-autoencoder+${name}+${tag} - - loss: - lambda_logits: 1. - lambda_kl: 0.001 - - optimizer: - name: AdamW - args: - lr: 1.e-4 - betas: [0.9, 0.99] - eps: 1.e-6 - - scheduler: - name: SequentialLR - interval: step - schedulers: - - name: LinearLR - interval: step - args: - start_factor: 1e-6 - end_factor: 1.0 - total_iters: 5000 - - name: CosineAnnealingLR - interval: step - args: - T_max: 5000 - eta_min: 0. - milestones: [5000] - -trainer: - num_nodes: 1 - max_epochs: 100000 - log_every_n_steps: 5 - num_sanity_val_steps: 1 - # val_check_interval: 200 - check_val_every_n_epoch: 10 - enable_progress_bar: true - precision: 16-mixed - -checkpoint: - save_last: true - save_top_k: -1 - every_n_train_steps: 5000 \ No newline at end of file diff --git a/configs/shape-autoencoder/l512-e64-ne8-nd16.yaml b/configs/shape-autoencoder/l512-e64-ne8-nd16.yaml deleted file mode 100755 index e32ef5272757c4f4c8d0f3067a1c029818543475..0000000000000000000000000000000000000000 --- a/configs/shape-autoencoder/l512-e64-ne8-nd16.yaml +++ /dev/null @@ -1,95 +0,0 @@ -exp_root_dir: "outputs" -name: "michelangelo-autoencoder/l512-e64-ne8-nd16" -tag: "${rmspace:n${data.n_samples}+${data.supervision_type}+rot${data.rotate}+noise${data.noise_sigma}+${system.shape_model.embed_type}+dsample${system.shape_model.use_downsample}+pfeat${system.shape_model.point_feats}+logits${system.loss.lambda_logits}+kl${system.loss.lambda_kl}+lr${system.optimizer.args.lr},_}" -seed: 0 - -data_type: "objaverse-datamodule" -data: - root_dir: "data/objaverse_clean/sdf_100k" - data_type: "sdf" - n_samples: 4096 - noise_sigma: 0. - rotate: False - - load_supervision: True - supervision_type: "occupancy" - n_supervision: 4096 - - load_image: False # whether to load images - load_caption: False # whether to load captions - - batch_size: 128 - num_workers: 16 - -system_type: "shape-autoencoder-system" -system: - sample_posterior: true - - shape_model_type: "michelangelo-autoencoder" - shape_model: - num_latents: 512 # 512 - embed_dim: 64 - point_feats: 3 # xyz + normal - out_dim: 1 # only occupancy - embed_type: "fourier" - num_freqs: 8 - include_pi: false - heads: 12 - width: 768 - num_encoder_layers: 8 - num_decoder_layers: 16 - use_ln_post: true - init_scale: 0.25 - qkv_bias: true - use_flash: true - use_checkpoint: true - use_downsample: true - - loggers: - wandb: - enable: false - project: "CraftsMan" - name: shape-autoencoder+${name}+${tag} - - loss: - lambda_logits: 1. - lambda_kl: 0.001 - - optimizer: - name: AdamW - args: - lr: 1.e-4 - betas: [0.9, 0.99] - eps: 1.e-6 - - scheduler: - name: SequentialLR - interval: step - schedulers: - - name: LinearLR - interval: step - args: - start_factor: 1e-6 - end_factor: 1.0 - total_iters: 5000 - - name: CosineAnnealingLR - interval: step - args: - T_max: 5000 - eta_min: 0. - milestones: [5000] - -trainer: - num_nodes: 1 - max_epochs: 100000 - log_every_n_steps: 5 - num_sanity_val_steps: 1 - # val_check_interval: 200 - check_val_every_n_epoch: 10 - enable_progress_bar: true - precision: 16-mixed - -checkpoint: - save_last: true - save_top_k: -1 - every_n_train_steps: 5000 \ No newline at end of file diff --git a/configs/shape-autoencoder/l1024-e64-ne8-nd16.yaml b/configs/shape-autoencoder/michelangelo-l768-e64-ne8-nd16.yaml old mode 100755 new mode 100644 similarity index 71% rename from configs/shape-autoencoder/l1024-e64-ne8-nd16.yaml rename to configs/shape-autoencoder/michelangelo-l768-e64-ne8-nd16.yaml index 099e7c3c29171c336fd8092515569e429480b48a..cb7443872e844df25bd5001bf4963f4b08404c9f --- a/configs/shape-autoencoder/l1024-e64-ne8-nd16.yaml +++ b/configs/shape-autoencoder/michelangelo-l768-e64-ne8-nd16.yaml @@ -1,25 +1,25 @@ exp_root_dir: "outputs" -name: "michelangelo-autoencoder/l1024-e64-ne8-nd16" -tag: "${rmspace:n${data.n_samples}+${data.supervision_type}+rot${data.rotate}+noise${data.noise_sigma}+${system.shape_model.embed_type}+dsample${system.shape_model.use_downsample}+pfeat${system.shape_model.point_feats}+logits${system.loss.lambda_logits}+kl${system.loss.lambda_kl}+lr${system.optimizer.args.lr},_}" +name: "michelangelo-autoencoder/michelangelo-l768-e64-ne8-nd16" +tag: "${rmspace:n${data.n_samples}+${data.supervision_type}+rot${data.rotate_points}+noise${data.noise_sigma}+${system.shape_model.embed_type}+dsample${system.shape_model.use_downsample}+pfeat${system.shape_model.point_feats}+logits${system.loss.lambda_logits}+kl${system.loss.lambda_kl}+lr${system.optimizer.args.lr},_}" seed: 0 -data_type: "objaverse-datamodule" +data_type: "Objaverse-datamodule" data: - root_dir: "data/objaverse_clean/sdf_100k" - data_type: "sdf" - n_samples: 4096 - noise_sigma: 0. - rotate: False - + root_dir: ./data/objaverse + + load_geometry: True # whether to load geometry + geo_data_type: "tsdf" + n_samples: 10240 load_supervision: True supervision_type: "occupancy" - n_supervision: 4096 + n_supervision: 10240 + tsdf_threshold: 0.0078125 # threshold for truncating sdf values, used when input is sdf load_image: False # whether to load images load_caption: False # whether to load captions - batch_size: 128 - num_workers: 16 + batch_size: 8 + num_workers: 0 system_type: "shape-autoencoder-system" system: @@ -84,8 +84,7 @@ trainer: max_epochs: 100000 log_every_n_steps: 5 num_sanity_val_steps: 1 - # val_check_interval: 200 - check_val_every_n_epoch: 10 + check_val_every_n_epoch: 600 enable_progress_bar: true precision: 16-mixed diff --git a/craftsman/__init__.py b/craftsman/__init__.py index aaae1910181ee66ba098822f3009e44765e4a3a2..e5183952013695e1ef5fe4dc42abc2b9bc3f6e76 100755 --- a/craftsman/__init__.py +++ b/craftsman/__init__.py @@ -1,4 +1,7 @@ import importlib +from .pipeline import ( + CraftsManPipeline +) __modules__ = {} diff --git a/craftsman/__pycache__/__init__.cpython-310.pyc b/craftsman/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a51ef6f8fe535a1397913052837454aaa7d53f7 Binary files /dev/null and b/craftsman/__pycache__/__init__.cpython-310.pyc differ diff --git a/craftsman/__pycache__/__init__.cpython-311.pyc b/craftsman/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index fdd02739ebf30ab292a84fde381cb39014769ca0..0000000000000000000000000000000000000000 Binary files a/craftsman/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/craftsman/__pycache__/__init__.cpython-38.pyc b/craftsman/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 1ae602b962e76d70f08d6a8635c2e94951963c9b..0000000000000000000000000000000000000000 Binary files a/craftsman/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/craftsman/__pycache__/pipeline.cpython-310.pyc b/craftsman/__pycache__/pipeline.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11f8f99c630037594255abc91320ca76286dcedb Binary files /dev/null and b/craftsman/__pycache__/pipeline.cpython-310.pyc differ diff --git a/craftsman/data/Objaverse.py b/craftsman/data/Objaverse.py new file mode 100755 index 0000000000000000000000000000000000000000..3d63daa7e2ec7861500daf19af589d791c006374 --- /dev/null +++ b/craftsman/data/Objaverse.py @@ -0,0 +1,65 @@ +import math +import os +import json +import re +import cv2 +from dataclasses import dataclass, field + +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from craftsman import register +from craftsman.utils.typing import * +from craftsman.utils.config import parse_structured + +from .base import BaseDataModuleConfig, BaseDataset + +@dataclass +class ObjaverseDataModuleConfig(BaseDataModuleConfig): + pass + +class ObjaverseDataset(BaseDataset): + pass + + +@register("Objaverse-datamodule") +class ObjaverseDataModule(pl.LightningDataModule): + cfg: ObjaverseDataModuleConfig + + def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: + super().__init__() + self.cfg = parse_structured(ObjaverseDataModuleConfig, cfg) + + def setup(self, stage=None) -> None: + if stage in [None, "fit"]: + self.train_dataset = ObjaverseDataset(self.cfg, "train") + if stage in [None, "fit", "validate"]: + self.val_dataset = ObjaverseDataset(self.cfg, "val") + if stage in [None, "test", "predict"]: + self.test_dataset = ObjaverseDataset(self.cfg, "test") + + def prepare_data(self): + pass + + def general_loader(self, dataset, batch_size, collate_fn=None, num_workers=0) -> DataLoader: + return DataLoader( + dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers + ) + + def train_dataloader(self) -> DataLoader: + return self.general_loader( + self.train_dataset, + batch_size=self.cfg.batch_size, + collate_fn=self.train_dataset.collate, + num_workers=self.cfg.num_workers + ) + + def val_dataloader(self) -> DataLoader: + return self.general_loader(self.val_dataset, batch_size=1) + + def test_dataloader(self) -> DataLoader: + return self.general_loader(self.test_dataset, batch_size=1) + + def predict_dataloader(self) -> DataLoader: + return self.general_loader(self.test_dataset, batch_size=1) \ No newline at end of file diff --git a/craftsman/data/__init__.py b/craftsman/data/__init__.py old mode 100644 new mode 100755 index 6548321fbbaf745526a3060ad9e3826c793b2f2b..e3b5936f9d527c905000525bd6763b277909d10e --- a/craftsman/data/__init__.py +++ b/craftsman/data/__init__.py @@ -1,3 +1,3 @@ from . import ( - objaverse + Objaverse ) \ No newline at end of file diff --git a/craftsman/data/__pycache__/Objaverse.cpython-310.pyc b/craftsman/data/__pycache__/Objaverse.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e386501d4e556a0b58284856146c0cc902fe116c Binary files /dev/null and b/craftsman/data/__pycache__/Objaverse.cpython-310.pyc differ diff --git a/craftsman/data/__pycache__/__init__.cpython-310.pyc b/craftsman/data/__pycache__/__init__.cpython-310.pyc index 533244b5257e1439dbf99999ec2a39f1a9a5ab3b..a3984e5da4f87d7c9daa51190557b2d1b52905ae 100644 Binary files a/craftsman/data/__pycache__/__init__.cpython-310.pyc and b/craftsman/data/__pycache__/__init__.cpython-310.pyc differ diff --git a/craftsman/data/__pycache__/__init__.cpython-38.pyc b/craftsman/data/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index fe7fb6e708782671297bc33419978a06181211b3..0000000000000000000000000000000000000000 Binary files a/craftsman/data/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/craftsman/data/__pycache__/base.cpython-310.pyc b/craftsman/data/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..598c77443204d9b4fedab25daee82170de9c59e0 Binary files /dev/null and b/craftsman/data/__pycache__/base.cpython-310.pyc differ diff --git a/craftsman/data/__pycache__/objaverse.cpython-310.pyc b/craftsman/data/__pycache__/objaverse.cpython-310.pyc deleted file mode 100644 index 48f210fd6ed4f965ce3beedc195ba6ed98427e7f..0000000000000000000000000000000000000000 Binary files a/craftsman/data/__pycache__/objaverse.cpython-310.pyc and /dev/null differ diff --git a/craftsman/data/__pycache__/objaverse.cpython-38.pyc b/craftsman/data/__pycache__/objaverse.cpython-38.pyc deleted file mode 100644 index 8285f3527159173cb93626e5c3862c743f4c47ff..0000000000000000000000000000000000000000 Binary files a/craftsman/data/__pycache__/objaverse.cpython-38.pyc and /dev/null differ diff --git a/craftsman/data/__pycache__/objaverse_parquet.cpython-310.pyc b/craftsman/data/__pycache__/objaverse_parquet.cpython-310.pyc deleted file mode 100644 index 218ef66f65b8d02539c9e112aaa909ef61d70d53..0000000000000000000000000000000000000000 Binary files a/craftsman/data/__pycache__/objaverse_parquet.cpython-310.pyc and /dev/null differ diff --git a/craftsman/data/__pycache__/objaverse_patch_sdf.cpython-310.pyc b/craftsman/data/__pycache__/objaverse_patch_sdf.cpython-310.pyc deleted file mode 100644 index e19b0883c14dc91bb0fb558ad2c7c85c37b4b72a..0000000000000000000000000000000000000000 Binary files a/craftsman/data/__pycache__/objaverse_patch_sdf.cpython-310.pyc and /dev/null differ diff --git a/craftsman/data/__pycache__/objaverse_sdf.cpython-310.pyc b/craftsman/data/__pycache__/objaverse_sdf.cpython-310.pyc deleted file mode 100644 index ac84be9d68686c37a5dd03a0ac8eb73825d59922..0000000000000000000000000000000000000000 Binary files a/craftsman/data/__pycache__/objaverse_sdf.cpython-310.pyc and /dev/null differ diff --git a/craftsman/data/__pycache__/objaverse_sdf_parquet.cpython-310.pyc b/craftsman/data/__pycache__/objaverse_sdf_parquet.cpython-310.pyc deleted file mode 100644 index 7f5f766e30b6556ff0ff83c8a1b68c778631627b..0000000000000000000000000000000000000000 Binary files a/craftsman/data/__pycache__/objaverse_sdf_parquet.cpython-310.pyc and /dev/null differ diff --git a/craftsman/data/__pycache__/shape.cpython-310.pyc b/craftsman/data/__pycache__/shape.cpython-310.pyc deleted file mode 100644 index a41c08a10c6c4cecadd453232efa80d0cc48d01d..0000000000000000000000000000000000000000 Binary files a/craftsman/data/__pycache__/shape.cpython-310.pyc and /dev/null differ diff --git a/craftsman/data/base.py b/craftsman/data/base.py new file mode 100755 index 0000000000000000000000000000000000000000..a89fcbf466c1af978e498fb383eb52df10bc8d82 --- /dev/null +++ b/craftsman/data/base.py @@ -0,0 +1,230 @@ +import math +import os +import json +import re +import cv2 +from dataclasses import dataclass, field + +import random +import imageio +import numpy as np +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset +from PIL import Image + +from craftsman.utils.typing import * + +def fit_bounding_box(img, mask, marign_pix_dis, background_color): + # alpha_channel = img[:, :, 3] + alpha_channel = mask.numpy().squeeze() + height = np.any(alpha_channel, axis=1) + width = np.any(alpha_channel, axis=0) + h_min, h_max = np.where(height)[0][[0, -1]] + w_min, w_max = np.where(width)[0][[0, -1]] + box_height = h_max - h_min + box_width = w_max - w_min + cropped_image = img[h_min:h_max, w_min:w_max] + if box_height > box_width: + new_hight = 512 - 2 * marign_pix_dis + new_width = int((512 - 2 * marign_pix_dis) / (box_height) * box_width) + 1 + else: + new_hight = int((512 - 2 * marign_pix_dis) / (box_width) * box_height) + 1 + new_width = 512 - 2 * marign_pix_dis + new_h_min_pos = int((512 - new_hight) / 2 + 1) + new_h_max_pos = new_hight + new_h_min_pos + + new_w_min_pos = int((512 - new_width) / 2 + 1) + new_w_max_pos = new_width + new_w_min_pos + # extend of the bbox + new_image = np.full((512, 512, 3), background_color) + new_image[new_h_min_pos:new_h_max_pos, new_w_min_pos:new_w_max_pos, :] = cv2.resize(cropped_image.numpy(), (new_width, new_hight)) + + return torch.from_numpy(new_image) + +@dataclass +class BaseDataModuleConfig: + local_dir: str = None + + ################################# Geometry part ################################# + load_geometry: bool = True # whether to load geometry data + geo_data_type: str = "occupancy" # occupancy, sdf + geo_data_path: str = "" # path to the geometry data + # for occupancy and sdf data + n_samples: int = 4096 # number of points in input point cloud + upsample_ratio: int = 1 # upsample ratio for input point cloud + sampling_strategy: str = "random" # sampling strategy for input point cloud + scale: float = 1.0 # scale of the input point cloud and target supervision + load_supervision: bool = True # whether to load supervision + supervision_type: str = "occupancy" # occupancy, sdf, tsdf + tsdf_threshold: float = 0.05 # threshold for truncating sdf values, used when input is sdf + n_supervision: int = 10000 # number of points in supervision + + ################################# Image part ################################# + load_image: bool = False # whether to load images + image_data_path: str = "" # path to the image data + image_type: str = "rgb" # rgb, normal + background_color: Tuple[float, float, float] = field( + default_factory=lambda: (0.5, 0.5, 0.5) + ) + idx: Optional[List[int]] = None # index of the image to load + n_views: int = 1 # number of views + marign_pix_dis: int = 30 # margin of the bounding box + + +class BaseDataset(Dataset): + def __init__(self, cfg: Any, split: str) -> None: + super().__init__() + self.cfg: BaseDataModuleConfig = cfg + self.split = split + + self.uids = json.load(open(f'{cfg.root_dir}/{split}.json')) + print(f"Loaded {len(self.uids)} {split} uids") + + def __len__(self): + return len(self.uids) + + + def _load_shape_from_occupancy_or_sdf(self, index: int) -> Dict[str, Any]: + if self.cfg.geo_data_type == "occupancy": + # for input point cloud, using Objaverse-MIX data + pointcloud = np.load(f'{self.cfg.geo_data_path}/{self.uids[index]}/pointcloud.npz') + surface = np.asarray(pointcloud['points']) * 2 # range from -1 to 1 + normal = np.asarray(pointcloud['normals']) + surface = np.concatenate([surface, normal], axis=1) + elif self.cfg.geo_data_type == "sdf": + # for sdf data with our own format + if re.match(r"\.\.", self.uids[index]): + data = np.load(f'{self.cfg.geo_data_path}/{self.uids[index]}.npz') + else: + data = np.load(f'{self.uids[index]}.npz') + # for input point cloud + surface = data["surface"] + else: + raise NotImplementedError(f"Data type {self.cfg.geo_data_type} not implemented") + + # random sampling + if self.cfg.sampling_strategy == "random": + rng = np.random.default_rng() + ind = rng.choice(surface.shape[0], self.cfg.upsample_ratio * self.cfg.n_samples, replace=False) + surface = surface[ind] + elif self.cfg.sampling_strategy == "fps": + import fpsample + kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling(surface[:, :3], self.cfg.n_samples, h=5) + surface = surface[kdline_fps_samples_idx] + else: + raise NotImplementedError(f"sampling strategy {self.cfg.sampling_strategy} not implemented") + # rescale data + surface[:, :3] = surface[:, :3] * self.cfg.scale # target scale + ret = { + "uid": self.uids[index].split('/')[-1], + "surface": surface.astype(np.float32), + } + + return ret + + def _load_shape_supervision_occupancy_or_sdf(self, index: int) -> Dict[str, Any]: + # for supervision + ret = {} + if self.cfg.data_type == "occupancy": + points = np.load(f'{self.cfg.geo_data_path}/{self.uids[index]}/points.npz') + rand_points = np.asarray(points['points']) * 2 # range from -1.1 to 1.1 + occupancies = np.asarray(points['occupancies']) + occupancies = np.unpackbits(occupancies) + elif self.cfg.data_type == "sdf": + data = np.load(f'{self.cfg.geo_data_path}/{self.uids[index]}.npz') + rand_points = data['rand_points'] + sdfs = data['sdfs'] + else: + raise NotImplementedError(f"Data type {self.cfg.data_type} not implemented") + + # random sampling + rng = np.random.default_rng() + ind = rng.choice(rand_points.shape[0], self.cfg.n_supervision, replace=False) + rand_points = rand_points[ind] + rand_points = rand_points * self.cfg.scale + ret["rand_points"] = rand_points.astype(np.float32) + + if self.cfg.data_type == "occupancy": + assert self.cfg.supervision_type == "occupancy", "Only occupancy supervision is supported for occupancy data" + occupancies = occupancies[ind] + ret["occupancies"] = occupancies.astype(np.float32) + elif self.cfg.data_type == "sdf": + if self.cfg.supervision_type == "sdf": + ret["sdf"] = sdfs[ind].flatten().astype(np.float32) + elif self.cfg.supervision_type == "occupancy": + ret["occupancies"] = np.where(sdfs[ind].flatten() < 1e-3, 0, 1).astype(np.float32) + elif self.cfg.supervision_type == "tsdf": + ret["sdf"] = sdfs[ind].flatten().astype(np.float32).clip(-self.cfg.tsdf_threshold, self.cfg.tsdf_threshold) / self.cfg.tsdf_threshold + else: + raise NotImplementedError(f"Supervision type {self.cfg.supervision_type} not implemented") + + return ret + + + def _load_image(self, index: int) -> Dict[str, Any]: + def _load_single_image(img_path, background_color, marign_pix_dis=None): + img = torch.from_numpy( + np.asarray( + Image.fromarray(imageio.v2.imread(img_path)) + .convert("RGBA") + ) + / 255.0 + ).float() + mask: Float[Tensor, "H W 1"] = img[:, :, -1:] + image: Float[Tensor, "H W 3"] = img[:, :, :3] * mask + background_color[ + None, None, : + ] * (1 - mask) + if marign_pix_dis is not None: + image = fit_bounding_box(image, mask, marign_pix_dis, background_color) + return image, mask + + if self.cfg.background_color == [-1, -1, -1]: + background_color = torch.randint(0, 256, (3,)) + else: + background_color = torch.as_tensor(self.cfg.background_color) + ret = {} + if self.cfg.image_type == "rgb" or self.cfg.image_type == "normal": + assert self.cfg.n_views == 1, "Only single view is supported for single image" + sel_idx = random.choice(self.cfg.idx) + ret["sel_image_idx"] = sel_idx + if self.cfg.image_type == "rgb": + img_path = f'{self.cfg.image_data_path}/' + "/".join(self.uids[index].split('/')[-2:]) + f"/{'{:04d}'.format(sel_idx)}_rgb.png" + elif self.cfg.image_type == "normal": + img_path = f'{self.cfg.image_data_path}/' + "/".join(self.uids[index].split('/')[-2:]) + f"/{'{:04d}'.format(sel_idx)}_normal.png" + ret["image"], ret["mask"] = _load_single_image(img_path, background_color, self.cfg.marign_pix_dis) + + else: + raise NotImplementedError(f"Image type {self.cfg.image_type} not implemented") + + return ret + + def _get_data(self, index): + ret = {"uid": self.uids[index]} + # load geometry + if self.cfg.load_geometry: + if self.cfg.geo_data_type == "occupancy" or self.cfg.geo_data_type == "sdf": + # load shape + ret = self._load_shape_from_occupancy_or_sdf(index) + # load supervision for shape + if self.cfg.load_supervision: + ret.update(self._load_shape_supervision_occupancy_or_sdf(index)) + else: + raise NotImplementedError(f"Geo data type {self.cfg.geo_data_type} not implemented") + + # load image + if self.cfg.load_image: + ret.update(self._load_image(index)) + + return ret + + def __getitem__(self, index): + try: + return self._get_data(index) + except Exception as e: + print(f"Error in {self.uids[index]}: {e}") + return self.__getitem__(np.random.randint(len(self))) + + def collate(self, batch): + from torch.utils.data._utils.collate import default_collate_fn_map + return torch.utils.data.default_collate(batch) diff --git a/craftsman/data/objaverse.py b/craftsman/data/objaverse.py deleted file mode 100644 index 1e2f11c5e883149c4888a716e7e99e03db0c381e..0000000000000000000000000000000000000000 --- a/craftsman/data/objaverse.py +++ /dev/null @@ -1,311 +0,0 @@ -import math -import os -import json -from dataclasses import dataclass, field - -import random -import imageio -import numpy as np -import pytorch_lightning as pl -import torch -import torch.nn.functional as F -from torch.utils.data import DataLoader, Dataset -from torchvision import transforms -from PIL import Image -from transformers import CLIPImageProcessor, CLIPTokenizer - -from craftsman import register -from craftsman.utils.base import Updateable -from craftsman.utils.config import parse_structured -from craftsman.utils.typing import * - -def rot2eul(R): - beta = -np.arcsin(R[2,0]) - alpha = np.arctan2(R[2,1]/np.cos(beta),R[2,2]/np.cos(beta)) - gamma = np.arctan2(R[1,0]/np.cos(beta),R[0,0]/np.cos(beta)) - return np.array((alpha, beta, gamma)) - -def eul2rot(theta) : - R = np.array([[np.cos(theta[1])*np.cos(theta[2]), np.sin(theta[0])*np.sin(theta[1])*np.cos(theta[2]) - np.sin(theta[2])*np.cos(theta[0]), np.sin(theta[1])*np.cos(theta[0])*np.cos(theta[2]) + np.sin(theta[0])*np.sin(theta[2])], - [np.sin(theta[2])*np.cos(theta[1]), np.sin(theta[0])*np.sin(theta[1])*np.sin(theta[2]) + np.cos(theta[0])*np.cos(theta[2]), np.sin(theta[1])*np.sin(theta[2])*np.cos(theta[0]) - np.sin(theta[0])*np.cos(theta[2])], - [-np.sin(theta[1]), np.sin(theta[0])*np.cos(theta[1]), np.cos(theta[0])*np.cos(theta[1])]]) - return R - -@dataclass -class ObjaverseDataModuleConfig: - root_dir: str = None - data_type: str = "occupancy" # occupancy or sdf - n_samples: int = 4096 # number of points in input point cloud - scale: float = 1.0 # scale of the input point cloud and target supervision - noise_sigma: float = 0.0 # noise level of the input point cloud - - load_supervision: bool = True # whether to load supervision - supervision_type: str = "occupancy" # occupancy, sdf, tsdf, tsdf_w_surface - n_supervision: int = 10000 # number of points in supervision - - load_image: bool = False # whether to load images - image_data_path: str = "" # path to the image data - image_type: str = "rgb" # rgb, normal - background_color: Tuple[float, float, float] = field( - default_factory=lambda: (1.0, 1.0, 1.0) - ) - idx: Optional[List[int]] = None # index of the image to load - n_views: int = 1 # number of views - rotate_points: bool = False # whether to rotate the input point cloud and the supervision - - load_caption: bool = False # whether to load captions - caption_type: str = "text" # text, clip_embeds - tokenizer_pretrained_model_name_or_path: str = "" - - batch_size: int = 32 - num_workers: int = 0 - -class ObjaverseDataset(Dataset): - def __init__(self, cfg: Any, split: str) -> None: - super().__init__() - self.cfg: ObjaverseDataModuleConfig = cfg - self.split = split - - self.uids = json.load(open(f'{cfg.root_dir}/{split}.json')) - print(f"Loaded {len(self.uids)} {split} uids") - - if self.cfg.load_caption: - self.tokenizer = CLIPTokenizer.from_pretrained(self.cfg.tokenizer_pretrained_model_name_or_path) - - self.background_color = torch.as_tensor(self.cfg.background_color) - self.distance = 1.0 - self.camera_embedding = torch.as_tensor([ - [[1, 0, 0, 0], - [0, 0, -1, -self.distance], - [0, 1, 0, 0], - [0, 0, 0, 1]], # front to back - - [[0, 0, 1, self.distance], - [1, 0, 0, 0], - [0, 1, 0, 0], - [0, 0, 0, 1]], # right to left - - [[-1, 0, 0, 0], - [0, 0, 1, self.distance], - [0, 1, 0, 0], - [0, 0, 0, 1]], # back to front - - [[0, 0, -1, -self.distance], - [-1, 0, 0, 0], - [0, 1, 0, 0], - [0, 0, 0, 1]], # left to right - ], dtype=torch.float32) - if self.cfg.n_views != 1: - assert self.cfg.n_views == self.camera_embedding.shape[0] - - def __len__(self): - return len(self.uids) - - def _load_shape(self, index: int): - if self.cfg.data_type == "occupancy": - # for input point cloud - pointcloud = np.load(f'{self.cfg.root_dir}/{self.uids[index]}/pointcloud.npz') - surface = np.asarray(pointcloud['points']) * 2 # range from -1 to 1 - normal = np.asarray(pointcloud['normals']) - surface = np.concatenate([surface, normal], axis=1) - elif self.cfg.data_type == "sdf": - data = np.load(f'{self.cfg.root_dir}/{self.uids[index]}.npz') - # for input point cloud - surface = data["surface"] - else: - raise NotImplementedError(f"Data type {self.cfg.data_type} not implemented") - - # random sampling - rng = np.random.default_rng() - ind = rng.choice(surface.shape[0], self.cfg.n_samples, replace=False) - surface = surface[ind] - # rescale data - surface[:, :3] = surface[:, :3] * self.cfg.scale # target scale - # add noise to input point cloud - surface[:, :3] += (np.random.rand(surface.shape[0], 3) * 2 - 1) * self.cfg.noise_sigma - ret = { - "uid": self.uids[index].split('/')[-1], - "surface": surface.astype(np.float32), - } - - return ret - - def _load_shape_supervision(self, index: int): - # for supervision - ret = {} - if self.cfg.data_type == "occupancy": - points = np.load(f'{self.cfg.root_dir}/{self.uids[index]}/points.npz') - rand_points = np.asarray(points['points']) * 2 # range from -1.1 to 1.1 - occupancies = np.asarray(points['occupancies']) - occupancies = np.unpackbits(occupancies) - elif self.cfg.data_type == "sdf": - data = np.load(f'{self.cfg.root_dir}/{self.uids[index]}.npz') - rand_points = data['rand_points'] - sdfs = data['sdfs'] - else: - raise NotImplementedError(f"Data type {self.cfg.data_type} not implemented") - - # random sampling - rng = np.random.default_rng() - ind = rng.choice(rand_points.shape[0], self.cfg.n_supervision, replace=False) - rand_points = rand_points[ind] - rand_points = rand_points * self.cfg.scale - ret["rand_points"] = rand_points.astype(np.float32) - - if self.cfg.data_type == "occupancy": - assert self.cfg.supervision_type == "occupancy", "Only occupancy supervision is supported for occupancy data" - occupancies = occupancies[ind] - ret["occupancies"] = occupancies.astype(np.float32) - elif self.cfg.data_type == "sdf": - if self.cfg.supervision_type == "sdf": - ret["sdf"] = sdfs[ind].flatten().astype(np.float32) - elif self.cfg.supervision_type == "occupancy": - ret["occupancies"] = np.where(sdfs[ind].flatten() < 1e-3, 0, 1).astype(np.float32) - else: - raise NotImplementedError(f"Supervision type {self.cfg.supervision_type} not implemented") - - return ret - - def _load_image(self, index: int): - def _load_single_image(img_path): - img = torch.from_numpy( - np.asarray( - Image.fromarray(imageio.v2.imread(img_path)) - .convert("RGBA") - ) - / 255.0 - ).float() - mask: Float[Tensor, "H W 1"] = img[:, :, -1:] - image: Float[Tensor, "H W 3"] = img[:, :, :3] * mask + self.background_color[ - None, None, : - ] * (1 - mask) - return image - - ret = {} - if self.cfg.image_type == "rgb" or self.cfg.image_type == "normal": - assert self.cfg.n_views == 1, "Only single view is supported for single image" - sel_idx = random.choice(self.cfg.idx) - ret["sel_image_idx"] = sel_idx - if self.cfg.image_type == "rgb": - img_path = f'{self.cfg.image_data_path}/' + "/".join(self.uids[index].split('/')[-2:]) + f"/{sel_idx}.png" - elif self.cfg.image_type == "normal": - img_path = f'{self.cfg.image_data_path}/' + "/".join(self.uids[index].split('/')[-2:]) + f"/{sel_idx}_normal.png" - ret["image"] = _load_single_image(img_path) - ret["c2w"] = self.camera_embedding[sel_idx % 4] - elif self.cfg.image_type == "mvrgb" or self.cfg.image_type == "mvnormal": - sel_idx = random.choice(self.cfg.idx) - ret["sel_image_idx"] = sel_idx - mvimages = [] - for i in range(self.cfg.n_views): - if self.cfg.image_type == "mvrgb": - img_path = f'{self.cfg.image_data_path}/' + "/".join(self.uids[index].split('/')[-2:]) + f"/{sel_idx+i}.png" - elif self.cfg.image_type == "mvnormal": - img_path = f'{self.cfg.image_data_path}/' + "/".join(self.uids[index].split('/')[-2:]) + f"/{sel_idx+i}_normal.png" - mvimages.append(_load_single_image(img_path)) - ret["mvimages"] = torch.stack(mvimages) - ret["c2ws"] = self.camera_embedding - else: - raise NotImplementedError(f"Image type {self.cfg.image_type} not implemented") - - return ret - - def _load_caption(self, index: int, drop_text_embed: bool = False): - ret = {} - if self.cfg.caption_type == "text": - caption = eval(json.load(open(f'{self.cfg.image_data_path}/' + "/".join(self.uids[index].split('/')[-2:]) + f'/annotation.json'))) - texts = [v for k, v in caption.items()] - sel_idx = random.randint(0, len(texts) - 1) - ret["sel_caption_idx"] = sel_idx - ret['text_input_ids'] = self.tokenizer( - texts[sel_idx] if not drop_text_embed else "", - max_length=self.tokenizer.model_max_length, - padding="max_length", - truncation=True, - return_tensors="pt" - ).input_ids.detach() - else: - raise NotImplementedError(f"Caption type {self.cfg.caption_type} not implemented") - - return ret - - def get_data(self, index): - # load shape - ret = self._load_shape(index) - - # load supervision for shape - if self.cfg.load_supervision: - ret.update(self._load_shape_supervision(index)) - - # load image - if self.cfg.load_image: - ret.update(self._load_image(index)) - - # load the rotation of the object and rotate the camera - rots = np.load(f'{self.cfg.image_data_path}/' + "/".join(self.uids[index].split('/')[-2:]) + f'/rots.npy')[ret['sel_image_idx']].astype(np.float32) - rots = torch.tensor(rots[:3, :3], dtype=torch.float32) - if "c2ws" in ret.keys(): - ret["c2ws"][:, :3, :3] = torch.matmul(rots, ret["c2ws"][:, :3, :3]) - ret["c2ws"][:, :3, 3] = torch.matmul(rots, ret["c2ws"][:, :3, 3].unsqueeze(-1)).squeeze(-1) - elif "c2w" in ret.keys(): - ret["c2w"][:3, :3] = torch.matmul(rots, ret["c2w"][:3, :3]) - ret["c2w"][:3, 3] = torch.matmul(rots, ret["c2w"][:3, 3].unsqueeze(-1)).squeeze(-1) - - # load caption - if self.cfg.load_caption: - ret.update(self._load_caption(index)) - - return ret - - def __getitem__(self, index): - try: - return self.get_data(index) - except Exception as e: - print(f"Error in {self.uids[index]}: {e}") - return self.__getitem__(np.random.randint(len(self))) - - - def collate(self, batch): - batch = torch.utils.data.default_collate(batch) - return batch - - -@register("objaverse-datamodule") -class ObjaverseDataModule(pl.LightningDataModule): - cfg: ObjaverseDataModuleConfig - - def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: - super().__init__() - self.cfg = parse_structured(ObjaverseDataModuleConfig, cfg) - - def setup(self, stage=None) -> None: - if stage in [None, "fit"]: - self.train_dataset = ObjaverseDataset(self.cfg, "train") - if stage in [None, "fit", "validate"]: - self.val_dataset = ObjaverseDataset(self.cfg, "val") - if stage in [None, "test", "predict"]: - self.test_dataset = ObjaverseDataset(self.cfg, "test") - - def prepare_data(self): - pass - - def general_loader(self, dataset, batch_size, collate_fn=None, num_workers=0) -> DataLoader: - return DataLoader( - dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers - ) - - def train_dataloader(self) -> DataLoader: - return self.general_loader( - self.train_dataset, - batch_size=self.cfg.batch_size, - collate_fn=self.train_dataset.collate, - num_workers=self.cfg.num_workers - ) - - def val_dataloader(self) -> DataLoader: - return self.general_loader(self.val_dataset, batch_size=1) - - def test_dataloader(self) -> DataLoader: - return self.general_loader(self.test_dataset, batch_size=1) - - def predict_dataloader(self) -> DataLoader: - return self.general_loader(self.test_dataset, batch_size=1) \ No newline at end of file diff --git a/craftsman/models/__init__.py b/craftsman/models/__init__.py index 73b20fbfa46a26d96e273dd3b0a77a765bbe58c7..c051f5d4c57886da8615ef01dadcf80fd0817481 100755 --- a/craftsman/models/__init__.py +++ b/craftsman/models/__init__.py @@ -1,5 +1,5 @@ from . import ( autoencoders, conditional_encoders, - denoisers, + denoisers ) \ No newline at end of file diff --git a/craftsman/models/__pycache__/__init__.cpython-310.pyc b/craftsman/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f2a4aa9900ea2774cdf420f186133bc7841c02a Binary files /dev/null and b/craftsman/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/craftsman/models/__pycache__/__init__.cpython-38.pyc b/craftsman/models/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 34f64f0875b96cc02f7478e3ccb7afdd82ee9022..0000000000000000000000000000000000000000 Binary files a/craftsman/models/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/craftsman/models/autoencoders/__pycache__/__init__.cpython-310.pyc b/craftsman/models/autoencoders/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..816e8317dfc4c21814e583abccec00f656238ef0 Binary files /dev/null and b/craftsman/models/autoencoders/__pycache__/__init__.cpython-310.pyc differ diff --git a/craftsman/models/autoencoders/__pycache__/__init__.cpython-38.pyc b/craftsman/models/autoencoders/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index ecbd3992bc3cab0917eb9a8b3520a73d7f96955c..0000000000000000000000000000000000000000 Binary files a/craftsman/models/autoencoders/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/craftsman/models/autoencoders/__pycache__/michelangelo_autoencoder.cpython-310.pyc b/craftsman/models/autoencoders/__pycache__/michelangelo_autoencoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dc2375a7d7d199b2d62a491b5a98b85c9587634 Binary files /dev/null and b/craftsman/models/autoencoders/__pycache__/michelangelo_autoencoder.cpython-310.pyc differ diff --git a/craftsman/models/autoencoders/__pycache__/michelangelo_autoencoder.cpython-38.pyc b/craftsman/models/autoencoders/__pycache__/michelangelo_autoencoder.cpython-38.pyc deleted file mode 100644 index 62261736e29d122bd646f75876afd7f6ce4bc2ea..0000000000000000000000000000000000000000 Binary files a/craftsman/models/autoencoders/__pycache__/michelangelo_autoencoder.cpython-38.pyc and /dev/null differ diff --git a/craftsman/models/autoencoders/__pycache__/utils.cpython-38.pyc b/craftsman/models/autoencoders/__pycache__/utils.cpython-38.pyc deleted file mode 100644 index 584e668da960b1c80b674fc8588b67b1494c1f25..0000000000000000000000000000000000000000 Binary files a/craftsman/models/autoencoders/__pycache__/utils.cpython-38.pyc and /dev/null differ diff --git a/craftsman/models/autoencoders/michelangelo_autoencoder.py b/craftsman/models/autoencoders/michelangelo_autoencoder.py index 6a8e211d85f80589c97f7f1b544edde37473956d..6cacc0dc3865b12aa8871f0e7d8b12b3375823f9 100755 --- a/craftsman/models/autoencoders/michelangelo_autoencoder.py +++ b/craftsman/models/autoencoders/michelangelo_autoencoder.py @@ -2,9 +2,10 @@ from dataclasses import dataclass import math import torch +import numpy as np +import random import torch.nn as nn from einops import repeat, rearrange -from transformers import CLIPModel import craftsman from craftsman.models.transformers.perceiver_1d import Perceiver @@ -12,8 +13,309 @@ from craftsman.models.transformers.attention import ResidualCrossAttentionBlock from craftsman.utils.checkpoint import checkpoint from craftsman.utils.base import BaseModule from craftsman.utils.typing import * +from craftsman.utils.misc import get_world_size +from craftsman.utils.ops import generate_dense_grid_points + +###################### Utils +VALID_EMBED_TYPES = ["identity", "fourier", "learned_fourier", "siren"] + +class FourierEmbedder(nn.Module): + def __init__(self, + num_freqs: int = 6, + logspace: bool = True, + input_dim: int = 3, + include_input: bool = True, + include_pi: bool = True) -> None: + super().__init__() + + if logspace: + frequencies = 2.0 ** torch.arange( + num_freqs, + dtype=torch.float32 + ) + else: + frequencies = torch.linspace( + 1.0, + 2.0 ** (num_freqs - 1), + num_freqs, + dtype=torch.float32 + ) + + if include_pi: + frequencies *= torch.pi + + self.register_buffer("frequencies", frequencies, persistent=False) + self.include_input = include_input + self.num_freqs = num_freqs + + self.out_dim = self.get_dims(input_dim) + + def get_dims(self, input_dim): + temp = 1 if self.include_input or self.num_freqs == 0 else 0 + out_dim = input_dim * (self.num_freqs * 2 + temp) + + return out_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.num_freqs > 0: + embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1) + if self.include_input: + return torch.cat((x, embed.sin(), embed.cos()), dim=-1) + else: + return torch.cat((embed.sin(), embed.cos()), dim=-1) + else: + return x + +class LearnedFourierEmbedder(nn.Module): + def __init__(self, input_dim, dim): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + per_channel_dim = half_dim // input_dim + self.weights = nn.Parameter(torch.randn(per_channel_dim)) + + self.out_dim = self.get_dims(input_dim) + + def forward(self, x): + # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d] + freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1) + fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1) + return fouriered + + def get_dims(self, input_dim): + return input_dim * (self.weights.shape[0] * 2 + 1) + +class Sine(nn.Module): + def __init__(self, w0 = 1.): + super().__init__() + self.w0 = w0 + def forward(self, x): + return torch.sin(self.w0 * x) + +class Siren(nn.Module): + def __init__( + self, + in_dim, + out_dim, + w0 = 1., + c = 6., + is_first = False, + use_bias = True, + activation = None, + dropout = 0. + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.is_first = is_first + + weight = torch.zeros(out_dim, in_dim) + bias = torch.zeros(out_dim) if use_bias else None + self.init_(weight, bias, c = c, w0 = w0) + + self.weight = nn.Parameter(weight) + self.bias = nn.Parameter(bias) if use_bias else None + self.activation = Sine(w0) if activation is None else activation + self.dropout = nn.Dropout(dropout) + + def init_(self, weight, bias, c, w0): + dim = self.in_dim + + w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0) + weight.uniform_(-w_std, w_std) + + if bias is not None: + bias.uniform_(-w_std, w_std) + + def forward(self, x): + out = F.linear(x, self.weight, self.bias) + out = self.activation(out) + out = self.dropout(out) + return out + +def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, include_pi=True): + if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1): + return nn.Identity(), input_dim + + elif embed_type == "fourier": + embedder_obj = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) + + elif embed_type == "learned_fourier": + embedder_obj = LearnedFourierEmbedder(in_channels=input_dim, dim=num_freqs) + + elif embed_type == "siren": + embedder_obj = Siren(in_dim=input_dim, out_dim=num_freqs * input_dim * 2 + input_dim) + + else: + raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}") + return embedder_obj + + +###################### AutoEncoder +class AutoEncoder(BaseModule): + @dataclass + class Config(BaseModule.Config): + pretrained_model_name_or_path: str = "" + num_latents: int = 256 + embed_dim: int = 64 + width: int = 768 + + cfg: Config + + def configure(self) -> None: + super().configure() + + def encode(self, x: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + raise NotImplementedError + + def decode(self, z: torch.FloatTensor) -> torch.FloatTensor: + raise NotImplementedError + + def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True): + posterior = None + if self.cfg.embed_dim > 0: + moments = self.pre_kl(latents) + posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) + if sample_posterior: + kl_embed = posterior.sample() + else: + kl_embed = posterior.mode() + else: + kl_embed = latents + return kl_embed, posterior + + def forward(self, + surface: torch.FloatTensor, + queries: torch.FloatTensor, + sample_posterior: bool = True): + shape_latents, kl_embed, posterior = self.encode(surface, sample_posterior=sample_posterior) + + latents = self.decode(kl_embed) # [B, num_latents, width] + + logits = self.query(queries, latents) # [B,] + + return shape_latents, latents, posterior, logits + + def query(self, queries: torch.FloatTensor, latents: torch.FloatTensor) -> torch.FloatTensor: + raise NotImplementedError + + @torch.no_grad() + def extract_geometry(self, + latents: torch.FloatTensor, + extract_mesh_func: str = "mc", + bounds: Union[Tuple[float], List[float], float] = (-1.05, -1.05, -1.05, 1.05, 1.05, 1.05), + octree_depth: int = 8, + num_chunks: int = 10000, + ): + + if isinstance(bounds, float): + bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] + + bbox_min = np.array(bounds[0:3]) + bbox_max = np.array(bounds[3:6]) + bbox_size = bbox_max - bbox_min + + xyz_samples, grid_size, length = generate_dense_grid_points( + bbox_min=bbox_min, + bbox_max=bbox_max, + octree_depth=octree_depth, + indexing="ij" + ) + xyz_samples = torch.FloatTensor(xyz_samples) + batch_size = latents.shape[0] + + batch_logits = [] + for start in range(0, xyz_samples.shape[0], num_chunks): + queries = xyz_samples[start: start + num_chunks, :].to(latents) + batch_queries = repeat(queries, "p c -> b p c", b=batch_size) + + logits = self.query(batch_queries, latents) + batch_logits.append(logits.cpu()) + + grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).float().numpy() + + mesh_v_f = [] + has_surface = np.zeros((batch_size,), dtype=np.bool_) + for i in range(batch_size): + try: + if extract_mesh_func == "mc": + from skimage import measure + vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner") + # vertices, faces = mcubes.marching_cubes(grid_logits[i], 0) + vertices = vertices / grid_size * bbox_size + bbox_min + faces = faces[:, [2, 1, 0]] + elif extract_mesh_func == "diffmc": + from diso import DiffMC + diffmc = DiffMC(dtype=torch.float32).to(latents.device) + vertices, faces = diffmc(-torch.tensor(grid_logits[i]).float().to(latents.device), isovalue=0) + vertices = vertices * 2 - 1 + vertices = vertices.cpu().numpy() + faces = faces.cpu().numpy() + elif extract_mesh_func == "diffdmc": + from diso import DiffDMC + diffmc = DiffDMC(dtype=torch.float32).to(latents.device) + vertices, faces = diffmc(-torch.tensor(grid_logits[i]).float().to(latents.device), isovalue=0) + vertices = vertices * 2 - 1 + vertices = vertices.cpu().numpy() + faces = faces.cpu().numpy() + else: + raise NotImplementedError(f"{extract_mesh_func} not implement") + mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces.astype(np.int64)))) + has_surface[i] = True + except: + mesh_v_f.append((None, None)) + has_surface[i] = False + + return mesh_v_f, has_surface + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1): + self.feat_dim = feat_dim + self.parameters = parameters + + if isinstance(parameters, list): + self.mean = parameters[0] + self.logvar = parameters[1] + else: + self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim) + + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean) + + def sample(self): + x = self.mean + self.std * torch.randn_like(self.mean) + return x + + def kl(self, other=None, dims=(1, 2)): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.mean(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=dims) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=dims) + + def nll(self, sample, dims=(1, 2)): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean -from .utils import AutoEncoder, FourierEmbedder, get_embedder class PerceiverCrossAttentionEncoder(nn.Module): def __init__(self, @@ -29,7 +331,10 @@ class PerceiverCrossAttentionEncoder(nn.Module): qkv_bias: bool = True, use_ln_post: bool = False, use_flash: bool = False, - use_checkpoint: bool = False): + use_checkpoint: bool = False, + use_multi_reso: bool = False, + resolutions: list = [], + sampling_prob: list = []): super().__init__() @@ -37,6 +342,9 @@ class PerceiverCrossAttentionEncoder(nn.Module): self.num_latents = num_latents self.use_downsample = use_downsample self.embed_point_feats = embed_point_feats + self.use_multi_reso = use_multi_reso + self.resolutions = resolutions + self.sampling_prob = sampling_prob if not self.use_downsample: self.query = nn.Parameter(torch.randn((num_latents, width)) * 0.02) @@ -83,7 +391,7 @@ class PerceiverCrossAttentionEncoder(nn.Module): """ bs, N, D = pc.shape - + data = self.embedder(pc) if feats is not None: if self.embed_point_feats: @@ -91,13 +399,31 @@ class PerceiverCrossAttentionEncoder(nn.Module): data = torch.cat([data, feats], dim=-1) data = self.input_proj(data) + if self.use_multi_reso: + # number = 8192 + resolution = random.choice(self.resolutions, size=1, p=self.sampling_prob)[0] + + if resolution != N: + + flattened = pc.view(bs*N, D) # bs*N, 64. 103,4096,3 -> 421888,3 + batch = torch.arange(bs).to(pc.device) # 103 + batch = torch.repeat_interleave(batch, N) # bs*N. 421888 + pos = flattened + ratio = 1.0 * resolution / N # 0.0625 + idx = fps(pos, batch, ratio=ratio) #26368 + pc = pc.view(bs*N, -1)[idx].view(bs, -1, D) + bs,N,D=feats.shape + flattened1 = feats.view(bs*N, D) + feats= flattened1.view(bs*N, -1)[idx].view(bs, -1, D) + bs, N, D = pc.shape + if self.use_downsample: ###### fps from torch_cluster import fps - flattened = pc.view(bs*N, D) + flattened = pc.view(bs*N, D) # bs*N, 64 batch = torch.arange(bs).to(pc.device) - batch = torch.repeat_interleave(batch, N) + batch = torch.repeat_interleave(batch, N) # bs*N pos = flattened @@ -184,7 +510,9 @@ class MichelangeloAutoencoder(AutoEncoder): @dataclass class Config(BaseModule.Config): pretrained_model_name_or_path: str = "" + n_samples: int = 4096 use_downsample: bool = False + downsample_ratio: float = 0.0625 num_latents: int = 256 point_feats: int = 0 embed_point_feats: bool = False @@ -202,6 +530,9 @@ class MichelangeloAutoencoder(AutoEncoder): use_ln_post: bool = False use_flash: bool = False use_checkpoint: bool = True + use_multi_reso: Optional[bool] = False + resolutions: Optional[List[int]] = None + sampling_prob: Optional[List[float]] = None cfg: Config @@ -225,7 +556,10 @@ class MichelangeloAutoencoder(AutoEncoder): qkv_bias=self.cfg.qkv_bias, use_ln_post=self.cfg.use_ln_post, use_flash=self.cfg.use_flash, - use_checkpoint=self.cfg.use_checkpoint + use_checkpoint=self.cfg.use_checkpoint, + use_multi_reso=self.cfg.use_multi_reso, + resolutions=self.cfg.resolutions, + sampling_prob=self.cfg.sampling_prob ) if self.cfg.embed_dim > 0: @@ -269,7 +603,14 @@ class MichelangeloAutoencoder(AutoEncoder): if k.startswith('shape_model.'): _pretrained_ckpt[k.replace('shape_model.', '')] = v pretrained_ckpt = _pretrained_ckpt - self.load_state_dict(pretrained_ckpt, strict=True) + else: + _pretrained_ckpt = {} + for k, v in pretrained_ckpt.items(): + if k.startswith('shape_model.'): + _pretrained_ckpt[k.replace('shape_model.', '')] = v + pretrained_ckpt = _pretrained_ckpt + + self.load_state_dict(pretrained_ckpt, strict=False) def encode(self, @@ -288,7 +629,22 @@ class MichelangeloAutoencoder(AutoEncoder): assert surface.shape[-1] == 3 + self.cfg.point_feats, f"\ Expected {3 + self.cfg.point_feats} channels, got {surface.shape[-1]}" - pc, feats = surface[..., :3], surface[..., 3:] # B, n_samples, 3 + pc, feats = surface[..., :3], surface[..., 3:] # B, n_samples, 3 + bs, N, D = pc.shape + if N > self.cfg.n_samples: + # idx = furthest_point_sample(pc, self.cfg.n_samples) # (B, 3, npoint) + # pc = gather_operation(pc, idx).transpose(2, 1).contiguous() + # feats = gather_operation(feats, idx).transpose(2, 1).contiguous() + from torch_cluster import fps + flattened = pc.view(bs*N, D) # bs*N, 64 + batch = torch.arange(bs).to(pc.device) + batch = torch.repeat_interleave(batch, N) # bs*N + pos = flattened + ratio = self.cfg.n_samples / N + idx = fps(pos, batch, ratio=ratio) + pc = pc.view(bs*N, -1)[idx].view(bs, -1, pc.shape[-1]) + feats = feats.view(bs*N, -1)[idx].view(bs, -1, feats.shape[-1]) + shape_latents = self.encoder(pc, feats) # B, num_latents, width kl_embed, posterior = self.encode_kl_embed(shape_latents, sample_posterior) # B, num_latents, embed_dim @@ -324,81 +680,3 @@ class MichelangeloAutoencoder(AutoEncoder): logits = self.decoder(queries, latents).squeeze(-1) return logits - - - - -@craftsman.register("michelangelo-aligned-autoencoder") -class MichelangeloAlignedAutoencoder(MichelangeloAutoencoder): - r""" - A VAE model for encoding shapes into latents and decoding latent representations into shapes. - """ - @dataclass - class Config(MichelangeloAutoencoder.Config): - clip_model_version: Optional[str] = None - - cfg: Config - - def configure(self) -> None: - if self.cfg.clip_model_version is not None: - self.clip_model: CLIPModel = CLIPModel.from_pretrained(self.cfg.clip_model_version) - self.projection = nn.Parameter(torch.empty(self.cfg.width, self.clip_model.projection_dim)) - self.logit_scale = torch.exp(self.clip_model.logit_scale.data) - nn.init.normal_(self.projection, std=self.clip_model.projection_dim ** -0.5) - else: - self.projection = nn.Parameter(torch.empty(self.cfg.width, 768)) - nn.init.normal_(self.projection, std=768 ** -0.5) - - self.cfg.num_latents = self.cfg.num_latents + 1 - - super().configure() - - def encode(self, - surface: torch.FloatTensor, - sample_posterior: bool = True): - """ - Args: - surface (torch.FloatTensor): [B, N, 3+C] - sample_posterior (bool): - - Returns: - latents (torch.FloatTensor) - posterior (DiagonalGaussianDistribution or None): - """ - assert surface.shape[-1] == 3 + self.cfg.point_feats, f"\ - Expected {3 + self.cfg.point_feats} channels, got {surface.shape[-1]}" - - pc, feats = surface[..., :3], surface[..., 3:] # B, n_samples, 3 - shape_latents = self.encoder(pc, feats) # B, num_latents, width - shape_embeds = shape_latents[:, 0] # B, width - shape_latents = shape_latents[:, 1:] # B, num_latents-1, width - kl_embed, posterior = self.encode_kl_embed(shape_latents, sample_posterior) # B, num_latents, embed_dim - - shape_embeds = shape_embeds @ self.projection - return shape_embeds, kl_embed, posterior - - def forward(self, - surface: torch.FloatTensor, - queries: torch.FloatTensor, - sample_posterior: bool = True): - """ - Args: - surface (torch.FloatTensor): [B, N, 3+C] - queries (torch.FloatTensor): [B, P, 3] - sample_posterior (bool): - - Returns: - shape_embeds (torch.FloatTensor): [B, width] - latents (torch.FloatTensor): [B, num_latents, embed_dim] - posterior (DiagonalGaussianDistribution or None). - logits (torch.FloatTensor): [B, P] - """ - - shape_embeds, kl_embed, posterior = self.encode(surface, sample_posterior=sample_posterior) - - latents = self.decode(kl_embed) # [B, num_latents - 1, width] - - logits = self.query(queries, latents) # [B,] - - return shape_embeds, latents, posterior, logits - diff --git a/craftsman/models/autoencoders/utils.py b/craftsman/models/autoencoders/utils.py deleted file mode 100755 index 53a70e5a8e5c983619cd0d82c1bc859d511849d1..0000000000000000000000000000000000000000 --- a/craftsman/models/autoencoders/utils.py +++ /dev/null @@ -1,302 +0,0 @@ -from dataclasses import dataclass - -import torch -import torch.nn as nn -from torch import distributed as tdist -from torch.nn import functional as F -import math -import mcubes -import numpy as np -from einops import repeat, rearrange -from skimage import measure - -from craftsman.utils.base import BaseModule -from craftsman.utils.typing import * -from craftsman.utils.misc import get_world_size -from craftsman.utils.ops import generate_dense_grid_points - -VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"] - -class FourierEmbedder(nn.Module): - def __init__(self, - num_freqs: int = 6, - logspace: bool = True, - input_dim: int = 3, - include_input: bool = True, - include_pi: bool = True) -> None: - super().__init__() - - if logspace: - frequencies = 2.0 ** torch.arange( - num_freqs, - dtype=torch.float32 - ) - else: - frequencies = torch.linspace( - 1.0, - 2.0 ** (num_freqs - 1), - num_freqs, - dtype=torch.float32 - ) - - if include_pi: - frequencies *= torch.pi - - self.register_buffer("frequencies", frequencies, persistent=False) - self.include_input = include_input - self.num_freqs = num_freqs - - self.out_dim = self.get_dims(input_dim) - - def get_dims(self, input_dim): - temp = 1 if self.include_input or self.num_freqs == 0 else 0 - out_dim = input_dim * (self.num_freqs * 2 + temp) - - return out_dim - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.num_freqs > 0: - embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1) - if self.include_input: - return torch.cat((x, embed.sin(), embed.cos()), dim=-1) - else: - return torch.cat((embed.sin(), embed.cos()), dim=-1) - else: - return x - - -class LearnedFourierEmbedder(nn.Module): - def __init__(self, input_dim, dim): - super().__init__() - assert (dim % 2) == 0 - half_dim = dim // 2 - per_channel_dim = half_dim // input_dim - self.weights = nn.Parameter(torch.randn(per_channel_dim)) - - self.out_dim = self.get_dims(input_dim) - - def forward(self, x): - # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d] - freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1) - fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1) - return fouriered - - def get_dims(self, input_dim): - return input_dim * (self.weights.shape[0] * 2 + 1) - -class Sine(nn.Module): - def __init__(self, w0 = 1.): - super().__init__() - self.w0 = w0 - def forward(self, x): - return torch.sin(self.w0 * x) - -class Siren(nn.Module): - def __init__( - self, - in_dim, - out_dim, - w0 = 1., - c = 6., - is_first = False, - use_bias = True, - activation = None, - dropout = 0. - ): - super().__init__() - self.in_dim = in_dim - self.out_dim = out_dim - self.is_first = is_first - - weight = torch.zeros(out_dim, in_dim) - bias = torch.zeros(out_dim) if use_bias else None - self.init_(weight, bias, c = c, w0 = w0) - - self.weight = nn.Parameter(weight) - self.bias = nn.Parameter(bias) if use_bias else None - self.activation = Sine(w0) if activation is None else activation - self.dropout = nn.Dropout(dropout) - - def init_(self, weight, bias, c, w0): - dim = self.in_dim - - w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0) - weight.uniform_(-w_std, w_std) - - if bias is not None: - bias.uniform_(-w_std, w_std) - - def forward(self, x): - out = F.linear(x, self.weight, self.bias) - out = self.activation(out) - out = self.dropout(out) - return out - -def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, include_pi=True): - if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1): - return nn.Identity(), input_dim - - elif embed_type == "fourier": - embedder_obj = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) - - elif embed_type == "learned_fourier": - embedder_obj = LearnedFourierEmbedder(in_channels=input_dim, dim=num_freqs) - - elif embed_type == "siren": - embedder_obj = Siren(in_dim=input_dim, out_dim=num_freqs * input_dim * 2 + input_dim) - - elif embed_type == "hashgrid": - raise NotImplementedError - - elif embed_type == "sphere_harmonic": - raise NotImplementedError - - else: - raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}") - return embedder_obj - - -###################### AutoEncoder -class AutoEncoder(BaseModule): - @dataclass - class Config(BaseModule.Config): - pretrained_model_name_or_path: str = "" - num_latents: int = 256 - embed_dim: int = 64 - width: int = 768 - - cfg: Config - - def configure(self) -> None: - super().configure() - - def encode(self, x: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor]: - raise NotImplementedError - - def decode(self, z: torch.FloatTensor) -> torch.FloatTensor: - raise NotImplementedError - - def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True): - posterior = None - if self.cfg.embed_dim > 0: - moments = self.pre_kl(latents) - posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) - if sample_posterior: - kl_embed = posterior.sample() - else: - kl_embed = posterior.mode() - else: - kl_embed = latents - return kl_embed, posterior - - def forward(self, - surface: torch.FloatTensor, - queries: torch.FloatTensor, - sample_posterior: bool = True): - shape_latents, kl_embed, posterior = self.encode(surface, sample_posterior=sample_posterior) - - latents = self.decode(kl_embed) # [B, num_latents, width] - - logits = self.query(queries, latents) # [B,] - - return shape_latents, latents, posterior, logits - - def query(self, queries: torch.FloatTensor, latents: torch.FloatTensor) -> torch.FloatTensor: - raise NotImplementedError - - @torch.no_grad() - def extract_geometry(self, - latents: torch.FloatTensor, - bounds: Union[Tuple[float], List[float], float] = (-1.05, -1.05, -1.05, 1.05, 1.05, 1.05), - octree_depth: int = 8, - num_chunks: int = 10000, - ): - - if isinstance(bounds, float): - bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] - - bbox_min = np.array(bounds[0:3]) - bbox_max = np.array(bounds[3:6]) - bbox_size = bbox_max - bbox_min - - xyz_samples, grid_size, length = generate_dense_grid_points( - bbox_min=bbox_min, - bbox_max=bbox_max, - octree_depth=octree_depth, - indexing="ij" - ) - xyz_samples = torch.FloatTensor(xyz_samples) - batch_size = latents.shape[0] - - batch_logits = [] - for start in range(0, xyz_samples.shape[0], num_chunks): - queries = xyz_samples[start: start + num_chunks, :].to(latents) - batch_queries = repeat(queries, "p c -> b p c", b=batch_size) - - logits = self.query(batch_queries, latents) - batch_logits.append(logits.cpu()) - - grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).float().numpy() - - mesh_v_f = [] - has_surface = np.zeros((batch_size,), dtype=np.bool_) - for i in range(batch_size): - try: - vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner") - # vertices, faces = mcubes.marching_cubes(grid_logits[i], 0) - vertices = vertices / grid_size * bbox_size + bbox_min - faces = faces[:, [2, 1, 0]] - mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces))) - has_surface[i] = True - except: - mesh_v_f.append((None, None)) - has_surface[i] = False - - return mesh_v_f, has_surface - -class DiagonalGaussianDistribution(object): - def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1): - self.feat_dim = feat_dim - self.parameters = parameters - - if isinstance(parameters, list): - self.mean = parameters[0] - self.logvar = parameters[1] - else: - self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim) - - self.logvar = torch.clamp(self.logvar, -30.0, 20.0) - self.deterministic = deterministic - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean) - - def sample(self): - x = self.mean + self.std * torch.randn_like(self.mean) - return x - - def kl(self, other=None, dims=(1, 2)): - if self.deterministic: - return torch.Tensor([0.]) - else: - if other is None: - return 0.5 * torch.mean(torch.pow(self.mean, 2) - + self.var - 1.0 - self.logvar, - dim=dims) - else: - return 0.5 * torch.mean( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - 1.0 - self.logvar + other.logvar, - dim=dims) - - def nll(self, sample, dims=(1, 2)): - if self.deterministic: - return torch.Tensor([0.]) - logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, - dim=dims) - - def mode(self): - return self.mean diff --git a/craftsman/models/conditional_encoders/__init__.py b/craftsman/models/conditional_encoders/__init__.py index be7f5c66f7fa65b598a3e47e627d2318d769b260..576ac14341e4f34681370903125df7ac1c75eb7e 100755 --- a/craftsman/models/conditional_encoders/__init__.py +++ b/craftsman/models/conditional_encoders/__init__.py @@ -1,3 +1,3 @@ from . import ( - clip_encoder, + cond_encoder ) diff --git a/craftsman/models/conditional_encoders/__pycache__/__init__.cpython-310.pyc b/craftsman/models/conditional_encoders/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cdb9960006c8152e0087a74620d445905ca85c4 Binary files /dev/null and b/craftsman/models/conditional_encoders/__pycache__/__init__.cpython-310.pyc differ diff --git a/craftsman/models/conditional_encoders/__pycache__/__init__.cpython-38.pyc b/craftsman/models/conditional_encoders/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 3ccb873be8e22f71330428529d306ad6db924c25..0000000000000000000000000000000000000000 Binary files a/craftsman/models/conditional_encoders/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/craftsman/models/conditional_encoders/__pycache__/base.cpython-310.pyc b/craftsman/models/conditional_encoders/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bd6963909f1f914109d6babaea136f7b82e87f9 Binary files /dev/null and b/craftsman/models/conditional_encoders/__pycache__/base.cpython-310.pyc differ diff --git a/craftsman/models/conditional_encoders/__pycache__/base.cpython-38.pyc b/craftsman/models/conditional_encoders/__pycache__/base.cpython-38.pyc deleted file mode 100644 index beffb7a70ceb8236c6a611bf57112bf50fadb2dc..0000000000000000000000000000000000000000 Binary files a/craftsman/models/conditional_encoders/__pycache__/base.cpython-38.pyc and /dev/null differ diff --git a/craftsman/models/conditional_encoders/__pycache__/clip_encoder.cpython-38.pyc b/craftsman/models/conditional_encoders/__pycache__/clip_encoder.cpython-38.pyc deleted file mode 100644 index 883250d27f42079814e36cf2eeaa5405557f843c..0000000000000000000000000000000000000000 Binary files a/craftsman/models/conditional_encoders/__pycache__/clip_encoder.cpython-38.pyc and /dev/null differ diff --git a/craftsman/models/conditional_encoders/__pycache__/cond_encoder.cpython-310.pyc b/craftsman/models/conditional_encoders/__pycache__/cond_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1aa092ceb6e5e4d1b03c3ad1c906ba5a44f89763 Binary files /dev/null and b/craftsman/models/conditional_encoders/__pycache__/cond_encoder.cpython-310.pyc differ diff --git a/craftsman/models/conditional_encoders/base.py b/craftsman/models/conditional_encoders/base.py index 7bda4c04e2a59d2a60a91356e11b76a4ae6c66bc..86933ae169915d031c4f4cd55ae304562a88de31 100755 --- a/craftsman/models/conditional_encoders/base.py +++ b/craftsman/models/conditional_encoders/base.py @@ -69,9 +69,6 @@ class BaseEmbedder(BaseModule): def encode_image(self, images: Iterable[Optional[ImageType]], camera_embeds: Optional[torch.Tensor] = None, **kwargs) -> torch.FloatTensor: pass - def encode_text(self, texts: List[str], **kwargs) -> torch.FloatTensor: - pass - def encode_camera(self, c2ws: torch.Tensor): if self.cfg.camera_embeds_type == "sincos": assert c2ws.shape[-1] == 4 and c2ws.shape[-2] == 4, f"Invalid c2ws shape: {c2ws.shape}" @@ -80,46 +77,32 @@ class BaseEmbedder(BaseModule): else: raise NotImplementedError(f"Unknown camera_embeds_type: {self.cfg.camera_embeds_type}") - def post_process_embeds(self, text_embeds, visual_embeds): - bs = text_embeds.shape[0] if text_embeds is not None else visual_embeds.shape[0] + def post_process_embeds(self, visual_embeds): + bs =visual_embeds.shape[0] if self.cfg.normalize_embeds: - # post-process the text/visual embeds - if text_embeds is not None: - text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) + # post-process the visual embeds if visual_embeds is not None: visual_embeds = visual_embeds / visual_embeds.norm(dim=-1, keepdim=True) - assert text_embeds is not None or visual_embeds is not None - - # return text_embeds, visual_embeds - if text_embeds is not None and visual_embeds is not None: - return torch.cat([text_embeds, visual_embeds], dim=1) - elif text_embeds is not None: - return text_embeds - else: - return visual_embeds + assert visual_embeds is not None + # return visual_embeds + return visual_embeds def forward(self, batch): - bs = batch["surface"].shape[0] + if batch["image"].dim() == 5: + bs = batch["image"].shape[0] * batch["image"].shape[1] + else: + bs = batch["image"].shape[0] - text_embeds, visual_embeds = None, None + visual_embeds = None if random.random() < self.cfg.empty_embeds_ratio: - if "text_input_ids" in batch or "text_embeds" in batch: - if self.empty_text_embeds is None: - if not self.cfg.zero_uncond_embeds: - self.empty_text_embeds = self.encode_text([""]).detach() # [1, 77, 768] - text_embeds = self.empty_text_embeds.repeat(bs, 1, 1) if "image" in batch or "image_embeds" in batch: visual_embeds = self.empty_image_embeds.repeat(bs, 1, 1) elif "mvimages" in batch or "mvimage_embeds" in batch: visual_embeds = self.empty_image_embeds.unsqueeze(1).repeat(bs, 1, 1, 1) else: - # for text inputs - if "text_input_ids" in batch: - text_embeds = self.encode_text(batch["text_input_ids"]) - # for visual inputs if "image" in batch: if self.cfg.encode_camera: @@ -136,4 +119,4 @@ class BaseEmbedder(BaseModule): visual_embeds = self.encode_image( batch["mvimages"].view(-1, *batch["mvimages"].shape[-3:])).view(bs, n_views, *self.empty_image_embeds.shape[-2:]) - return self.post_process_embeds(text_embeds, visual_embeds) + return self.post_process_embeds(visual_embeds) diff --git a/craftsman/models/conditional_encoders/clip/__pycache__/modeling_clip.cpython-310.pyc b/craftsman/models/conditional_encoders/clip/__pycache__/modeling_clip.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b2783cc2194ea5717c0037b3b33547d75372724 Binary files /dev/null and b/craftsman/models/conditional_encoders/clip/__pycache__/modeling_clip.cpython-310.pyc differ diff --git a/craftsman/models/conditional_encoders/clip/__pycache__/modeling_clip.cpython-38.pyc b/craftsman/models/conditional_encoders/clip/__pycache__/modeling_clip.cpython-38.pyc deleted file mode 100644 index 9803c8612a0de3c10577ab23fe901bc4f0f5f2a0..0000000000000000000000000000000000000000 Binary files a/craftsman/models/conditional_encoders/clip/__pycache__/modeling_clip.cpython-38.pyc and /dev/null differ diff --git a/craftsman/models/conditional_encoders/clip/__pycache__/modeling_conditional_clip.cpython-310.pyc b/craftsman/models/conditional_encoders/clip/__pycache__/modeling_conditional_clip.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f0517d090506cdf564f4e2c281cc1f4e7db2667 Binary files /dev/null and b/craftsman/models/conditional_encoders/clip/__pycache__/modeling_conditional_clip.cpython-310.pyc differ diff --git a/craftsman/models/conditional_encoders/clip/__pycache__/modeling_conditional_clip.cpython-38.pyc b/craftsman/models/conditional_encoders/clip/__pycache__/modeling_conditional_clip.cpython-38.pyc deleted file mode 100644 index 4d0af12db96dfa86a8a930ce81e77fc70baba29c..0000000000000000000000000000000000000000 Binary files a/craftsman/models/conditional_encoders/clip/__pycache__/modeling_conditional_clip.cpython-38.pyc and /dev/null differ diff --git a/craftsman/models/conditional_encoders/clip/modeling_clip.py b/craftsman/models/conditional_encoders/clip/modeling_clip.py index 5027c30dbde6d809867bf9ca3b4ed50f4be44c03..cefb5c90e38f100934e3756c850c07dfcadbbba3 100755 --- a/craftsman/models/conditional_encoders/clip/modeling_clip.py +++ b/craftsman/models/conditional_encoders/clip/modeling_clip.py @@ -149,7 +149,7 @@ class CLIPOutput(ModelOutput): text_model_output: BaseModelOutputWithPooling = None vision_model_output: BaseModelOutputWithPooling = None - def to_tuple(self): + def to_tuple(self) -> Tuple[Any]: return tuple( self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() for k in self.keys() diff --git a/craftsman/models/conditional_encoders/clip_encoder.py b/craftsman/models/conditional_encoders/clip_encoder.py deleted file mode 100755 index 326f3b67b4a33ecc969e90cba5e1e893c201aec7..0000000000000000000000000000000000000000 --- a/craftsman/models/conditional_encoders/clip_encoder.py +++ /dev/null @@ -1,172 +0,0 @@ -import random -import torch -from torch import nn -import numpy as np -from PIL import Image -from einops import rearrange -from dataclasses import dataclass -from torchvision.transforms import Normalize -from torchvision.transforms import InterpolationMode -from torchvision.transforms.transforms import _interpolation_modes_from_int -from torchvision import transforms - -from transformers import CLIPTokenizer, CLIPImageProcessor -from transformers.utils import ModelOutput -from typing import Iterable, Optional, Union, List - -import craftsman -from craftsman.utils.typing import * -from .clip.modeling_clip import CLIPModel -from .clip.modeling_conditional_clip import ConditionalCLIPModel -from .base import BaseEmbedder, ImageType - -@dataclass -class CLIPEmbedOutput(ModelOutput): - last_hidden_state: torch.FloatTensor = None - pooler_output: torch.FloatTensor = None - embeds: torch.FloatTensor = None - -@craftsman.register("clip-embedder") -class CLIPEmbedder(BaseEmbedder): - - @dataclass - class Config(BaseEmbedder.Config): - freeze_modulation: bool = False - config_path: str = '' - - cfg: Config - - def configure(self) -> None: - super().configure() - - # Load the CLIP model and processor - if not self.cfg.encode_camera: - self.model: CLIPModel = CLIPModel.from_pretrained(self.cfg.pretrained_model_name_or_path) - else: - if self.cfg.pretrained_model_name_or_path == '': - assert self.cfg.config_path is not None, "The config path should be provided" - conditional_clip_config = ConditionalCLIPModel.config_class.from_json_file(self.cfg.config_path) - conditional_clip_config.vision_config.modulation_dim = self.cfg.camera_embeds_dim - self.model: CLIPModel = ConditionalCLIPModel(conditional_clip_config) - else: - conditional_clip_config = ConditionalCLIPModel.config_class.from_pretrained( - self.cfg.pretrained_model_name_or_path, - ) - conditional_clip_config.vision_config.modulation_dim = self.cfg.camera_embeds_dim - self.model: CLIPModel = ConditionalCLIPModel.from_pretrained( - self.cfg.pretrained_model_name_or_path, - vision_config=conditional_clip_config.vision_config - ) - - self.tokenizer = None - self.image_preprocess = CLIPImageProcessor() - self.transform = transforms.Compose( - [ - transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True), - transforms.CenterCrop(224), # crop a (224, 224) square - transforms.Normalize( - mean=[0.48145466, 0.4578275, 0.40821073], - std=[0.26862954, 0.26130258, 0.27577711], - ), - ] - ) - - self.logit_scale = self.model.logit_scale.exp() - - if self.cfg.zero_uncond_embeds: - self.empty_text_embeds = torch.zeros((1, 77, 768)).detach() - self.empty_image_embeds = torch.zeros((self.cfg.n_views, 257, 1024)).detach() - else: - try: - self.empty_text_embeds = self.encode_text([""]).detach() # [1, 77, 768] - except: - self.empty_text_embeds = None - if self.cfg.encode_camera: - self.empty_image_embeds = self.encode_image(torch.zeros(self.cfg.n_views, 224, 224, 3), self.cameras[:self.cfg.n_views]).detach() - else: - self.empty_image_embeds = self.encode_image(torch.zeros(self.cfg.n_views, 224, 224, 3)).detach() - - # Freeze the model parameters - self.model.eval() - for k, p in self.model.named_parameters(): - ks = k.split('.') - if 'mod_norm1' in ks or 'mod_norm2' in ks and not self.cfg.freeze_modulation: - p.requires_grad_(True) - else: - p.requires_grad_(False) - - def encode_image(self, images: Iterable[Optional[ImageType]], cameras: Optional[torch.Tensor] = None, force_none_camera_embeds: bool = False, return_dict: bool = False, **kwargs) -> torch.FloatTensor: - camera_embeds = None - if isinstance(images, (np.ndarray, torch.Tensor)): # for training process - assert images.min() >= 0.0 and images.max() <= 1.0, "The pixel values should be in the range of [0, 1]" - do_rescale = False - if self.cfg.encode_camera: - assert cameras is not None, "The cameras should be provided" - camera_embeds = self.encode_camera(cameras) - pixel_values = self.transform(images.permute(0, 3, 1, 2)) - else: # for inference process - do_rescale = True - if self.cfg.encode_camera: - if cameras is None: - bs = len(images) // self.cfg.n_views - cameras = self.cameras[:self.cfg.n_views].repeat(bs, 1, 1).to(self.model.device) - camera_embeds = self.encode_camera(cameras) - pixel_values = self.image_preprocess.preprocess(images, return_tensors='pt', do_rescale=do_rescale).pixel_values - - if force_none_camera_embeds: - camera_embeds = None - - packed = False - if pixel_values.ndim == 4: - packed = True - pixel_values = pixel_values.unsqueeze(1) - if camera_embeds is not None: - camera_embeds = camera_embeds.unsqueeze(1) - - if self.cfg.encode_camera and camera_embeds is not None: - vision_outputs = self.model.vision_model( - pixel_values=rearrange(pixel_values.to(self.model.device), "B N C H W -> (B N) C H W"), - condition=rearrange(camera_embeds, "B N C -> (B N) C") - ) - else: - vision_outputs = self.model.vision_model( - pixel_values=rearrange(pixel_values.to(self.model.device), "B N C H W -> (B N) C H W"), - ) - - if return_dict: - pooler_output = vision_outputs[1] # pooled_output - image_features = self.model.visual_projection(pooler_output) - - return CLIPEmbedOutput( - last_hidden_state=vision_outputs.last_hidden_state, - pooler_output=pooler_output, - embeds=image_features - ) - else: - return vision_outputs.last_hidden_state - - @torch.no_grad() - def encode_text(self, text_inputs: torch.Tensor, return_dict: bool = False) -> torch.FloatTensor: - if self.tokenizer is None: - self.tokenizer = CLIPTokenizer.from_pretrained(self.cfg.pretrained_model_name_or_path) - - if isinstance(text_inputs, list): - text_inputs = self.tokenizer( - text_inputs, - max_length=self.tokenizer.model_max_length, - padding="max_length", - return_tensors="pt" - ).input_ids - text_outputs = self.model.text_model(input_ids=text_inputs.to(self.model.device)) - - pooler_output = text_outputs[1] # pooled_output - text_features = self.model.text_projection(pooler_output) - - if return_dict: - return CLIPEmbedOutput( - last_hidden_state=text_outputs.last_hidden_state, - pooler_output=pooler_output, - embeds=text_features - ) - else: - return text_outputs.last_hidden_state \ No newline at end of file diff --git a/craftsman/models/conditional_encoders/cond_encoder.py b/craftsman/models/conditional_encoders/cond_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1faa5eb44fc421e84af7e8f5c54a4c4389396464 --- /dev/null +++ b/craftsman/models/conditional_encoders/cond_encoder.py @@ -0,0 +1,305 @@ +import random +import torch +from torch import nn +import numpy as np +import re +from einops import rearrange +from dataclasses import dataclass +from torchvision import transforms + +from transformers import CLIPTokenizer, CLIPImageProcessor +from transformers import AutoImageProcessor +from transformers import T5EncoderModel, T5Tokenizer, AutoTokenizer +from transformers.utils import ModelOutput +from typing import Iterable, Optional, Union, List + +import craftsman +from craftsman.utils.typing import * +from .clip.modeling_clip import CLIPModel +from .clip.modeling_conditional_clip import ConditionalCLIPModel +from .base import BaseEmbedder, ImageType +from .dino_v2.modeling_dinov2 import Dinov2Model +from .dino_v2.modeling_conditional_dinov2 import ConditionalDinov2Model + +@dataclass +class CLIPEmbedOutput(ModelOutput): + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + embeds: torch.FloatTensor = None + +class DINOEmbedOutput(ModelOutput): + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + +@craftsman.register("cond-embedder") +class CondEmbedder(BaseEmbedder): + + @dataclass + class Config(BaseEmbedder.Config): + pretrained_model_name_or_path: Optional[str] = None # the pretrained model name or path for condition model + pretrained_clip_name_or_path: Optional[str] = None # the pretrained model name or path for clip + pretrained_dino_name_or_path: Optional[str] = None # the pretrained model name or path for dino + pretrained_linear_proj: Optional[str] = None + freeze_modulation_clip: bool = False + freeze_modulation_dino: bool = False + config_path: str = '' + enable_gradient_checkpointing: bool = False + embeds_fusion_mode: int = 1 # 0: sum | 1: concat + linear_proj_init: str = "constant" + text_model_type: str = "clip" + text_max_length: int = 77 + image_size_clip: int = 224 + image_size_dino: int = 224 + + cfg: Config + + def configure(self) -> None: + super().configure() + + # Load the CLIP model and processor + if not self.cfg.encode_camera: + if self.cfg.pretrained_clip_name_or_path is not None: + self.clip_model: CLIPModel = CLIPModel.from_pretrained(self.cfg.pretrained_clip_name_or_path) + else: + self.clip_model: CLIPModel = CLIPModel(config=ConditionalCLIPModel.config_class.from_pretrained( + "openai/clip-vit-large-patch14", + )) + if self.cfg.pretrained_dino_name_or_path is not None: + self.dino_model: Dinov2Model = Dinov2Model.from_pretrained(self.cfg.pretrained_dino_name_or_path) + else: + self.dino_model: Dinov2Model = Dinov2Model(config=ConditionalDinov2Model.config_class.from_pretrained( + "facebook/dinov2-base", + )) + else: + if self.cfg.pretrained_clip_name_or_path == '': + assert self.cfg.config_path is not None, "The config path should be provided" + conditional_clip_config = ConditionalCLIPModel.config_class.from_json_file(self.cfg.config_path) + conditional_clip_config.vision_config.modulation_dim = self.cfg.camera_embeds_dim + self.clip_model: CLIPModel = ConditionalCLIPModel(conditional_clip_config) + else: + + # clip + conditional_clip_config = ConditionalCLIPModel.config_class.from_pretrained( + self.cfg.pretrained_clip_name_or_path, + ) + conditional_clip_config.vision_config.modulation_dim = self.cfg.camera_embeds_dim + self.clip_model: CLIPModel = ConditionalCLIPModel.from_pretrained( + self.cfg.pretrained_clip_name_or_path, + vision_config=conditional_clip_config.vision_config + ) + + # dino + conditional_vit_config = ConditionalDinov2Model.config_class.from_pretrained( + self.cfg.pretrained_dino_name_or_path, + ) + conditional_vit_config.modulation_dim = self.cfg.camera_embeds_dim + self.dino_model: ConditionalDinov2Model = ConditionalDinov2Model.from_pretrained( + self.cfg.pretrained_dino_name_or_path, + config=conditional_vit_config + ) + + self.image_preprocess_clip = CLIPImageProcessor() + self.image_preprocess_dino = AutoImageProcessor.from_pretrained( + self.cfg.pretrained_dino_name_or_path if self.cfg.pretrained_dino_name_or_path is not None else "facebook/dinov2-base", + ) + self.transform_clip= transforms.Compose( + [ + transforms.Resize(self.cfg.image_size_clip, transforms.InterpolationMode.BICUBIC, antialias=True), + transforms.CenterCrop(self.cfg.image_size_clip), # crop a (224, 224) square + transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + ), + ] + ) + + self.transform_dino = transforms.Compose( + [ + transforms.Resize(self.cfg.image_size_dino, transforms.InterpolationMode.BICUBIC, antialias=True), + transforms.CenterCrop(self.cfg.image_size_dino), # crop a (224, 224) square + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ), + ] + ) + + if self.cfg.enable_gradient_checkpointing: + self.dino_model.encoder.gradient_checkpointing = True + + if self.cfg.zero_uncond_embeds: + self.empty_image_embeds_clip = torch.zeros((self.cfg.n_views, 257, 1024)).detach() + self.empty_image_embeds_dino = torch.zeros((self.cfg.n_views, 257, 1024)).detach() + self.empty_image_embeds = torch.cat([self.empty_image_embeds_clip, self.empty_image_embeds_dino], dim=1) + else: + if self.cfg.encode_camera: + self.empty_image_embeds_clip = self.encode_image_clip(torch.zeros(self.cfg.n_views, self.cfg.image_size_clip, self.cfg.image_size_clip, 3), self.cameras[:self.cfg.n_views]).detach() + self.empty_image_embeds_dino = self.encode_image_dino(torch.zeros(self.cfg.n_views, self.cfg.image_size_clip, self.cfg.image_size_clip, 3), self.cameras[:self.cfg.n_views]).detach() + self.empty_image_embeds = torch.cat([self.empty_image_embeds_clip, self.empty_image_embeds_dino], dim=1) + else: + self.empty_image_embeds_clip = self.encode_image_clip(torch.zeros(self.cfg.n_views, self.cfg.image_size_dino, self.cfg.image_size_dino, 3)).detach() + self.empty_image_embeds_dino = self.encode_image_dino(torch.zeros(self.cfg.n_views, self.cfg.image_size_dino, self.cfg.image_size_dino, 3)).detach() + self.empty_image_embeds = torch.cat([self.empty_image_embeds_clip, self.empty_image_embeds_dino], dim=1) + + # Freeze the clip model parameters + self.clip_model.eval() + for k, p in self.clip_model.named_parameters(): + ks = k.split('.') + if 'mod_norm1' in ks or 'mod_norm2' in ks and not self.cfg.freeze_modulation_clip: + p.requires_grad_(not self.cfg.freeze_modulation_clip) + else: + p.requires_grad_(False) + + # freeze the dino model parameters + self.dino_model.eval() + for k, p in self.dino_model.named_parameters(): + ks = k.split('.') + if 'mod_norm1' in ks or 'mod_norm2' in ks and not self.cfg.freeze_modulation_dino: + p.requires_grad_(not self.cfg.freeze_modulation_dino) + else: + p.requires_grad_(False) + + self.linear_proj = nn.Linear(768, 1024, bias=False) + if self.cfg.linear_proj_init == "constant": + nn.init.constant_(self.linear_proj.weight, 0) + elif self.cfg.linear_proj_init == "xavier": + nn.init.xavier_uniform_(self.linear_proj.weight) + else: + raise ValueError + + if self.cfg.pretrained_model_name_or_path is not None: + print(f"Loading ckpt from {self.cfg.pretrained_model_name_or_path}") + ckpt = torch.load(self.cfg.pretrained_model_name_or_path, map_location="cpu")['state_dict'] + pretrained_model_ckpt = {} + for k, v in ckpt.items(): + if k.startswith('condition.'): + pretrained_model_ckpt[k.replace('condition.', '')] = v + self.load_state_dict(pretrained_model_ckpt, strict=False) + + def encode_image_clip(self, images: Iterable[Optional[ImageType]], cameras: Optional[torch.Tensor] = None, force_none_camera_embeds: bool = False, return_dict: bool = False, **kwargs) -> torch.FloatTensor: + camera_embeds = None + if isinstance(images, (np.ndarray, torch.Tensor)): # for training process + assert images.min() >= 0.0 and images.max() <= 1.0, "The pixel values should be in the range of [0, 1]" + do_rescale = False + if self.cfg.encode_camera: + assert cameras is not None, "The cameras should be provided" + camera_embeds = self.encode_camera(cameras) + pixel_values = self.transform_clip(images.permute(0, 3, 1, 2)) + else: # for inference process + do_rescale = True + if self.cfg.encode_camera: + if cameras is None: + bs = len(images) // self.cfg.n_views + cameras = self.cameras[:self.cfg.n_views].repeat(bs, 1, 1).to(self.clip_model.device) + camera_embeds = self.encode_camera(cameras) + pixel_values = self.image_preprocess_clip.preprocess(images, return_tensors='pt', do_rescale=do_rescale).pixel_values + + if force_none_camera_embeds: + camera_embeds = None + + if pixel_values.ndim == 4: + pixel_values = pixel_values.unsqueeze(1) + if camera_embeds is not None: + camera_embeds = camera_embeds.unsqueeze(1) + + if self.cfg.encode_camera and camera_embeds is not None: + vision_outputs = self.clip_model.vision_model( + pixel_values=rearrange(pixel_values.to(self.clip_model.device), "B N C H W -> (B N) C H W"), + condition=rearrange(camera_embeds, "B N C -> (B N) C") + ) + + else: + vision_outputs = self.clip_model.vision_model( + pixel_values=rearrange(pixel_values.to(self.clip_model.device), "B N C H W -> (B N) C H W"), + ) + + if return_dict: + # clip + pooler_output = vision_outputs[1] # pooled_output + image_features = self.clip_model.visual_projection(pooler_output) + + clip_embeds = vision_outputs.last_hidden_state + + clip_embeds_dict = CLIPEmbedOutput( + last_hidden_state=clip_embeds, + pooler_output=pooler_output, + embeds=image_features + ) + + return clip_embeds_dict + else: + return vision_outputs.last_hidden_state + + def encode_image_dino(self, images: Iterable[Optional[ImageType]], cameras: Optional[torch.Tensor] = None, force_none_camera_embeds: bool = False, return_dict: bool = False, **kwargs) -> torch.FloatTensor: + camera_embeds = None + if isinstance(images, (np.ndarray, torch.Tensor)): # for training process + assert images.min() >= 0.0 and images.max() <= 1.0, "The pixel values should be in the range of [0, 1]" + do_rescale = False + if self.cfg.encode_camera: + assert cameras is not None, "The cameras should be provided" + camera_embeds = self.encode_camera(cameras) + pixel_values = self.transform_dino(images.permute(0, 3, 1, 2)) + else: # for inference process + do_rescale = True + if self.cfg.encode_camera: + if cameras is None: + bs = len(images) // self.cfg.n_views + cameras = self.cameras[:self.cfg.n_views].repeat(bs, 1, 1).to(self.dino_model.device) + camera_embeds = self.encode_camera(cameras) + pixel_values = self.image_preprocess_dino.preprocess(images, return_tensors='pt', do_rescale=do_rescale).pixel_values + + if force_none_camera_embeds: + camera_embeds = None + + if pixel_values.ndim == 4: + pixel_values = pixel_values.unsqueeze(1) + if camera_embeds is not None: + camera_embeds = camera_embeds.unsqueeze(1) + + if self.cfg.encode_camera and camera_embeds is not None: + vision_outputs = self.dino_model( + rearrange(pixel_values.to(self.dino_model.device), "B N C H W -> (B N) C H W"), + condition=rearrange(camera_embeds, "B N C -> (B N) C"), + ) + else: + + vision_outputs = self.dino_model( + rearrange(pixel_values.to(self.dino_model.device), "B N C H W -> (B N) C H W"), + ) + + if return_dict: + # dino + dino_embeds_dict = DINOEmbedOutput( + last_hidden_state=vision_outputs.last_hidden_state, + pooler_output=vision_outputs.pooler_output, + ) + return dino_embeds_dict + else: + return vision_outputs.last_hidden_state + + def post_process_embeds(self, text_embeds, visual_embeds): + clip_embeds, dino_embeds = visual_embeds.chunk(2, dim=2) + if self.cfg.normalize_embeds: + # post-process the text/visual embeds + if text_embeds is not None: + text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) + if clip_embeds is not None: + clip_embeds = clip_embeds / clip_embeds.norm(dim=-1, keepdim=True) + if dino_embeds is not None: + dino_embeds = dino_embeds / dino_embeds.norm(dim=-1, keepdim=True) + + assert text_embeds is not None or dino_embeds is not None or clip_embeds is not None + + if text_embeds is not None and visual_embeds is not None: + return torch.cat([text_embeds, visual_embeds], dim=1) + elif text_embeds is not None: + return text_embeds + else: + return visual_embeds + + def encode_image(self, images: Iterable[Optional[ImageType]], cameras: Optional[torch.Tensor] = None, force_none_camera_embeds: bool = False, return_dict: bool = False, **kwargs) -> torch.FloatTensor: + clip_embeds = self.encode_image_clip(images, cameras) + dino_embeds = self.encode_image_dino(images, cameras) + dino_embeds = self.linear_proj(dino_embeds) + visual_embeds = torch.cat([clip_embeds, dino_embeds], dim=1) + return visual_embeds \ No newline at end of file diff --git a/craftsman/models/conditional_encoders/dino_v2/__pycache__/modeling_conditional_dinov2.cpython-310.pyc b/craftsman/models/conditional_encoders/dino_v2/__pycache__/modeling_conditional_dinov2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8cc31eaf7786f8da942a578c437e98187aead95 Binary files /dev/null and b/craftsman/models/conditional_encoders/dino_v2/__pycache__/modeling_conditional_dinov2.cpython-310.pyc differ diff --git a/craftsman/models/conditional_encoders/dino_v2/__pycache__/modeling_dinov2.cpython-310.pyc b/craftsman/models/conditional_encoders/dino_v2/__pycache__/modeling_dinov2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddafb045cd5591010924c1b676245338418afee4 Binary files /dev/null and b/craftsman/models/conditional_encoders/dino_v2/__pycache__/modeling_dinov2.cpython-310.pyc differ diff --git a/craftsman/models/conditional_encoders/dino_v2/modeling_conditional_dinov2.py b/craftsman/models/conditional_encoders/dino_v2/modeling_conditional_dinov2.py new file mode 100755 index 0000000000000000000000000000000000000000..d7e085f1e76dac690ed29c7bc7a481650e5ab708 --- /dev/null +++ b/craftsman/models/conditional_encoders/dino_v2/modeling_conditional_dinov2.py @@ -0,0 +1,228 @@ +# coding=utf-8 +# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Reference: +# * transformers/models/dinov2/modeling_dinov2.py +# * https://github.com/facebookresearch/DiT/blob/main/models.py#L101 +# * https://github.com/3DTopia/OpenLRM/tree/main/openlrm/models/encoders/dinov2 +""" PyTorch DINOv2 model.""" + +from typing import Dict, List, Optional, Set, Tuple, Union + +import torch +import torch.nn as nn + +from .modeling_dinov2 import ( + Dinov2Config, + Dinov2Layer, + Dinov2Model, + Dinov2Embeddings, + BaseModelOutput, + BaseModelOutputWithPooling, +) + + +class ModLN(nn.Module): + def __init__(self, inner_dim: int, mod_dim: int = 1024): + super().__init__() + self.mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(mod_dim, inner_dim * 2), + ) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.weight) + nn.init.zeros_(m.bias) + + def forward(self, x:torch.Tensor, condition:torch.Tensor): + ''' + x: [N, M, C_in], M: num of tokens + condition: [N, C_mod] + ''' + shift, scale = self.mlp(condition).unsqueeze(1).chunk(2, dim=-1) + return x * (1 + scale) + shift + + +class ConditionalDinov2Config(Dinov2Config): + def __init__(self, modulation_dim: int = 1024, *args, **kwargs): + super().__init__(*args, **kwargs) + self.modulation_dim = modulation_dim + + +class ConditionalDinov2Layer(Dinov2Layer): + """This corresponds to the Block class in the original implementation.""" + + def __init__(self, config: ConditionalDinov2Config) -> None: + super().__init__(config) + self.mod_norm1 = ModLN(config.hidden_size, config.modulation_dim) + self.mod_norm2 = ModLN(config.hidden_size, config.modulation_dim) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + condition: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_attention_outputs = self.attention( + self.mod_norm1(self.norm1(hidden_states), condition), # in Dinov2, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + + attention_output = self.layer_scale1(attention_output) + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = self.drop_path(attention_output) + hidden_states + + # in Dinov2, layernorm is also applied after self-attention + layer_output = self.mod_norm2(self.norm2(hidden_states), condition) + layer_output = self.mlp(layer_output) + layer_output = self.layer_scale2(layer_output) + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2 +class ConditionalDinov2Encoder(nn.Module): + def __init__(self, config: ConditionalDinov2Config) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([ConditionalDinov2Layer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + condition: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + layer_head_mask, + condition, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + layer_head_mask, + condition, + output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ConditionalDinov2Model(Dinov2Model): + config_class = ConditionalDinov2Config + def __init__(self, config: ConditionalDinov2Config): + super().__init__(config) + self.config = config + + self.embeddings = Dinov2Embeddings(config) + self.encoder = ConditionalDinov2Encoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + condition: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + condition=condition, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = sequence_output[:, 0, :] + + if not return_dict: + head_outputs = (sequence_output, pooled_output) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + diff --git a/craftsman/models/conditional_encoders/dino_v2/modeling_dinov2.py b/craftsman/models/conditional_encoders/dino_v2/modeling_dinov2.py new file mode 100755 index 0000000000000000000000000000000000000000..d2f1c23ff6a1753f7f40a2a5cf06782530ce9da8 --- /dev/null +++ b/craftsman/models/conditional_encoders/dino_v2/modeling_dinov2.py @@ -0,0 +1,859 @@ +# coding=utf-8 +# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DINOv2 model.""" + + +import collections.abc +import math +from typing import Dict, List, Optional, Set, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BackboneOutput, + BaseModelOutput, + BaseModelOutputWithPooling, + ImageClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.utils.backbone_utils import BackboneMixin +from transformers.models.dinov2.configuration_dinov2 import Dinov2Config + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "Dinov2Config" + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/dinov2-base" +_EXPECTED_OUTPUT_SHAPE = [1, 257, 768] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-small-imagenet1k-1-layer" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + + +DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/dinov2-base", + # See all DINOv2 models at https://huggingface.co./models?filter=dinov2 +] + + +class Dinov2Embeddings(nn.Module): + """ + Construct the CLS token, mask token, position and patch embeddings. + """ + + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size)) + self.patch_embeddings = Dinov2PatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + height = height // self.config.patch_size + width = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + 0.1, width + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + target_dtype = patch_pos_embed.dtype + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.to(dtype=torch.float32), + scale_factor=(float(height / math.sqrt(num_positions)), float(width / math.sqrt(num_positions))), + mode="bicubic", + align_corners=False, + ).to(dtype=target_dtype) + if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: + raise ValueError("Width or height does not match with the interpolated position embeddings") + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + target_dtype = self.patch_embeddings.projection.weight.dtype + embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + + if bool_masked_pos is not None: + embeddings = torch.where( + bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings + ) + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + + embeddings = self.dropout(embeddings) + + return embeddings + + +class Dinov2PatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2 +class Dinov2SelfAttention(nn.Module): + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2 +class Dinov2SelfOutput(nn.Module): + """ + The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2 +class Dinov2Attention(nn.Module): + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + self.attention = Dinov2SelfAttention(config) + self.output = Dinov2SelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class Dinov2LayerScale(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size)) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + return hidden_state * self.lambda1 + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitDropPath +class Dinov2DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class Dinov2MLP(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + self.fc1 = nn.Linear(in_features, hidden_features, bias=True) + if isinstance(config.hidden_act, str): + self.activation = ACT2FN[config.hidden_act] + else: + self.activation = config.hidden_act + self.fc2 = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.fc1(hidden_state) + hidden_state = self.activation(hidden_state) + hidden_state = self.fc2(hidden_state) + return hidden_state + + +class Dinov2SwiGLUFFN(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + + self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True) + self.weights_out = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.weights_in(hidden_state) + x1, x2 = hidden_state.chunk(2, dim=-1) + hidden = nn.functional.silu(x1) * x2 + return self.weights_out(hidden) + + +class Dinov2Layer(nn.Module): + """This corresponds to the Block class in the original implementation.""" + + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + + self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attention = Dinov2Attention(config) + self.layer_scale1 = Dinov2LayerScale(config) + self.drop_path = Dinov2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() + + self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.use_swiglu_ffn: + self.mlp = Dinov2SwiGLUFFN(config) + else: + self.mlp = Dinov2MLP(config) + self.layer_scale2 = Dinov2LayerScale(config) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_attention_outputs = self.attention( + self.norm1(hidden_states), # in Dinov2, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + + attention_output = self.layer_scale1(attention_output) + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = self.drop_path(attention_output) + hidden_states + + # in Dinov2, layernorm is also applied after self-attention + layer_output = self.norm2(hidden_states) + layer_output = self.mlp(layer_output) + layer_output = self.layer_scale2(layer_output) + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2 +class Dinov2Encoder(nn.Module): + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([Dinov2Layer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + layer_head_mask, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class Dinov2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Dinov2Config + base_model_prefix = "dinov2" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, Dinov2Embeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + + +DINOV2_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`Dinov2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DINOV2_BASE_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`BitImageProcessor.preprocess`] for details. + + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for + pre-training. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +DINOV2_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`BitImageProcessor.preprocess`] for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.", + DINOV2_START_DOCSTRING, +) +class Dinov2Model(Dinov2PreTrainedModel): + def __init__(self, config: Dinov2Config): + super().__init__(config) + self.config = config + + self.embeddings = Dinov2Embeddings(config) + self.encoder = Dinov2Encoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2PatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = sequence_output[:, 0, :] + + if not return_dict: + head_outputs = (sequence_output, pooled_output) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """ + Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state + of the [CLS] token) e.g. for ImageNet. + """, + DINOV2_START_DOCSTRING, +) +class Dinov2ForImageClassification(Dinov2PreTrainedModel): + def __init__(self, config: Dinov2Config) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.dinov2 = Dinov2Model(config) + + # Classifier head + self.classifier = ( + nn.Linear(config.hidden_size * 2, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.dinov2( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] # batch_size, sequence_length, hidden_size + + cls_token = sequence_output[:, 0] + patch_tokens = sequence_output[:, 1:] + + linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1) + + logits = self.classifier(linear_input) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Dinov2 backbone, to be used with frameworks like DETR and MaskFormer. + """, + DINOV2_START_DOCSTRING, +) +class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)] + self.embeddings = Dinov2Embeddings(config) + self.encoder = Dinov2Encoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2PatchEmbeddings: + return self.embeddings.patch_embeddings + + @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base") + >>> model = AutoBackbone.from_pretrained( + ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 768, 16, 16] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + embedding_output = self.embeddings(pixel_values) + + outputs = self.encoder( + embedding_output, output_hidden_states=True, output_attentions=output_attentions, return_dict=return_dict + ) + + hidden_states = outputs.hidden_states if return_dict else outputs[1] + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + if self.config.apply_layernorm: + hidden_state = self.layernorm(hidden_state) + if self.config.reshape_hidden_states: + hidden_state = hidden_state[:, 1:] + # this was actually a bug in the original implementation that we copied here, + # cause normally the order is height, width + batch_size, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + feature_maps += (hidden_state,) + + if not return_dict: + if output_hidden_states: + output = (feature_maps,) + outputs[1:] + else: + output = (feature_maps,) + outputs[2:] + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions if output_attentions else None, + ) diff --git a/craftsman/models/denoisers/__init__.py b/craftsman/models/denoisers/__init__.py index b878f033b5bfc7ed68b926c4c692ed9fca17bced..0b63ea141c0347a3ea586658f78871db08e12aa1 100755 --- a/craftsman/models/denoisers/__init__.py +++ b/craftsman/models/denoisers/__init__.py @@ -1,3 +1,3 @@ from . import ( - simple_denoiser, + pixart_denoiser ) diff --git a/craftsman/models/denoisers/__pycache__/__init__.cpython-310.pyc b/craftsman/models/denoisers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7403ff05f007ce3dcd68a6b68853f6d9e2b9c765 Binary files /dev/null and b/craftsman/models/denoisers/__pycache__/__init__.cpython-310.pyc differ diff --git a/craftsman/models/denoisers/__pycache__/__init__.cpython-38.pyc b/craftsman/models/denoisers/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 5cb6488f20b57944d0636465b6d086a0ae737f4e..0000000000000000000000000000000000000000 Binary files a/craftsman/models/denoisers/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/craftsman/models/denoisers/__pycache__/pixart_denoiser.cpython-310.pyc b/craftsman/models/denoisers/__pycache__/pixart_denoiser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64aab3f6a63158b80efed5c4714f60ef5cbd37c6 Binary files /dev/null and b/craftsman/models/denoisers/__pycache__/pixart_denoiser.cpython-310.pyc differ diff --git a/craftsman/models/denoisers/__pycache__/simple_denoiser.cpython-38.pyc b/craftsman/models/denoisers/__pycache__/simple_denoiser.cpython-38.pyc deleted file mode 100644 index 40fafaaf160a2990ced1878c5c84ecd94e462cf2..0000000000000000000000000000000000000000 Binary files a/craftsman/models/denoisers/__pycache__/simple_denoiser.cpython-38.pyc and /dev/null differ diff --git a/craftsman/models/denoisers/__pycache__/utils.cpython-310.pyc b/craftsman/models/denoisers/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a5b02a2bf10278ba7fc10ac184473d9712092ae Binary files /dev/null and b/craftsman/models/denoisers/__pycache__/utils.cpython-310.pyc differ diff --git a/craftsman/models/denoisers/pixart_denoiser.py b/craftsman/models/denoisers/pixart_denoiser.py new file mode 100755 index 0000000000000000000000000000000000000000..b34f1caa2d4bf6d8385106f619dcad6314cb7281 --- /dev/null +++ b/craftsman/models/denoisers/pixart_denoiser.py @@ -0,0 +1,160 @@ +from dataclasses import dataclass + +import torch +import torch.nn as nn +import math +import importlib +import craftsman +import re + +from typing import Optional +from craftsman.utils.base import BaseModule +from craftsman.models.denoisers.utils import * + +@craftsman.register("pixart-denoiser") +class PixArtDinoDenoiser(BaseModule): + @dataclass + class Config(BaseModule.Config): + pretrained_model_name_or_path: Optional[str] = None + input_channels: int = 32 + output_channels: int = 32 + n_ctx: int = 512 + width: int = 768 + layers: int = 28 + heads: int = 16 + context_dim: int = 1024 + n_views: int = 1 + context_ln: bool = True + skip_ln: bool = False + init_scale: float = 0.25 + use_checkpoint: bool = False + drop_path: float = 0. + variance_type: str = "" + img_pos_embed: bool = False + clip_weight: float = 1.0 + dino_weight: float = 1.0 + dit_block: str = "" + + cfg: Config + + def configure(self) -> None: + super().configure() + + # timestep embedding + self.time_embed = TimestepEmbedder(self.cfg.width) + + # x embedding + self.x_embed = nn.Linear(self.cfg.input_channels, self.cfg.width, bias=True) + + # context embedding + if self.cfg.context_ln: + self.clip_embed = nn.Sequential( + nn.LayerNorm(self.cfg.context_dim), + nn.Linear(self.cfg.context_dim, self.cfg.width), + ) + + self.dino_embed = nn.Sequential( + nn.LayerNorm(self.cfg.context_dim), + nn.Linear(self.cfg.context_dim, self.cfg.width), + ) + else: + self.clip_embed = nn.Linear(self.cfg.context_dim, self.cfg.width) + self.dino_embed = nn.Linear(self.cfg.context_dim, self.cfg.width) + + init_scale = self.cfg.init_scale * math.sqrt(1.0 / self.cfg.width) + drop_path = [x.item() for x in torch.linspace(0, self.cfg.drop_path, self.cfg.layers)] + ditblock = getattr(importlib.import_module("craftsman.models.denoisers.utils"), self.cfg.dit_block) + self.blocks = nn.ModuleList([ + ditblock( + width=self.cfg.width, + heads=self.cfg.heads, + init_scale=init_scale, + qkv_bias=self.cfg.drop_path, + use_flash=True, + drop_path=drop_path[i] + ) + for i in range(self.cfg.layers) + ]) + + self.t_block = nn.Sequential( + nn.SiLU(), + nn.Linear(self.cfg.width, 6 * self.cfg.width, bias=True) + ) + + # final layer + if self.cfg.variance_type.upper() in ["LEARNED", "LEARNED_RANGE"]: + self.output_channels = self.cfg.output_channels * 2 + else: + self.output_channels = self.cfg.output_channels + self.final_layer = T2IFinalLayer(self.cfg.width, self.output_channels) + + self.identity_initialize() + + if self.cfg.pretrained_model_name_or_path: + print(f"Loading pretrained model from {self.cfg.pretrained_model_name_or_path}") + ckpt = torch.load(self.cfg.pretrained_model_name_or_path, map_location="cpu")['state_dict'] + self.denoiser_ckpt = {} + for k, v in ckpt.items(): + if k.startswith('denoiser_model.'): + self.denoiser_ckpt[k.replace('denoiser_model.', '')] = v + self.load_state_dict(self.denoiser_ckpt, strict=False) + + def forward_with_dpmsolver(self, model_input, timestep, context): + """ + dpm solver donnot need variance prediction + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + model_out = self.forward(model_input, timestep, context) + if self.cfg.variance_type.upper() in ["LEARNED", "LEARNED_RANGE"]: + return model_out.chunk(2, dim=-1)[0] + else: + return model_out + + def identity_initialize(self): + for block in self.blocks: + nn.init.constant_(block.attn.c_proj.weight, 0) + nn.init.constant_(block.attn.c_proj.bias, 0) + nn.init.constant_(block.cross_attn.c_proj.weight, 0) + nn.init.constant_(block.cross_attn.c_proj.bias, 0) + nn.init.constant_(block.mlp.c_proj.weight, 0) + nn.init.constant_(block.mlp.c_proj.bias, 0) + + def forward(self, + model_input: torch.FloatTensor, + timestep: torch.LongTensor, + context: torch.FloatTensor): + + r""" + Args: + model_input (torch.FloatTensor): [bs, n_data, c] + timestep (torch.LongTensor): [bs,] + context (torch.FloatTensor): [bs, context_tokens, c] + + Returns: + sample (torch.FloatTensor): [bs, n_data, c] + + """ + + B, n_data, _ = model_input.shape + + # 1. time + t_emb = self.time_embed(timestep) + + # 2. conditions projector + context = context.view(B, self.cfg.n_views, -1, self.cfg.context_dim) + clip_feat, dino_feat = context.chunk(2, dim=2) + clip_cond = self.clip_embed(clip_feat.contiguous().view(B, -1, self.cfg.context_dim)) + dino_cond = self.dino_embed(dino_feat.contiguous().view(B, -1, self.cfg.context_dim)) + visual_cond = self.cfg.clip_weight * clip_cond + self.cfg.dino_weight * dino_cond + + # 4. denoiser + latent = self.x_embed(model_input) + + t0 = self.t_block(t_emb).unsqueeze(dim=1) + for block in self.blocks: + latent = auto_grad_checkpoint(block, latent, visual_cond, t0) + + latent = self.final_layer(latent, t_emb) + + return latent + diff --git a/craftsman/models/denoisers/simple_denoiser.py b/craftsman/models/denoisers/simple_denoiser.py deleted file mode 100755 index 2b674a3fab1bac10d7aefdc2e26d8b01f8248710..0000000000000000000000000000000000000000 --- a/craftsman/models/denoisers/simple_denoiser.py +++ /dev/null @@ -1,191 +0,0 @@ -from dataclasses import dataclass - -import torch -import torch.nn as nn -from typing import Optional -from diffusers.models.embeddings import Timesteps -import math - -import craftsman -from craftsman.models.transformers.attention import ResidualAttentionBlock -from craftsman.models.transformers.utils import init_linear, MLP -from craftsman.utils.base import BaseModule - - -class UNetDiffusionTransformer(nn.Module): - def __init__( - self, - *, - n_ctx: int, - width: int, - layers: int, - heads: int, - init_scale: float = 0.25, - qkv_bias: bool = False, - skip_ln: bool = False, - use_checkpoint: bool = False - ): - super().__init__() - - self.n_ctx = n_ctx - self.width = width - self.layers = layers - - self.encoder = nn.ModuleList() - for _ in range(layers): - resblock = ResidualAttentionBlock( - n_ctx=n_ctx, - width=width, - heads=heads, - init_scale=init_scale, - qkv_bias=qkv_bias, - use_checkpoint=use_checkpoint - ) - self.encoder.append(resblock) - - self.middle_block = ResidualAttentionBlock( - n_ctx=n_ctx, - width=width, - heads=heads, - init_scale=init_scale, - qkv_bias=qkv_bias, - use_checkpoint=use_checkpoint - ) - - self.decoder = nn.ModuleList() - for _ in range(layers): - resblock = ResidualAttentionBlock( - n_ctx=n_ctx, - width=width, - heads=heads, - init_scale=init_scale, - qkv_bias=qkv_bias, - use_checkpoint=use_checkpoint - ) - linear = nn.Linear(width * 2, width) - init_linear(linear, init_scale) - - layer_norm = nn.LayerNorm(width) if skip_ln else None - - self.decoder.append(nn.ModuleList([resblock, linear, layer_norm])) - - def forward(self, x: torch.Tensor): - - enc_outputs = [] - for block in self.encoder: - x = block(x) - enc_outputs.append(x) - - x = self.middle_block(x) - - for i, (resblock, linear, layer_norm) in enumerate(self.decoder): - x = torch.cat([enc_outputs.pop(), x], dim=-1) - x = linear(x) - - if layer_norm is not None: - x = layer_norm(x) - - x = resblock(x) - - return x - - -@craftsman.register("simple-denoiser") -class SimpleDenoiser(BaseModule): - - @dataclass - class Config(BaseModule.Config): - pretrained_model_name_or_path: Optional[str] = None - input_channels: int = 32 - output_channels: int = 32 - n_ctx: int = 512 - width: int = 768 - layers: int = 6 - heads: int = 12 - context_dim: int = 1024 - context_ln: bool = True - skip_ln: bool = False - init_scale: float = 0.25 - flip_sin_to_cos: bool = False - use_checkpoint: bool = False - - cfg: Config - - def configure(self) -> None: - super().configure() - - init_scale = self.cfg.init_scale * math.sqrt(1.0 / self.cfg.width) - - self.backbone = UNetDiffusionTransformer( - n_ctx=self.cfg.n_ctx, - width=self.cfg.width, - layers=self.cfg.layers, - heads=self.cfg.heads, - skip_ln=self.cfg.skip_ln, - init_scale=init_scale, - use_checkpoint=self.cfg.use_checkpoint - ) - self.ln_post = nn.LayerNorm(self.cfg.width) - self.input_proj = nn.Linear(self.cfg.input_channels, self.cfg.width) - self.output_proj = nn.Linear(self.cfg.width, self.cfg.output_channels) - - # timestep embedding - self.time_embed = Timesteps(self.cfg.width, flip_sin_to_cos=self.cfg.flip_sin_to_cos, downscale_freq_shift=0) - self.time_proj = MLP(width=self.cfg.width, init_scale=init_scale) - - if self.cfg.context_ln: - self.context_embed = nn.Sequential( - nn.LayerNorm(self.cfg.context_dim), - nn.Linear(self.cfg.context_dim, self.cfg.width), - ) - else: - self.context_embed = nn.Linear(self.cfg.context_dim, self.cfg.width) - - if self.cfg.pretrained_model_name_or_path: - pretrained_ckpt = torch.load(self.cfg.pretrained_model_name_or_path, map_location="cpu") - _pretrained_ckpt = {} - for k, v in pretrained_ckpt.items(): - if k.startswith('denoiser_model.'): - _pretrained_ckpt[k.replace('denoiser_model.', '')] = v - pretrained_ckpt = _pretrained_ckpt - if 'state_dict' in pretrained_ckpt: - _pretrained_ckpt = {} - for k, v in pretrained_ckpt['state_dict'].items(): - if k.startswith('denoiser_model.'): - _pretrained_ckpt[k.replace('denoiser_model.', '')] = v - pretrained_ckpt = _pretrained_ckpt - self.load_state_dict(pretrained_ckpt, strict=True) - - def forward(self, - model_input: torch.FloatTensor, - timestep: torch.LongTensor, - context: torch.FloatTensor): - - r""" - Args: - model_input (torch.FloatTensor): [bs, n_data, c] - timestep (torch.LongTensor): [bs,] - context (torch.FloatTensor): [bs, context_tokens, c] - - Returns: - sample (torch.FloatTensor): [bs, n_data, c] - - """ - - _, n_data, _ = model_input.shape - - # 1. time - t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1) - - # 2. conditions projector - context = self.context_embed(context) - - # 3. denoiser - x = self.input_proj(model_input) - x = torch.cat([t_emb, context, x], dim=1) - x = self.backbone(x) - x = self.ln_post(x) - x = x[:, -n_data:] # B, n_data, width - sample = self.output_proj(x) # B, n_data, embed_dim - - return sample \ No newline at end of file diff --git a/craftsman/models/denoisers/utils.py b/craftsman/models/denoisers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9e29cbbf784d8cfa0683eda7b7e72b0cf68356ce --- /dev/null +++ b/craftsman/models/denoisers/utils.py @@ -0,0 +1,330 @@ +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange +from itertools import repeat +from collections.abc import Iterable +from torch.utils.checkpoint import checkpoint, checkpoint_sequential +from timm.models.layers import DropPath +from craftsman.models.transformers.utils import MLP +from craftsman.models.transformers.attention import MultiheadAttention, MultiheadCrossAttention + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + def __init__( + self, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + bias=True, + ): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + self.flatten = flatten + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + +class DiTBlock(nn.Module): + """ + A PixArt block with adaptive layer norm (adaLN-single) conditioning. + """ + + def __init__(self, width, heads, init_scale=1.0, qkv_bias=True, use_flash=True, drop_path=0.0): + super().__init__() + self.norm1 = nn.LayerNorm(width, elementwise_affine=True, eps=1e-6) + self.attn = MultiheadAttention( + n_ctx=None, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + use_flash=use_flash + ) + self.cross_attn = MultiheadCrossAttention( + n_data=None, + width=width, + heads=heads, + data_width=None, + init_scale=init_scale, + qkv_bias=qkv_bias, + use_flash=use_flash, + ) + + self.norm2 = nn.LayerNorm(width, elementwise_affine=True, eps=1e-6) + + self.mlp = MLP(width=width, init_scale=init_scale) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.scale_shift_table = nn.Parameter(torch.randn(6, width) / width ** 0.5) + + def forward(self, x, visual_cond, t, **kwargs): + B, N, C = x.shape + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) + x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C)) + x = x + self.cross_attn(x, visual_cond) + x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) + + return x + +class DiTBlock_text(nn.Module): + """ + A PixArt block with adaptive layer norm (adaLN-single) conditioning. + """ + + def __init__(self, width, heads, init_scale=1.0, qkv_bias=True, use_flash=True, drop_path=0.0): + super().__init__() + self.norm1 = nn.LayerNorm(width, elementwise_affine=True, eps=1e-6) + self.attn = MultiheadAttention( + n_ctx=None, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + use_flash=use_flash + ) + self.cross_attn = MultiheadCrossAttention( + n_data=None, + width=width, + heads=heads, + data_width=None, + init_scale=init_scale, + qkv_bias=qkv_bias, + use_flash=use_flash, + ) + + self.cross_attn_extra = MultiheadCrossAttention( + n_data=None, + width=width, + heads=heads, + data_width=None, + init_scale=init_scale, + qkv_bias=qkv_bias, + use_flash=use_flash, + ) + self.norm2 = nn.LayerNorm(width, elementwise_affine=True, eps=1e-6) + + self.mlp = MLP(width=width, init_scale=init_scale) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.scale_shift_table = nn.Parameter(torch.randn(6, width) / width ** 0.5) + + def forward(self, x, visual_cond, text_cond, t, **kwargs): + B, N, C = x.shape + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) + x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C)) + x = x + self.cross_attn(x, visual_cond) + x = x + self.cross_attn_extra(x, text_cond) + x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) + + return x + +class DiTBlock(nn.Module): + """ + A DiT block with adaptive layer norm (adaLN-single) conditioning. + """ + + def __init__(self, width, heads, init_scale=1.0, qkv_bias=True, use_flash=True, drop_path=0.0): + super().__init__() + self.norm1 = nn.LayerNorm(width, elementwise_affine=True, eps=1e-6) + self.attn = MultiheadAttention( + n_ctx=None, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + use_flash=use_flash + ) + self.cross_attn = MultiheadCrossAttention( + n_data=None, + width=width, + heads=heads, + data_width=None, + init_scale=init_scale, + qkv_bias=qkv_bias, + use_flash=use_flash, + ) + self.norm2 = nn.LayerNorm(width, elementwise_affine=True, eps=1e-6) + + self.mlp = MLP(width=width, init_scale=init_scale) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.scale_shift_table = nn.Parameter(torch.randn(6, width) / width ** 0.5) + + def forward(self, x, y, t, **kwargs): + B, N, C = x.shape + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) + x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C)) + x = x + self.cross_attn(x, y) + x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) + + return x + +def t2i_modulate(x, shift, scale): + return x * (1 + scale) + shift + +# def t2i_modulate(x, shift, scale): +# a = torch.ones_like(scale) +# a[..., 768:] = 0 +# return x * (a + scale) + shift + +def auto_grad_checkpoint(module, *args, **kwargs): + if getattr(module, 'grad_checkpointing', False): + if not isinstance(module, Iterable): + return checkpoint(module, *args, **kwargs) + gc_step = module[0].grad_checkpointing_step + return checkpoint_sequential(module, gc_step, *args, **kwargs) + return module(*args, **kwargs) + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(self.dtype) + t_emb = self.mlp(t_freq) + return t_emb + + @property + def dtype(self): + # 返回模型参数的数据类型 + return next(self.parameters()).dtype + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = t2i_modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class T2IFinalLayer(nn.Module): + """ + The final layer of PixArt. + """ + + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5) + self.out_channels = out_channels + + def forward(self, x, t): + shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1) + x = t2i_modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2. + omega = 1. / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + +def _ntuple(n): + def parse(x): + if isinstance(x, Iterable) and not isinstance(x, str): + return x + return tuple(repeat(x, n)) + return parse + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, pe_interpolation=1.0, base_size=16): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_size = to_2tuple(grid_size) + grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0]/base_size) / pe_interpolation + grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1]/base_size) / pe_interpolation + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed \ No newline at end of file diff --git a/craftsman/models/geometry/__pycache__/__init__.cpython-310.pyc b/craftsman/models/geometry/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e8a39cd825569e0bbe4ba22f7241f6cbedc7069 Binary files /dev/null and b/craftsman/models/geometry/__pycache__/__init__.cpython-310.pyc differ diff --git a/craftsman/models/geometry/__pycache__/__init__.cpython-38.pyc b/craftsman/models/geometry/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index b1cdda25f3d4d8ee4c4d381118069a3d67dbc7d7..0000000000000000000000000000000000000000 Binary files a/craftsman/models/geometry/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/craftsman/models/geometry/__pycache__/base.cpython-310.pyc b/craftsman/models/geometry/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54b91b362964b3f4242d48775f76166c08e6f05f Binary files /dev/null and b/craftsman/models/geometry/__pycache__/base.cpython-310.pyc differ diff --git a/craftsman/models/geometry/__pycache__/base.cpython-38.pyc b/craftsman/models/geometry/__pycache__/base.cpython-38.pyc deleted file mode 100644 index 3f8d4aa25c2cdb83fbd6635f374ec39697449c56..0000000000000000000000000000000000000000 Binary files a/craftsman/models/geometry/__pycache__/base.cpython-38.pyc and /dev/null differ diff --git a/craftsman/models/geometry/__pycache__/utils.cpython-310.pyc b/craftsman/models/geometry/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..759c2dac939bd6029335df19094e417576e369e9 Binary files /dev/null and b/craftsman/models/geometry/__pycache__/utils.cpython-310.pyc differ diff --git a/craftsman/models/geometry/__pycache__/utils.cpython-38.pyc b/craftsman/models/geometry/__pycache__/utils.cpython-38.pyc deleted file mode 100644 index a12a1e75fc0f7b8d4561597e720e88c65aaa7611..0000000000000000000000000000000000000000 Binary files a/craftsman/models/geometry/__pycache__/utils.cpython-38.pyc and /dev/null differ diff --git a/craftsman/models/geometry/base.py b/craftsman/models/geometry/base.py index 5e71f4f9217199778c869f69495c616f19cac634..aabea8a23e582879ef73507d50518ecf1950962d 100755 --- a/craftsman/models/geometry/base.py +++ b/craftsman/models/geometry/base.py @@ -32,7 +32,7 @@ class BaseGeometry(BaseModule): f"Cannot create {BaseGeometry.__name__} from {other.__class__.__name__}" ) - def export(self, *args, **kwargs): + def export(self, *args, **kwargs) -> Dict[str, Any]: return {} diff --git a/craftsman/models/transformers/__pycache__/attention.cpython-310.pyc b/craftsman/models/transformers/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86aaa2ab10591ad95b7931524ed4ced417f7c8ac Binary files /dev/null and b/craftsman/models/transformers/__pycache__/attention.cpython-310.pyc differ diff --git a/craftsman/models/transformers/__pycache__/attention.cpython-38.pyc b/craftsman/models/transformers/__pycache__/attention.cpython-38.pyc deleted file mode 100644 index 326ccc1ec84109a7fbbe117701bbbe7761d4c4bb..0000000000000000000000000000000000000000 Binary files a/craftsman/models/transformers/__pycache__/attention.cpython-38.pyc and /dev/null differ diff --git a/craftsman/models/transformers/__pycache__/perceiver_1d.cpython-310.pyc b/craftsman/models/transformers/__pycache__/perceiver_1d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b0fc923ea0ee99e65f6e38b6a3e98df4ba5dd08 Binary files /dev/null and b/craftsman/models/transformers/__pycache__/perceiver_1d.cpython-310.pyc differ diff --git a/craftsman/models/transformers/__pycache__/perceiver_1d.cpython-38.pyc b/craftsman/models/transformers/__pycache__/perceiver_1d.cpython-38.pyc deleted file mode 100644 index 8436a67a6b8e258fb4f12268cea3904d20f685ef..0000000000000000000000000000000000000000 Binary files a/craftsman/models/transformers/__pycache__/perceiver_1d.cpython-38.pyc and /dev/null differ diff --git a/craftsman/models/transformers/__pycache__/utils.cpython-310.pyc b/craftsman/models/transformers/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba2de6a0419f045de38bb33d2f42ef89a29a6fa1 Binary files /dev/null and b/craftsman/models/transformers/__pycache__/utils.cpython-310.pyc differ diff --git a/craftsman/models/transformers/__pycache__/utils.cpython-38.pyc b/craftsman/models/transformers/__pycache__/utils.cpython-38.pyc deleted file mode 100644 index 0e42e801d97cdacd9f824d868e00b5d8e7c509a6..0000000000000000000000000000000000000000 Binary files a/craftsman/models/transformers/__pycache__/utils.cpython-38.pyc and /dev/null differ diff --git a/craftsman/models/transformers/attention.py b/craftsman/models/transformers/attention.py old mode 100755 new mode 100644 index 7d4119b44b7bf84b238d9358bc1b0eb967cec032..81d179cf4e070c7a0520567b5e64dd3d7ed03792 --- a/craftsman/models/transformers/attention.py +++ b/craftsman/models/transformers/attention.py @@ -7,6 +7,128 @@ from craftsman.utils.typing import * from craftsman.utils.checkpoint import checkpoint from .utils import init_linear, MLP +from timm.models.vision_transformer import Attention + +def scaled_dot_product_gqa( + query: Tensor, + key: Tensor, + value: Tensor, + dropout: float = 0.0, + scale: Optional[float] = None, + mask: Optional[Tensor] = None, + is_causal: Optional[bool] = None, + need_weights: bool = False, + average_attn_weights: bool = False, + force_grouped: bool = False, +): + """Scaled dot product attention with support for grouped queries. + + Einstein notation: + - b: batch size + - n / s: sequence length + - h: number of heads + - g: number of groups + - d: dimension of query/key/value + + Args: + query: Query tensor of shape (b, n, h, d) + key: Key tensor of shape (b, s, h, d) + value: Value tensor of shape (b, s, h, d) + dropout: Dropout probability (default: 0.0) + scale: Scale factor for query (default: d_query ** 0.5) + mask: Mask tensor of shape (b, n, s) or (b, s). If 'ndim == 2', the mask is + applied to all 'n' rows of the attention matrix. (default: None) + force_grouped: If True, apply grouped-query attention even if the number of + heads is equal for query, key, and value. (default: False) + + Returns: + 2-tuple of: + - Attention output with shape (b, n, h, d) + - (Optional) Attention weights with shape (b, h, n, s). Only returned if + 'need_weights' is True. + """ + if (mask is not None) and (is_causal is not None): + raise ValueError( + "Only one of 'mask' and 'is_causal' should be provided, but got both." + ) + elif not query.ndim == key.ndim == value.ndim == 4: + raise ValueError( + f"Expected query, key, and value to be 4-dimensional, but got shapes " + f"{query.shape}, {key.shape}, and {value.shape}." + ) + + # Move sequence length dimension to axis 2. + # This makes the attention operations below *much* faster. + query = rearrange(query, "b n h d -> b h n d") + key = rearrange(key, "b s h d -> b h s d") + value = rearrange(value, "b s h d -> b h s d") + + bq, hq, nq, dq = query.shape + bk, hk, nk, dk = key.shape + bv, hv, nv, dv = value.shape + if not (bq == bk == bv and dq == dk == dv): + raise ValueError( + "Expected query, key, and value to have the same batch size (dim=0) and " + f"embedding dimension (dim=3), but got query: {query.shape}, " + f"key: {key.shape}, and value: {value.shape}." + ) + elif (hk != hv) or (nk != nv): + raise ValueError( + "Expected key and value to have the same size in dimensions 1 and 2, but " + f"got key: {key.shape} and value: {value.shape}." + ) + elif hq % hk != 0: + raise ValueError( + "Expected query heads to be a multiple of key/value heads, but got " + f"query: {query.shape} and key/value: {key.shape}." + ) + + if scale is None: + scale = query.size(-1) ** 0.5 + query = query / scale + + num_head_groups = hq // hk + query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups) + similarity = einsum(query, key, "b g h n d, b h s d -> b g h n s") + + if is_causal: + # Mask out the upper triangular portion of the attention matrix. This prevents + # the model from attending to tokens in the future. + mask = torch.ones((bq, nq, nk), device=query.device, dtype=torch.bool).tril_() + + if mask is not None: + # Expand mask to match the shape of the attention matrix. + # If mask is 2D, assume that it is applied to the key/value sequence dimension. + # Else if mask is 3D, assume that it is applied to the query/key/value sequence + # dimension for all attention heads. + # + if mask.ndim == 2: + mask = rearrange(mask, "b s -> b () () () s") + elif mask.ndim == 3: + mask = rearrange(mask, "b n s -> b () () n s") + # Mask similarity values by setting them to negative infinity. This guarantees + # that they will not contribute to the softmax computation below. + similarity.masked_fill_(~mask, torch.finfo(similarity.dtype).min) + + attention = F.softmax(similarity, dim=-1) + if dropout > 0.0: + attention = F.dropout(attention, p=dropout) + + # Apply attention matrix to the value Tensor. + out = einsum(attention, value, "b g h n s, b h s d -> b g h n d") + # Move head dimension back to axis 2 + out = rearrange(out, "b g h n d -> b n (h g) d") + + attn_weights: Optional[Tensor] = None + if need_weights: + # Move the sequence dimensions back to positions 1, 2. Move the head dimension + # to position 3. This more closely matches the return shape of the attention + # output: (b, n, h, d). + attn_weights = rearrange(attention, "b g h n s -> b n s (h g)") + if average_attn_weights: + attn_weights = attn_weights.mean(dim=1) + + return out, attn_weights class MultiheadAttention(nn.Module): def __init__( @@ -155,6 +277,7 @@ class QKVMultiheadCrossAttention(nn.Module): k, v = torch.split(kv, attn_ch, dim=-1) if self.use_flash: + q = q.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) diff --git a/craftsman/models/transformers/perceiver_1d.py b/craftsman/models/transformers/perceiver_1d.py old mode 100755 new mode 100644 diff --git a/craftsman/models/transformers/utils.py b/craftsman/models/transformers/utils.py old mode 100755 new mode 100644 index 1ed8bbf3eaf6e7bddc489dcc929ed1eab561409f..fb48fb9fe6019eedc9bd4f1eadf1eeb3fad3ca43 --- a/craftsman/models/transformers/utils.py +++ b/craftsman/models/transformers/utils.py @@ -18,4 +18,4 @@ class MLP(nn.Module): init_linear(self.c_proj, init_scale) def forward(self, x): - return self.c_proj(self.gelu(self.c_fc(x))) + return self.c_proj(self.gelu(self.c_fc(x))) \ No newline at end of file diff --git a/craftsman/pipeline.py b/craftsman/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..15edafd62330df4c7d3d282ab09bcd228a6353b4 --- /dev/null +++ b/craftsman/pipeline.py @@ -0,0 +1,269 @@ +import os +import warnings +from typing import Callable, List, Optional, Union, Dict, Any +import PIL.Image +import trimesh +import rembg +import torch +import numpy as np +from huggingface_hub import hf_hub_download +from diffusers.utils import BaseOutput + +import craftsman +from craftsman.utils.config import ExperimentConfig, load_config + +class MeshPipelineOutput(BaseOutput): + """ + Output class for image pipelines. + + Args: + images (`List[trimesh.Trimesh]` or `np.ndarray`) + List of denoised trimesh meshes of length `batch_size` or a tuple of NumPy array with shape `((vertices, 3), (faces, 3)) of length `batch_size``. + """ + + meshes: Union[List[trimesh.Trimesh], np.ndarray] + + +class CraftsManPipeline(): + """ + Pipeline for text-guided image to image generation using CraftsMan(https://github.com/wyysf-98/CraftsMan). + + Args: + feature_extractor ([`CLIPFeatureExtractor`]): + Feature extractor for image pre-processing before being encoded. + """ + def __init__( + self, + device: str, + cfg: ExperimentConfig, + system, + ): + self.device = device + self.cfg = cfg + self.system = system + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + A simpler version that instantiate a PyTorch diffusion pipeline from pretrained pipeline weights. + The pipeline is set in evaluation mode (`model.eval()`) by default. + """ + # 1. Download the checkpoints and configs + # use snapshot download here to get it working from from_pretrained + if not os.path.isdir(pretrained_model_name_or_path): + ckpt_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="model.ckpt", repo_type="model") + config_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="config.yaml", repo_type="model") + else: + ckpt_path = os.path.join(pretrained_model_name_or_path, "model.ckpt") + config_path = os.path.join(pretrained_model_name_or_path, "config.yaml") + + # 2. Load the model + device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu") + cfg = load_config(config_path) + system = craftsman.find(cfg.system_type)(cfg.system) + print(f"Restoring states from the checkpoint path at {ckpt_path} with config {cfg}") + ckpt = torch.load(ckpt_path, map_location=torch.device('cpu')) + system.load_state_dict( + ckpt["state_dict"] if "state_dict" in ckpt else ckpt, + ) + system = system.to(device).eval() + + return cls( + device=device, + cfg=cfg, + system=system + ) + + def check_inputs( + self, + image, + ): + r""" + Check if the inputs are valid. Raise an error if not. + """ + if isinstance(image, str): + assert os.path.isfile(image) or image.startswith("http"), "Input image must be a valid URL or a file path." + elif isinstance(image, (torch.Tensor, PIL.Image.Image)): + raise ValueError("Input image must be a `torch.Tensor` or `PIL.Image.Image`.") + + def preprocess_image( + self, + images_pil: List[PIL.Image.Image], + force: bool = False, + background_color: List[int] = [255, 255, 255], + foreground_ratio: float = 1.0, + ): + r""" + Crop and remote the background of the input image + Args: + image_pil (`List[PIL.Image.Image]`): + List of `PIL.Image.Image` objects representing the input image. + force (`bool`, *optional*, defaults to `False`): + Whether to force remove the background even if the image has an alpha channel. + Returns: + `List[PIL.Image.Image]`: List of `PIL.Image.Image` objects representing the preprocessed image. + """ + preprocessed_images = [] + for i in range(len(images_pil)): + image = images_pil[i] + do_remove = True + if image.mode == "RGBA" and image.getextrema()[3][0] < 255: + # explain why current do not rm bg + print("alhpa channl not enpty, skip remove background, using alpha channel as mask") + background = PIL.Image.new("RGBA", image.size, (*background_color, 0)) + image = PIL.Image.alpha_composite(background, image) + do_remove = False + do_remove = do_remove or force + if do_remove: + image = rembg.remove(image) + + # calculate the min bbox of the image + alpha = image.split()[-1] + image = image.crop(alpha.getbbox()) + + # Calculate the new size after rescaling + new_size = tuple(int(dim * foreground_ratio) for dim in image.size) + # Resize the image while maintaining the aspect ratio + resized_image = image.resize(new_size) + # Create a new image with the original size and white background + padded_image = PIL.Image.new("RGBA", image.size, (*background_color, 0)) + paste_position = ((image.width - resized_image.width) // 2, (image.height - resized_image.height) // 2) + padded_image.paste(resized_image, paste_position) + + # expand image to 1:1 + width, height = padded_image.size + if width == height: + preprocessed_images.append(padded_image) + continue + new_size = (max(width, height), max(width, height)) + new_image = PIL.Image.new("RGBA", new_size, (*background_color, 1)) + paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2) + new_image.paste(padded_image, paste_position) + preprocessed_images.append(new_image) + + return preprocessed_images + + @torch.no_grad() + def __call__( + self, + image: Union[torch.FloatTensor, PIL.Image.Image, str], + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + eta: float = 0.0, + num_meshes_per_prompt: Optional[int] = 1, + output_type: Optional[str] = "trimesh", + return_dict: bool = True, + seed: Optional[int] = None, + force_remove_background: bool = False, + background_color: List[int] = [255, 255, 255], + foreground_ratio: float = 0.95, + mc_depth: int = 8, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch. The image will be encoded to its CLIP/DINO-v2 embedding + which the DiT will be conditioned on. + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + eta (`float`, *optional*, defaults to 0.0): + The eta parameter as defined in [DDIM](https://arxiv.org/abs/2010.02502). `eta` is a parameter that + controls the amount of noise added to the latent space. It is only used with the DDIM scheduler and + will be ignored for other schedulers. `eta` should be between [0, 1]. + num_meshes_per_prompt (`int`, *optional*, defaults to 1): + The number of meshes to generate per prompt. + output_type (`str`, *optional*, defaults to `"trimesh"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image`, `latents` or `np.array of v and f`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + seed (`int`, *optional*, defaults to `None`): + Seed for the random number generator. Setting a seed will ensure reproducibility. + force_remove_background (`bool`, *optional*, defaults to `False`): + Whether to force remove the background even if the image has an alpha channel. + foreground_ratio (`float`, *optional*, defaults to 1.0): + The ratio of the foreground in the image. The foreground is the part of the image that is not the + background. The foreground is resized to the size of the background image while maintaining the aspect + ratio. The background is filled with black color. The foreground ratio should be between [0, 1]. + mc_depth (`int`, *optional*, defaults to 8): + The resolution of the Marching Cubes algorithm. The resolution is the number of cubes in the x, y, and z. + 8 means 2^8 = 256 cubes in each dimension. The higher the resolution, the more detailed the mesh will be. + Examples: + + Returns: + [`~MeshPipelineOutput`] or `tuple`: [`~MeshPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is a list with the generated meshes. + """ + # 0. Check inputs. Raise error if not correct + self.check_inputs( + image=image, + ) + + # 1. Define call parameters + if isinstance(image, torch.Tensor): + batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image) or isinstance(image, str): + batch_size = 1 + do_classifier_free_guidance = guidance_scale != 1.0 + + # 2. Preprocess input image + if isinstance(image, torch.Tensor): + images_pil = [TF.to_pil_image(image[i]) for i in range(image.shape[0])] + elif isinstance(image, PIL.Image.Image): + images_pil = [image] + elif isinstance(image, str): + if image.startswith("http"): + import requests + images_pil = [PIL.Image.open(requests.get(image, stream=True).raw)] + else: + images_pil = [PIL.Image.open(image)] + images_pil = self.preprocess_image( + images_pil, + force=force_remove_background, + background_color=background_color, + foreground_ratio=foreground_ratio + ) + + # 3. Inference + latents = self.system.sample( + {'image': images_pil}, + sample_times = num_meshes_per_prompt, + steps = num_inference_steps, + guidance_scale = guidance_scale, + eta = eta, + seed = seed + ) + + # 4. Post-processing + if not output_type == "latent": + mesh = [] + for i, cur_latents in enumerate(latents): + print(f"Generating mesh {i+1}/{num_meshes_per_prompt}") + mesh_v_f, has_surface = self.system.shape_model.extract_geometry( + cur_latents, + octree_depth=mc_depth, + extract_mesh_func="mc" + ) + + if output_type == "trimesh": + import trimesh + cur_mesh = trimesh.Trimesh(vertices=mesh_v_f[0][0], faces=mesh_v_f[0][1]) + mesh.append(cur_mesh) + elif output_type == "np": + mesh.append(mesh_v_f[0]) + else: + mesh = latents + + if not return_dict: + return tuple(mesh) + return MeshPipelineOutput(meshes=mesh) \ No newline at end of file diff --git a/craftsman/systems/__init__.py b/craftsman/systems/__init__.py index c2b32ffee706936f19bd2cb8ae4c7ba070cc4288..f2cca37702911e140de49df9868ef1c4649ed11b 100755 --- a/craftsman/systems/__init__.py +++ b/craftsman/systems/__init__.py @@ -1,4 +1,4 @@ from . import ( shape_autoencoder, - shape_diffusion, + pixart_diffusion, ) \ No newline at end of file diff --git a/craftsman/systems/__pycache__/__init__.cpython-310.pyc b/craftsman/systems/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a33a8274329244348289bbfd7de12c4a6533a0d1 Binary files /dev/null and b/craftsman/systems/__pycache__/__init__.cpython-310.pyc differ diff --git a/craftsman/systems/__pycache__/__init__.cpython-38.pyc b/craftsman/systems/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index efd8d6e9dd06baaeaaa60857791f65a1f37af7f4..0000000000000000000000000000000000000000 Binary files a/craftsman/systems/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/craftsman/systems/__pycache__/base.cpython-310.pyc b/craftsman/systems/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..368251a7d4d3b344031271b791f22c803dbea049 Binary files /dev/null and b/craftsman/systems/__pycache__/base.cpython-310.pyc differ diff --git a/craftsman/systems/__pycache__/base.cpython-38.pyc b/craftsman/systems/__pycache__/base.cpython-38.pyc deleted file mode 100644 index 0d3042ca213f65d0e2abf78bc54351750bf4e338..0000000000000000000000000000000000000000 Binary files a/craftsman/systems/__pycache__/base.cpython-38.pyc and /dev/null differ diff --git a/craftsman/systems/__pycache__/pixart_diffusion.cpython-310.pyc b/craftsman/systems/__pycache__/pixart_diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dc2e3115a0738aa6d00e90c50b83cae66274e60 Binary files /dev/null and b/craftsman/systems/__pycache__/pixart_diffusion.cpython-310.pyc differ diff --git a/craftsman/systems/__pycache__/shape_autoencoder.cpython-310.pyc b/craftsman/systems/__pycache__/shape_autoencoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12bf192a5b35186b6aa033695f0ea61c3272a554 Binary files /dev/null and b/craftsman/systems/__pycache__/shape_autoencoder.cpython-310.pyc differ diff --git a/craftsman/systems/__pycache__/shape_autoencoder.cpython-38.pyc b/craftsman/systems/__pycache__/shape_autoencoder.cpython-38.pyc deleted file mode 100644 index 10b054083a07ea2de75d0ef275da6f450f62b130..0000000000000000000000000000000000000000 Binary files a/craftsman/systems/__pycache__/shape_autoencoder.cpython-38.pyc and /dev/null differ diff --git a/craftsman/systems/__pycache__/shape_diffusion.cpython-38.pyc b/craftsman/systems/__pycache__/shape_diffusion.cpython-38.pyc deleted file mode 100644 index 7a0a47d423e16d335c8cef823ec9a53ae6d135e0..0000000000000000000000000000000000000000 Binary files a/craftsman/systems/__pycache__/shape_diffusion.cpython-38.pyc and /dev/null differ diff --git a/craftsman/systems/__pycache__/utils.cpython-310.pyc b/craftsman/systems/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7df3cfcd10fb8c92650f8213eba55c56fdd38a63 Binary files /dev/null and b/craftsman/systems/__pycache__/utils.cpython-310.pyc differ diff --git a/craftsman/systems/shape_diffusion.py b/craftsman/systems/pixart_diffusion.py old mode 100755 new mode 100644 similarity index 76% rename from craftsman/systems/shape_diffusion.py rename to craftsman/systems/pixart_diffusion.py index 5ab9725341e86cf859802ed5e07b2858b06cd41d..529167af2542f7d857defa39074e8047817b0b8c --- a/craftsman/systems/shape_diffusion.py +++ b/craftsman/systems/pixart_diffusion.py @@ -20,9 +20,9 @@ from diffusers import ( import craftsman from craftsman.systems.base import BaseSystem -from craftsman.utils.ops import generate_dense_grid_points from craftsman.utils.misc import get_rank from craftsman.utils.typing import * +from diffusers import DDIMScheduler def compute_snr(noise_scheduler, timesteps): """ @@ -49,6 +49,7 @@ def compute_snr(noise_scheduler, timesteps): snr = (alpha / sigma) ** 2 return snr + def ddim_sample(ddim_scheduler: DDIMScheduler, diffusion_model: torch.nn.Module, shape: Union[List[int], Tuple[int]], @@ -114,11 +115,15 @@ def ddim_sample(ddim_scheduler: DDIMScheduler, yield latents, t -@craftsman.register("shape-diffusion-system") -class ShapeDiffusionSystem(BaseSystem): +# DEBUG = True +@craftsman.register("pixart-diffusion-system") +class PixArtDiffusionSystem(BaseSystem): @dataclass class Config(BaseSystem.Config): val_samples_json: str = None + extract_mesh_func: str = "mc" + + # diffusion config z_scale_factor: float = 1.0 guidance_scale: float = 7.5 num_inference_steps: int = 50 @@ -155,60 +160,71 @@ class ShapeDiffusionSystem(BaseSystem): self.shape_model.requires_grad_(False) self.condition = craftsman.find(self.cfg.condition_model_type)(self.cfg.condition_model) - + self.denoiser_model = craftsman.find(self.cfg.denoiser_model_type)(self.cfg.denoiser_model) self.noise_scheduler = craftsman.find(self.cfg.noise_scheduler_type)(**self.cfg.noise_scheduler) - self.denoise_scheduler = craftsman.find(self.cfg.denoise_scheduler_type)(**self.cfg.denoise_scheduler) - self.z_scale_factor = self.cfg.z_scale_factor + self.denoise_scheduler = craftsman.find(self.cfg.denoise_scheduler_type)(**self.cfg.denoise_scheduler) - def forward(self, batch: Dict[str, Any]): - # encode shape latents - shape_embeds, kl_embed, posterior = self.shape_model.encode( + def forward(self, batch: Dict[str, Any], skip_noise=False) -> Dict[str, Any]: + # 1. encode shape latents + shape_embeds, kl_embed, _ = self.shape_model.encode( batch["surface"][..., :3 + self.cfg.shape_model.point_feats], sample_posterior=True ) - latents = kl_embed * self.z_scale_factor - cond_latents = self.condition(batch) - cond_latents = cond_latents.to(latents).view(latents.shape[0], -1, cond_latents.shape[-1]) + latents = kl_embed * self.cfg.z_scale_factor + + # 2. gain condition. assert not (text_cond and image_cond), "Only one of text or image condition must be provided." + if "image" in batch and batch['image'].dim() == 5: + if self.training: + bs, n_images = batch['image'].shape[:2] + batch['image'] = batch['image'].view(bs*n_images, *batch['image'].shape[-3:]) + else: + batch['image'] = batch['image'][:, 0, ...] + n_images = 1 + bs = batch['image'].shape[0] + cond_latents = self.condition(batch).to(latents) + latents = latents.unsqueeze(1).repeat(1, n_images, 1, 1) + latents = latents.view(bs*n_images, *latents.shape[-2:]) + else: + cond_latents = self.condition(batch).to(latents) + cond_latents = cond_latents.view(cond_latents.shape[0], -1, cond_latents.shape[-1]) - # Sample noise that we"ll add to the latents - # [batch_size, n_token, latent_dim] - noise = torch.randn_like(latents).to(latents) + # 3. sample noise that we"ll add to the latents + noise = torch.randn_like(latents).to(latents) # [batch_size, n_token, latent_dim] bs = latents.shape[0] - # Sample a random timestep for each motion + + # 4. Sample a random timestep for each motion timesteps = torch.randint( 0, - self.noise_scheduler.config.num_train_timesteps, + self.cfg.noise_scheduler.num_train_timesteps, (bs,), device=latents.device, ) - # import pdb; pdb.set_trace() - timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep - # x_t + + # 5. add noise noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps) - - # diffusion model forward + + # 6. diffusion model forward noise_pred = self.denoiser_model(noisy_z, timesteps, cond_latents) - # compute loss + # 7. compute loss if self.noise_scheduler.config.prediction_type == "epsilon": target = noise elif self.noise_scheduler.config.prediction_type == "v_prediction": target = self.noise_scheduler.get_velocity(latents, noise, timesteps) else: - raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.") + raise ValueError(f"Prediction Type: {self.noise_scheduler.prediction_type} not supported.") if self.cfg.snr_gamma == 0: if self.cfg.loss.loss_type == "l1": loss = F.l1_loss(noise_pred, target, reduction="mean") elif self.cfg.loss.loss_type in ["mse", "l2"]: loss = F.mse_loss(noise_pred, target, reduction="mean") else: - raise NotImplementedError(f"Loss Type: {self.cfg.loss.loss_type} not yet supported.") + raise ValueError(f"Loss Type: {self.cfg.loss.loss_type} not supported.") else: # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. # Since we predict the noise instead of x_0, the original formulation is slightly changed. @@ -219,7 +235,7 @@ class ShapeDiffusionSystem(BaseSystem): )[0] if self.noise_scheduler.config.prediction_type == "epsilon": mse_loss_weights = mse_loss_weights / snr - elif noise_scheduler.config.prediction_type == "v_prediction": + elif self.noise_scheduler.config.prediction_type == "v_prediction": mse_loss_weights = mse_loss_weights / (snr + 1) if self.cfg.loss.loss_type == "l1": @@ -227,16 +243,17 @@ class ShapeDiffusionSystem(BaseSystem): elif self.cfg.loss.loss_type in ["mse", "l2"]: loss = F.mse_loss(noise_pred, target, reduction="none") else: - raise NotImplementedError(f"Loss Type: {self.cfg.loss.loss_type} not yet supported.") + raise ValueError(f"Loss Type: {self.cfg.loss.loss_type} not supported.") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() + return { "loss_diffusion": loss, "latents": latents, - "x_0": noisy_z, + "x_t": x_t, "noise": noise, - "noise_pred": noise_pred, + "noise_pred": pred_noise, "timesteps": timesteps, } @@ -258,21 +275,16 @@ class ShapeDiffusionSystem(BaseSystem): @torch.no_grad() def validation_step(self, batch, batch_idx): self.eval() - + if get_rank() == 0: sample_inputs = json.loads(open(self.cfg.val_samples_json).read()) # condition sample_inputs_ = copy.deepcopy(sample_inputs) sample_outputs = self.sample(sample_inputs) # list for i, sample_output in enumerate(sample_outputs): - mesh_v_f, has_surface = self.shape_model.extract_geometry(sample_output, octree_depth=7) + mesh_v_f, has_surface = self.shape_model.extract_geometry(sample_output, octree_depth=7, extract_mesh_func=self.cfg.extract_mesh_func) + for j in range(len(mesh_v_f)): - if "text" in sample_inputs_ and "image" in sample_inputs_: - name = sample_inputs_["image"][j].split("/")[-1].replace(".png", "") - elif "text" in sample_inputs_ and "mvimage" in sample_inputs_: - name = sample_inputs_["mvimages"][j][0].split("/")[-2].replace(".png", "") - elif "text" in sample_inputs_: - name = sample_inputs_["text"][j].replace(" ", "_") - elif "image" in sample_inputs_: + if "image" in sample_inputs_: name = sample_inputs_["image"][j].split("/")[-1].replace(".png", "") elif "mvimages" in sample_inputs_: name = sample_inputs_["mvimages"][j][0].split("/")[-2].replace(".png", "") @@ -284,16 +296,15 @@ class ShapeDiffusionSystem(BaseSystem): out = self(batch) if self.global_step == 0: latents = self.shape_model.decode(out["latents"]) - mesh_v_f, has_surface = self.shape_model.extract_geometry(latents) + mesh_v_f, has_surface = self.shape_model.extract_geometry(latents=latents, extract_mesh_func=self.cfg.extract_mesh_func) + self.save_mesh( f"it{self.true_global_step}/{batch['uid'][0]}_{batch['sel_idx'][0] if 'sel_idx' in batch.keys() else 0}.obj", mesh_v_f[0][0], mesh_v_f[0][1] ) - # exit() - torch.cuda.empty_cache() return {"val/loss": out["loss_diffusion"]} - + @torch.no_grad() def sample(self, sample_inputs: Dict[str, Union[torch.FloatTensor, List[str]]], @@ -301,8 +312,6 @@ class ShapeDiffusionSystem(BaseSystem): steps: Optional[int] = None, guidance_scale: Optional[float] = None, eta: float = 0.0, - return_intermediates: bool = False, - camera_embeds: Optional[torch.Tensor] = None, seed: Optional[int] = None, **kwargs): @@ -310,11 +319,11 @@ class ShapeDiffusionSystem(BaseSystem): steps = self.cfg.num_inference_steps if guidance_scale is None: guidance_scale = self.cfg.guidance_scale - do_classifier_free_guidance = guidance_scale > 0 + do_classifier_free_guidance = guidance_scale != 1.0 # conditional encode if "image" in sample_inputs: - sample_inputs["image"] = [Image.open(img) for img in sample_inputs["image"]] + sample_inputs["image"] = [Image.open(img) if type(img) == str else img for img in sample_inputs["image"]] cond = self.condition.encode_image(sample_inputs["image"]) if do_classifier_free_guidance: un_cond = self.condition.empty_image_embeds.repeat(len(sample_inputs["image"]), 1, 1).to(cond) @@ -328,12 +337,12 @@ class ShapeDiffusionSystem(BaseSystem): else: sample_inputs["image"] = image cond += [self.condition.encode_image(sample_inputs["image"])] - cond = torch.stack(cond, dim=0)# tensor shape 为[len(sample_inputs["mvimages"], 4*(num_latents+1), context_dim] + cond = torch.stack(cond, dim=0).view(bs, -1, self.cfg.denoiser_model.context_dim) if do_classifier_free_guidance: - un_cond = self.condition.empty_image_embeds.unsqueeze(0).repeat(len(sample_inputs["mvimages"]), cond.shape[1] // self.condition.cfg.n_views, 1, 1).to(cond) # shape 为[len(sample_inputs["mvimages"], 4*(num_latents+1), context_dim] + un_cond = self.condition.empty_image_embeds.unsqueeze(0).repeat(len(sample_inputs["mvimages"]), 1, 1, 1).view(bs, cond.shape[1], self.cfg.denoiser_model.context_dim).to(cond) # shape 为[len(sample_inputs["mvimages"], 4*(num_latents+1), context_dim] cond = torch.cat([un_cond, cond], dim=0).view(bs * 2, -1, cond[0].shape[-1]) else: - raise NotImplementedError("Only text, image or mvimages condition is supported.") + raise NotImplementedError("Only image or mvimages condition is supported.") outputs = [] latents = None @@ -342,26 +351,8 @@ class ShapeDiffusionSystem(BaseSystem): generator = torch.Generator(device="cuda").manual_seed(seed) else: generator = None - - if not return_intermediates: - for _ in range(sample_times): - sample_loop = ddim_sample( - self.denoise_scheduler, - self.denoiser_model.eval(), - shape=self.shape_model.latent_shape, - cond=cond, - steps=steps, - guidance_scale=guidance_scale, - do_classifier_free_guidance=do_classifier_free_guidance, - device=self.device, - eta=eta, - disable_prog=False, - generator= generator - ) - for sample, t in sample_loop: - latents = sample - outputs.append(self.shape_model.decode(latents / self.z_scale_factor, **kwargs)) - else: + + for _ in range(sample_times): sample_loop = ddim_sample( self.denoise_scheduler, self.denoiser_model.eval(), @@ -375,17 +366,11 @@ class ShapeDiffusionSystem(BaseSystem): disable_prog=False, generator= generator ) - - iter_size = steps // sample_times - i = 0 for sample, t in sample_loop: latents = sample - if i % iter_size == 0 or i == steps - 1: - outputs.append(self.shape_model.decode(latents / self.z_scale_factor, **kwargs)) - i += 1 - + outputs.append(self.shape_model.decode(latents / self.cfg.z_scale_factor, **kwargs)) + return outputs - def on_validation_epoch_end(self): - pass + pass \ No newline at end of file diff --git a/craftsman/systems/shape_autoencoder.py b/craftsman/systems/shape_autoencoder.py index 383a1ecd7a5490c353782bd2c6ecc7aaa4f30772..9a45baa4888a1af24622991222b2b6867227dea0 100755 --- a/craftsman/systems/shape_autoencoder.py +++ b/craftsman/systems/shape_autoencoder.py @@ -124,7 +124,6 @@ class ShapeAutoEncoderSystem(BaseSystem): union = (pred + labels).gt(0).sum(dim=1) iou = intersection * 1.0 / union + 1e-5 iou = iou.mean() - self.log("val/accuracy", accuracy) self.log("val/iou", iou) diff --git a/craftsman/utils/__pycache__/__init__.cpython-310.pyc b/craftsman/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9c8c4a21e1ddb79aab68a042d8c41e41bae9ad4 Binary files /dev/null and b/craftsman/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/craftsman/utils/__pycache__/__init__.cpython-38.pyc b/craftsman/utils/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 7765274149e74f3ca16da7a2960b729debc84839..0000000000000000000000000000000000000000 Binary files a/craftsman/utils/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/craftsman/utils/__pycache__/base.cpython-310.pyc b/craftsman/utils/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83a68795573ff8940846efa2c7d038d175d73b83 Binary files /dev/null and b/craftsman/utils/__pycache__/base.cpython-310.pyc differ diff --git a/craftsman/utils/__pycache__/base.cpython-38.pyc b/craftsman/utils/__pycache__/base.cpython-38.pyc deleted file mode 100644 index f487d7c264d608048815518e6543948be4259fa5..0000000000000000000000000000000000000000 Binary files a/craftsman/utils/__pycache__/base.cpython-38.pyc and /dev/null differ diff --git a/craftsman/utils/__pycache__/checkpoint.cpython-310.pyc b/craftsman/utils/__pycache__/checkpoint.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd5b481c93c34571f33c0fc48ef145ed23b1fa58 Binary files /dev/null and b/craftsman/utils/__pycache__/checkpoint.cpython-310.pyc differ diff --git a/craftsman/utils/__pycache__/checkpoint.cpython-38.pyc b/craftsman/utils/__pycache__/checkpoint.cpython-38.pyc deleted file mode 100644 index 5ee1ebb32890cec380dd6b2b3fd8987271d7c815..0000000000000000000000000000000000000000 Binary files a/craftsman/utils/__pycache__/checkpoint.cpython-38.pyc and /dev/null differ diff --git a/craftsman/utils/__pycache__/config.cpython-310.pyc b/craftsman/utils/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aac3aef5468c180268a0c8c52a1707e60d870822 Binary files /dev/null and b/craftsman/utils/__pycache__/config.cpython-310.pyc differ diff --git a/craftsman/utils/__pycache__/config.cpython-38.pyc b/craftsman/utils/__pycache__/config.cpython-38.pyc deleted file mode 100644 index 2635fb7fd002183ae197135e57b72fd8684cbd12..0000000000000000000000000000000000000000 Binary files a/craftsman/utils/__pycache__/config.cpython-38.pyc and /dev/null differ diff --git a/craftsman/utils/__pycache__/misc.cpython-310.pyc b/craftsman/utils/__pycache__/misc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3adf88e303ffbe6e877b42cbadc7f326fcbad751 Binary files /dev/null and b/craftsman/utils/__pycache__/misc.cpython-310.pyc differ diff --git a/craftsman/utils/__pycache__/misc.cpython-38.pyc b/craftsman/utils/__pycache__/misc.cpython-38.pyc deleted file mode 100644 index 20f0ef3838a60a0ef1f8b6f326f27d2a361e6f4a..0000000000000000000000000000000000000000 Binary files a/craftsman/utils/__pycache__/misc.cpython-38.pyc and /dev/null differ diff --git a/craftsman/utils/__pycache__/ops.cpython-310.pyc b/craftsman/utils/__pycache__/ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..debfa41eae014018d9d0ba540248b2f77de3c23b Binary files /dev/null and b/craftsman/utils/__pycache__/ops.cpython-310.pyc differ diff --git a/craftsman/utils/__pycache__/ops.cpython-38.pyc b/craftsman/utils/__pycache__/ops.cpython-38.pyc deleted file mode 100644 index 2eca9760971a93e8744184314781d05aecd5bb14..0000000000000000000000000000000000000000 Binary files a/craftsman/utils/__pycache__/ops.cpython-38.pyc and /dev/null differ diff --git a/craftsman/utils/__pycache__/saving.cpython-310.pyc b/craftsman/utils/__pycache__/saving.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b38d835c68dc5a67dad09beea97d2bb2f47fcf4 Binary files /dev/null and b/craftsman/utils/__pycache__/saving.cpython-310.pyc differ diff --git a/craftsman/utils/__pycache__/saving.cpython-38.pyc b/craftsman/utils/__pycache__/saving.cpython-38.pyc deleted file mode 100644 index f18acb05a059c37e691d02e99b684e268394a30d..0000000000000000000000000000000000000000 Binary files a/craftsman/utils/__pycache__/saving.cpython-38.pyc and /dev/null differ diff --git a/craftsman/utils/__pycache__/scheduler.cpython-310.pyc b/craftsman/utils/__pycache__/scheduler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7f9b4184fd39cf6b2df5076a97681f89ea577fe Binary files /dev/null and b/craftsman/utils/__pycache__/scheduler.cpython-310.pyc differ diff --git a/craftsman/utils/__pycache__/scheduler.cpython-38.pyc b/craftsman/utils/__pycache__/scheduler.cpython-38.pyc deleted file mode 100644 index fea013c3f079bdd852119f654f88f1cd24851939..0000000000000000000000000000000000000000 Binary files a/craftsman/utils/__pycache__/scheduler.cpython-38.pyc and /dev/null differ diff --git a/craftsman/utils/__pycache__/typing.cpython-38.pyc b/craftsman/utils/__pycache__/typing.cpython-310.pyc similarity index 84% rename from craftsman/utils/__pycache__/typing.cpython-38.pyc rename to craftsman/utils/__pycache__/typing.cpython-310.pyc index 4a4b821005a73f47532a0d7a0d38b0230612b4e8..24970b9142feaba6029154274a042c4c34fc6e92 100644 Binary files a/craftsman/utils/__pycache__/typing.cpython-38.pyc and b/craftsman/utils/__pycache__/typing.cpython-310.pyc differ diff --git a/craftsman/utils/base.py b/craftsman/utils/base.py index 6a9bef941db99cd5f5b1e5897b247a1c91837167..2fa680e1a683beb47f254f74e5133c86da2ca8eb 100755 --- a/craftsman/utils/base.py +++ b/craftsman/utils/base.py @@ -57,12 +57,12 @@ class Updateable: pass -def update_if_possible(module, epoch: int, global_step: int) -> None: +def update_if_possible(module: Any, epoch: int, global_step: int) -> None: if isinstance(module, Updateable): module.do_update_step(epoch, global_step) -def update_end_if_possible(module, epoch: int, global_step: int) -> None: +def update_end_if_possible(module: Any, epoch: int, global_step: int) -> None: if isinstance(module, Updateable): module.do_update_step_end(epoch, global_step) diff --git a/craftsman/utils/callbacks.py b/craftsman/utils/callbacks.py index c0382c56d4a11ce1f6627324cf082aca74b32f05..daaf2c34d7864560cebcaa53d64310db534301a0 100755 --- a/craftsman/utils/callbacks.py +++ b/craftsman/utils/callbacks.py @@ -16,6 +16,23 @@ from pytorch_lightning.callbacks.progress import TQDMProgressBar from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn +class EarlyEnvironmentSetter(Callback): + def __init__(self): + super().__init__() + self.rank_set = False + + def setup(self, trainer, pl_module, stage): + if not self.rank_set: + world_size = trainer.num_devices + local_rank = trainer.strategy.local_rank + + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['LOCAL_WORLD_SIZE'] = str(world_size) + os.environ['LOCAL_RANK'] = str(local_rank) + os.environ['RANK'] = str(local_rank) + + self.rank_set = True + class VersionedCallback(Callback): def __init__(self, save_root, version=None, use_version=True): self.save_root = save_root diff --git a/craftsman/utils/config.py b/craftsman/utils/config.py index 8e7909615806baffd62785de4edba326a45f37ad..ae4f23ad2d59f5a30123f6c4c097909969f1d46f 100755 --- a/craftsman/utils/config.py +++ b/craftsman/utils/config.py @@ -28,7 +28,7 @@ OmegaConf.register_new_resolver( # ======================================================= # -def C_max(value) -> float: +def C_max(value: Any) -> float: if isinstance(value, int) or isinstance(value, float): pass else: @@ -98,10 +98,10 @@ class ExperimentConfig: self.trial_name += self.timestamp self.exp_dir = os.path.join(self.exp_root_dir, self.name) self.trial_dir = os.path.join(self.exp_dir, self.trial_name) - os.makedirs(self.trial_dir, exist_ok=True) + # os.makedirs(self.trial_dir, exist_ok=True) -def load_config(*yamls: str, cli_args: list = [], from_string=False, **kwargs): +def load_config(*yamls: str, cli_args: list = [], from_string=False, **kwargs) -> Any: if from_string: yaml_confs = [OmegaConf.create(s) for s in yamls] else: @@ -114,7 +114,7 @@ def load_config(*yamls: str, cli_args: list = [], from_string=False, **kwargs): return scfg -def config_to_primitive(config, resolve: bool = True): +def config_to_primitive(config, resolve: bool = True) -> Any: return OmegaConf.to_container(config, resolve=resolve) @@ -123,6 +123,6 @@ def dump_config(path: str, config) -> None: OmegaConf.save(config=config, f=fp) -def parse_structured(fields, cfg: Optional[Union[dict, DictConfig]] = None): +def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: scfg = OmegaConf.structured(fields(**cfg)) return scfg \ No newline at end of file diff --git a/craftsman/utils/misc.py b/craftsman/utils/misc.py index c3244f4f856acf50e59d2d3816632a03ae89da54..732835e7f4d2dba74a22c34f876ef4762db754e3 100755 --- a/craftsman/utils/misc.py +++ b/craftsman/utils/misc.py @@ -70,7 +70,7 @@ def load_module_weights( return state_dict_to_load, ckpt["epoch"], ckpt["global_step"] -def C(value, epoch: int, global_step: int) -> float: +def C(value: Any, epoch: int, global_step: int) -> float: if isinstance(value, int) or isinstance(value, float): pass else: diff --git a/craftsman/utils/saving.py b/craftsman/utils/saving.py index 3f20a0dbf4c188f6efdac7d585be0480b79fdca4..8501d684769753ce1c44556ef649b7fbcd58cb17 100755 --- a/craftsman/utils/saving.py +++ b/craftsman/utils/saving.py @@ -645,6 +645,12 @@ class SaverMixin: save_path = self.get_save_path(filename) shutil.copyfile(src_path, save_path) return save_path + + def save_txt(self, filename, comment) -> str: + save_path = self.get_save_path(filename) + with open(save_path, "w") as f: + f.write(comment) + return save_path def save_json(self, filename, payload) -> str: save_path = self.get_save_path(filename) diff --git a/craftsman/utils/visualizers/__init__.py b/craftsman/utils/visualizers/__init__.py deleted file mode 100755 index 40a96afc6ff09d58a702b76e3f7dd412fe975e26..0000000000000000000000000000000000000000 --- a/craftsman/utils/visualizers/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/craftsman/utils/visualizers/color_util.py b/craftsman/utils/visualizers/color_util.py deleted file mode 100755 index 7983243fd37f5fee47bc51475dc58c460a067830..0000000000000000000000000000000000000000 --- a/craftsman/utils/visualizers/color_util.py +++ /dev/null @@ -1,43 +0,0 @@ -import numpy as np -import matplotlib.pyplot as plt - - -# Helper functions -def get_colors(inp, colormap="viridis", normalize=True, vmin=None, vmax=None): - colormap = plt.cm.get_cmap(colormap) - if normalize: - vmin = np.min(inp) - vmax = np.max(inp) - - norm = plt.Normalize(vmin, vmax) - return colormap(norm(inp))[:, :3] - - -def gen_checkers(n_checkers_x, n_checkers_y, width=256, height=256): - # tex dims need to be power of two. - array = np.ones((width, height, 3), dtype='float32') - - # width in texels of each checker - checker_w = width / n_checkers_x - checker_h = height / n_checkers_y - - for y in range(height): - for x in range(width): - color_key = int(x / checker_w) + int(y / checker_h) - if color_key % 2 == 0: - array[x, y, :] = [1., 0.874, 0.0] - else: - array[x, y, :] = [0., 0., 0.] - return array - - -def gen_circle(width=256, height=256): - xx, yy = np.mgrid[:width, :height] - circle = (xx - width / 2 + 0.5) ** 2 + (yy - height / 2 + 0.5) ** 2 - array = np.ones((width, height, 4), dtype='float32') - array[:, :, 0] = (circle <= width) - array[:, :, 1] = (circle <= width) - array[:, :, 2] = (circle <= width) - array[:, :, 3] = circle <= width - return array - diff --git a/craftsman/utils/visualizers/html_util.py b/craftsman/utils/visualizers/html_util.py deleted file mode 100755 index f90fe6cfefe6108655b48c36d60db537589993d5..0000000000000000000000000000000000000000 --- a/craftsman/utils/visualizers/html_util.py +++ /dev/null @@ -1,49 +0,0 @@ -# -*- coding: utf-8 -*- -import io -import base64 -import numpy as np -from PIL import Image - - -def to_html_frame(content): - - html_frame = f""" - - - {content} - - - """ - - return html_frame - - -def to_single_row_table(caption: str, content: str): - - table_html = f""" - - - - - -
{caption}
{content}
- """ - - return table_html - - -def to_image_embed_tag(image: np.ndarray): - - # Convert np.ndarray to bytes - img = Image.fromarray(image) - raw_bytes = io.BytesIO() - img.save(raw_bytes, "PNG") - - # Encode bytes to base64 - image_base64 = base64.b64encode(raw_bytes.getvalue()).decode("utf-8") - - image_tag = f""" - Embedded Image - """ - - return image_tag diff --git a/craftsman/utils/visualizers/pythreejs_viewer.py b/craftsman/utils/visualizers/pythreejs_viewer.py deleted file mode 100755 index 70abbb7dce17361489f36eddca9c7e3089d4ec05..0000000000000000000000000000000000000000 --- a/craftsman/utils/visualizers/pythreejs_viewer.py +++ /dev/null @@ -1,537 +0,0 @@ -import numpy as np -from ipywidgets import embed -import pythreejs as p3s -import uuid - -from .color_util import get_colors, gen_circle, gen_checkers - - -EMBED_URL = "https://cdn.jsdelivr.net/npm/@jupyter-widgets/html-manager@1.0.1/dist/embed-amd.js" - - -class PyThreeJSViewer(object): - - def __init__(self, settings, render_mode="WEBSITE"): - self.render_mode = render_mode - self.__update_settings(settings) - self._light = p3s.DirectionalLight(color='white', position=[0, 0, 1], intensity=0.6) - self._light2 = p3s.AmbientLight(intensity=0.5) - self._cam = p3s.PerspectiveCamera(position=[0, 0, 1], lookAt=[0, 0, 0], fov=self.__s["fov"], - aspect=self.__s["width"] / self.__s["height"], children=[self._light]) - self._orbit = p3s.OrbitControls(controlling=self._cam) - self._scene = p3s.Scene(children=[self._cam, self._light2], background=self.__s["background"]) # "#4c4c80" - self._renderer = p3s.Renderer(camera=self._cam, scene=self._scene, controls=[self._orbit], - width=self.__s["width"], height=self.__s["height"], - antialias=self.__s["antialias"]) - - self.__objects = {} - self.__cnt = 0 - - def jupyter_mode(self): - self.render_mode = "JUPYTER" - - def offline(self): - self.render_mode = "OFFLINE" - - def website(self): - self.render_mode = "WEBSITE" - - def __get_shading(self, shading): - shad = {"flat": True, "wireframe": False, "wire_width": 0.03, "wire_color": "black", - "side": 'DoubleSide', "colormap": "viridis", "normalize": [None, None], - "bbox": False, "roughness": 0.5, "metalness": 0.25, "reflectivity": 1.0, - "line_width": 1.0, "line_color": "black", - "point_color": "red", "point_size": 0.01, "point_shape": "circle", - "text_color": "red" - } - for k in shading: - shad[k] = shading[k] - return shad - - def __update_settings(self, settings={}): - sett = {"width": 600, "height": 600, "antialias": True, "scale": 1.5, "background": "#ffffff", - "fov": 30} - for k in settings: - sett[k] = settings[k] - self.__s = sett - - def __add_object(self, obj, parent=None): - if not parent: # Object is added to global scene and objects dict - self.__objects[self.__cnt] = obj - self.__cnt += 1 - self._scene.add(obj["mesh"]) - else: # Object is added to parent object and NOT to objects dict - parent.add(obj["mesh"]) - - self.__update_view() - - if self.render_mode == "JUPYTER": - return self.__cnt - 1 - elif self.render_mode == "WEBSITE": - return self - - def __add_line_geometry(self, lines, shading, obj=None): - lines = lines.astype("float32", copy=False) - mi = np.min(lines, axis=0) - ma = np.max(lines, axis=0) - - geometry = p3s.LineSegmentsGeometry(positions=lines.reshape((-1, 2, 3))) - material = p3s.LineMaterial(linewidth=shading["line_width"], color=shading["line_color"]) - # , vertexColors='VertexColors'), - lines = p3s.LineSegments2(geometry=geometry, material=material) # type='LinePieces') - line_obj = {"geometry": geometry, "mesh": lines, "material": material, - "max": ma, "min": mi, "type": "Lines", "wireframe": None} - - if obj: - return self.__add_object(line_obj, obj), line_obj - else: - return self.__add_object(line_obj) - - def __update_view(self): - if len(self.__objects) == 0: - return - ma = np.zeros((len(self.__objects), 3)) - mi = np.zeros((len(self.__objects), 3)) - for r, obj in enumerate(self.__objects): - ma[r] = self.__objects[obj]["max"] - mi[r] = self.__objects[obj]["min"] - ma = np.max(ma, axis=0) - mi = np.min(mi, axis=0) - diag = np.linalg.norm(ma - mi) - mean = ((ma - mi) / 2 + mi).tolist() - scale = self.__s["scale"] * (diag) - self._orbit.target = mean - self._cam.lookAt(mean) - self._cam.position = [mean[0], mean[1], mean[2] + scale] - self._light.position = [mean[0], mean[1], mean[2] + scale] - - self._orbit.exec_three_obj_method('update') - self._cam.exec_three_obj_method('updateProjectionMatrix') - - def __get_bbox(self, v): - m = np.min(v, axis=0) - M = np.max(v, axis=0) - - # Corners of the bounding box - v_box = np.array([[m[0], m[1], m[2]], [M[0], m[1], m[2]], [M[0], M[1], m[2]], [m[0], M[1], m[2]], - [m[0], m[1], M[2]], [M[0], m[1], M[2]], [M[0], M[1], M[2]], [m[0], M[1], M[2]]]) - - f_box = np.array([[0, 1], [1, 2], [2, 3], [3, 0], [4, 5], [5, 6], [6, 7], [7, 4], - [0, 4], [1, 5], [2, 6], [7, 3]], dtype=np.uint32) - return v_box, f_box - - def __get_colors(self, v, f, c, sh): - coloring = "VertexColors" - if type(c) == np.ndarray and c.size == 3: # Single color - colors = np.ones_like(v) - colors[:, 0] = c[0] - colors[:, 1] = c[1] - colors[:, 2] = c[2] - # print("Single colors") - elif type(c) == np.ndarray and len(c.shape) == 2 and c.shape[1] == 3: # Color values for - if c.shape[0] == f.shape[0]: # faces - colors = np.hstack([c, c, c]).reshape((-1, 3)) - coloring = "FaceColors" - # print("Face color values") - elif c.shape[0] == v.shape[0]: # vertices - colors = c - # print("Vertex color values") - else: # Wrong size, fallback - print("Invalid color array given! Supported are numpy arrays.", type(c)) - colors = np.ones_like(v) - colors[:, 0] = 1.0 - colors[:, 1] = 0.874 - colors[:, 2] = 0.0 - elif type(c) == np.ndarray and c.size == f.shape[0]: # Function values for faces - normalize = sh["normalize"][0] != None and sh["normalize"][1] != None - cc = get_colors(c, sh["colormap"], normalize=normalize, - vmin=sh["normalize"][0], vmax=sh["normalize"][1]) - # print(cc.shape) - colors = np.hstack([cc, cc, cc]).reshape((-1, 3)) - coloring = "FaceColors" - # print("Face function values") - elif type(c) == np.ndarray and c.size == v.shape[0]: # Function values for vertices - normalize = sh["normalize"][0] != None and sh["normalize"][1] != None - colors = get_colors(c, sh["colormap"], normalize=normalize, - vmin=sh["normalize"][0], vmax=sh["normalize"][1]) - # print("Vertex function values") - - else: - colors = np.ones_like(v) - # colors[:, 0] = 1.0 - # colors[:, 1] = 0.874 - # colors[:, 2] = 0.0 - colors[:, 0] = 1 - colors[:, 1] = 1 - colors[:, 2] = 1 - - # No color - if c is not None: - print("Invalid color array given! Supported are numpy arrays.", type(c)) - - return colors, coloring - - def __get_point_colors(self, v, c, sh): - v_color = True - if c is None: # No color given, use global color - # conv = mpl.colors.ColorConverter() - colors = sh["point_color"] # np.array(conv.to_rgb(sh["point_color"])) - v_color = False - elif isinstance(c, str): # No color given, use global color - # conv = mpl.colors.ColorConverter() - colors = c # np.array(conv.to_rgb(c)) - v_color = False - elif type(c) == np.ndarray and len(c.shape) == 2 and c.shape[0] == v.shape[0] and c.shape[1] == 3: - # Point color - colors = c.astype("float32", copy=False) - - elif isinstance(c, np.ndarray) and len(c.shape) == 2 and c.shape[0] == v.shape[0] and c.shape[1] != 3: - # Function values for vertices, but the colors are features - c_norm = np.linalg.norm(c, ord=2, axis=-1) - normalize = sh["normalize"][0] != None and sh["normalize"][1] != None - colors = get_colors(c_norm, sh["colormap"], normalize=normalize, - vmin=sh["normalize"][0], vmax=sh["normalize"][1]) - colors = colors.astype("float32", copy=False) - - elif type(c) == np.ndarray and c.size == v.shape[0]: # Function color - normalize = sh["normalize"][0] != None and sh["normalize"][1] != None - colors = get_colors(c, sh["colormap"], normalize=normalize, - vmin=sh["normalize"][0], vmax=sh["normalize"][1]) - colors = colors.astype("float32", copy=False) - # print("Vertex function values") - - else: - print("Invalid color array given! Supported are numpy arrays.", type(c)) - colors = sh["point_color"] - v_color = False - - return colors, v_color - - def add_mesh(self, v, f, c=None, uv=None, n=None, shading={}, texture_data=None, **kwargs): - shading.update(kwargs) - sh = self.__get_shading(shading) - mesh_obj = {} - - # it is a tet - if v.shape[1] == 3 and f.shape[1] == 4: - f_tmp = np.ndarray([f.shape[0] * 4, 3], dtype=f.dtype) - for i in range(f.shape[0]): - f_tmp[i * 4 + 0] = np.array([f[i][1], f[i][0], f[i][2]]) - f_tmp[i * 4 + 1] = np.array([f[i][0], f[i][1], f[i][3]]) - f_tmp[i * 4 + 2] = np.array([f[i][1], f[i][2], f[i][3]]) - f_tmp[i * 4 + 3] = np.array([f[i][2], f[i][0], f[i][3]]) - f = f_tmp - - if v.shape[1] == 2: - v = np.append(v, np.zeros([v.shape[0], 1]), 1) - - # Type adjustment vertices - v = v.astype("float32", copy=False) - - # Color setup - colors, coloring = self.__get_colors(v, f, c, sh) - - # Type adjustment faces and colors - c = colors.astype("float32", copy=False) - - # Material and geometry setup - ba_dict = {"color": p3s.BufferAttribute(c)} - if coloring == "FaceColors": - verts = np.zeros((f.shape[0] * 3, 3), dtype="float32") - for ii in range(f.shape[0]): - # print(ii*3, f[ii]) - verts[ii * 3] = v[f[ii, 0]] - verts[ii * 3 + 1] = v[f[ii, 1]] - verts[ii * 3 + 2] = v[f[ii, 2]] - v = verts - else: - f = f.astype("uint32", copy=False).ravel() - ba_dict["index"] = p3s.BufferAttribute(f, normalized=False) - - ba_dict["position"] = p3s.BufferAttribute(v, normalized=False) - - if uv is not None: - uv = (uv - np.min(uv)) / (np.max(uv) - np.min(uv)) - if texture_data is None: - texture_data = gen_checkers(20, 20) - tex = p3s.DataTexture(data=texture_data, format="RGBFormat", type="FloatType") - material = p3s.MeshStandardMaterial(map=tex, reflectivity=sh["reflectivity"], side=sh["side"], - roughness=sh["roughness"], metalness=sh["metalness"], - flatShading=sh["flat"], - polygonOffset=True, polygonOffsetFactor=1, polygonOffsetUnits=5) - ba_dict["uv"] = p3s.BufferAttribute(uv.astype("float32", copy=False)) - else: - material = p3s.MeshStandardMaterial(vertexColors=coloring, reflectivity=sh["reflectivity"], - side=sh["side"], roughness=sh["roughness"], metalness=sh["metalness"], - flatShading=sh["flat"], - polygonOffset=True, polygonOffsetFactor=1, polygonOffsetUnits=5) - - if type(n) != type(None) and coloring == "VertexColors": # TODO: properly handle normals for FaceColors as well - ba_dict["normal"] = p3s.BufferAttribute(n.astype("float32", copy=False), normalized=True) - - geometry = p3s.BufferGeometry(attributes=ba_dict) - - if coloring == "VertexColors" and type(n) == type(None): - geometry.exec_three_obj_method('computeVertexNormals') - elif coloring == "FaceColors" and type(n) == type(None): - geometry.exec_three_obj_method('computeFaceNormals') - - # Mesh setup - mesh = p3s.Mesh(geometry=geometry, material=material) - - # Wireframe setup - mesh_obj["wireframe"] = None - if sh["wireframe"]: - wf_geometry = p3s.WireframeGeometry(mesh.geometry) # WireframeGeometry - wf_material = p3s.LineBasicMaterial(color=sh["wire_color"], linewidth=sh["wire_width"]) - wireframe = p3s.LineSegments(wf_geometry, wf_material) - mesh.add(wireframe) - mesh_obj["wireframe"] = wireframe - - # Bounding box setup - if sh["bbox"]: - v_box, f_box = self.__get_bbox(v) - _, bbox = self.add_edges(v_box, f_box, sh, mesh) - mesh_obj["bbox"] = [bbox, v_box, f_box] - - # Object setup - mesh_obj["max"] = np.max(v, axis=0) - mesh_obj["min"] = np.min(v, axis=0) - mesh_obj["geometry"] = geometry - mesh_obj["mesh"] = mesh - mesh_obj["material"] = material - mesh_obj["type"] = "Mesh" - mesh_obj["shading"] = sh - mesh_obj["coloring"] = coloring - mesh_obj["arrays"] = [v, f, c] # TODO replays with proper storage or remove if not needed - - return self.__add_object(mesh_obj) - - def add_lines(self, beginning, ending, shading={}, obj=None, **kwargs): - shading.update(kwargs) - if len(beginning.shape) == 1: - if len(beginning) == 2: - beginning = np.array([[beginning[0], beginning[1], 0]]) - else: - if beginning.shape[1] == 2: - beginning = np.append( - beginning, np.zeros([beginning.shape[0], 1]), 1) - if len(ending.shape) == 1: - if len(ending) == 2: - ending = np.array([[ending[0], ending[1], 0]]) - else: - if ending.shape[1] == 2: - ending = np.append( - ending, np.zeros([ending.shape[0], 1]), 1) - - sh = self.__get_shading(shading) - lines = np.hstack([beginning, ending]) - lines = lines.reshape((-1, 3)) - return self.__add_line_geometry(lines, sh, obj) - - def add_edges(self, vertices, edges, shading={}, obj=None, **kwargs): - shading.update(kwargs) - if vertices.shape[1] == 2: - vertices = np.append( - vertices, np.zeros([vertices.shape[0], 1]), 1) - sh = self.__get_shading(shading) - lines = np.zeros((edges.size, 3)) - cnt = 0 - for e in edges: - lines[cnt, :] = vertices[e[0]] - lines[cnt + 1, :] = vertices[e[1]] - cnt += 2 - return self.__add_line_geometry(lines, sh, obj) - - def add_points(self, points, c=None, shading={}, obj=None, **kwargs): - shading.update(kwargs) - if len(points.shape) == 1: - if len(points) == 2: - points = np.array([[points[0], points[1], 0]]) - else: - if points.shape[1] == 2: - points = np.append( - points, np.zeros([points.shape[0], 1]), 1) - sh = self.__get_shading(shading) - points = points.astype("float32", copy=False) - mi = np.min(points, axis=0) - ma = np.max(points, axis=0) - - g_attributes = {"position": p3s.BufferAttribute(points, normalized=False)} - m_attributes = {"size": sh["point_size"]} - - if sh["point_shape"] == "circle": # Plot circles - tex = p3s.DataTexture(data=gen_circle(16, 16), format="RGBAFormat", type="FloatType") - m_attributes["map"] = tex - m_attributes["alphaTest"] = 0.5 - m_attributes["transparency"] = True - else: # Plot squares - pass - - colors, v_colors = self.__get_point_colors(points, c, sh) - if v_colors: # Colors per point - m_attributes["vertexColors"] = 'VertexColors' - g_attributes["color"] = p3s.BufferAttribute(colors, normalized=False) - - else: # Colors for all points - m_attributes["color"] = colors - - material = p3s.PointsMaterial(**m_attributes) - geometry = p3s.BufferGeometry(attributes=g_attributes) - points = p3s.Points(geometry=geometry, material=material) - point_obj = {"geometry": geometry, "mesh": points, "material": material, - "max": ma, "min": mi, "type": "Points", "wireframe": None} - - if obj: - return self.__add_object(point_obj, obj), point_obj - else: - return self.__add_object(point_obj) - - def remove_object(self, obj_id): - if obj_id not in self.__objects: - print("Invalid object id. Valid ids are: ", list(self.__objects.keys())) - return - self._scene.remove(self.__objects[obj_id]["mesh"]) - del self.__objects[obj_id] - self.__update_view() - - def reset(self): - for obj_id in list(self.__objects.keys()).copy(): - self._scene.remove(self.__objects[obj_id]["mesh"]) - del self.__objects[obj_id] - self.__update_view() - - def update_object(self, oid=0, vertices=None, colors=None, faces=None): - obj = self.__objects[oid] - if type(vertices) != type(None): - if obj["coloring"] == "FaceColors": - f = obj["arrays"][1] - verts = np.zeros((f.shape[0] * 3, 3), dtype="float32") - for ii in range(f.shape[0]): - # print(ii*3, f[ii]) - verts[ii * 3] = vertices[f[ii, 0]] - verts[ii * 3 + 1] = vertices[f[ii, 1]] - verts[ii * 3 + 2] = vertices[f[ii, 2]] - v = verts - - else: - v = vertices.astype("float32", copy=False) - obj["geometry"].attributes["position"].array = v - # self.wireframe.attributes["position"].array = v # Wireframe updates? - obj["geometry"].attributes["position"].needsUpdate = True - # obj["geometry"].exec_three_obj_method('computeVertexNormals') - if type(colors) != type(None): - colors, coloring = self.__get_colors(obj["arrays"][0], obj["arrays"][1], colors, obj["shading"]) - colors = colors.astype("float32", copy=False) - obj["geometry"].attributes["color"].array = colors - obj["geometry"].attributes["color"].needsUpdate = True - if type(faces) != type(None): - if obj["coloring"] == "FaceColors": - print("Face updates are currently only possible in vertex color mode.") - return - f = faces.astype("uint32", copy=False).ravel() - print(obj["geometry"].attributes) - obj["geometry"].attributes["index"].array = f - # self.wireframe.attributes["position"].array = v # Wireframe updates? - obj["geometry"].attributes["index"].needsUpdate = True - # obj["geometry"].exec_three_obj_method('computeVertexNormals') - # self.mesh.geometry.verticesNeedUpdate = True - # self.mesh.geometry.elementsNeedUpdate = True - # self.update() - if self.render_mode == "WEBSITE": - return self - - # def update(self): - # self.mesh.exec_three_obj_method('update') - # self.orbit.exec_three_obj_method('update') - # self.cam.exec_three_obj_method('updateProjectionMatrix') - # self.scene.exec_three_obj_method('update') - - def add_text(self, text, shading={}, **kwargs): - shading.update(kwargs) - sh = self.__get_shading(shading) - tt = p3s.TextTexture(string=text, color=sh["text_color"]) - sm = p3s.SpriteMaterial(map=tt) - text = p3s.Sprite(material=sm, scaleToTexture=True) - self._scene.add(text) - - # def add_widget(self, widget, callback): - # self.widgets.append(widget) - # widget.observe(callback, names='value') - - # def add_dropdown(self, options, default, desc, cb): - # widget = widgets.Dropdown(options=options, value=default, description=desc) - # self.__widgets.append(widget) - # widget.observe(cb, names="value") - # display(widget) - - # def add_button(self, text, cb): - # button = widgets.Button(description=text) - # self.__widgets.append(button) - # button.on_click(cb) - # display(button) - - def to_html(self, imports=True, html_frame=True): - # Bake positions (fixes centering bug in offline rendering) - if len(self.__objects) == 0: - return - ma = np.zeros((len(self.__objects), 3)) - mi = np.zeros((len(self.__objects), 3)) - for r, obj in enumerate(self.__objects): - ma[r] = self.__objects[obj]["max"] - mi[r] = self.__objects[obj]["min"] - ma = np.max(ma, axis=0) - mi = np.min(mi, axis=0) - diag = np.linalg.norm(ma - mi) - mean = (ma - mi) / 2 + mi - for r, obj in enumerate(self.__objects): - v = self.__objects[obj]["geometry"].attributes["position"].array - v -= mean - v += np.array([0.0, .9, 0.0]) #! to move the obj to the center of window - - scale = self.__s["scale"] * (diag) - self._orbit.target = [0.0, 0.0, 0.0] - self._cam.lookAt([0.0, 0.0, 0.0]) - # self._cam.position = [0.0, 0.0, scale] - self._cam.position = [0.0, 0.5, scale * 1.3] #! show four complete meshes in the window - self._light.position = [0.0, 0.0, scale] - - state = embed.dependency_state(self._renderer) - - # Somehow these entries are missing when the state is exported in python. - # Exporting from the GUI works, so we are inserting the missing entries. - for k in state: - if state[k]["model_name"] == "OrbitControlsModel": - state[k]["state"]["maxAzimuthAngle"] = "inf" - state[k]["state"]["maxDistance"] = "inf" - state[k]["state"]["maxZoom"] = "inf" - state[k]["state"]["minAzimuthAngle"] = "-inf" - - tpl = embed.load_requirejs_template - if not imports: - embed.load_requirejs_template = "" - - s = embed.embed_snippet(self._renderer, state=state, embed_url=EMBED_URL) - # s = embed.embed_snippet(self.__w, state=state) - embed.load_requirejs_template = tpl - - if html_frame: - s = "\n\n" + s + "\n\n" - - # Revert changes - for r, obj in enumerate(self.__objects): - v = self.__objects[obj]["geometry"].attributes["position"].array - v += mean - self.__update_view() - - return s - - def save(self, filename=""): - if filename == "": - uid = str(uuid.uuid4()) + ".html" - else: - filename = filename.replace(".html", "") - uid = filename + '.html' - with open(uid, "w") as f: - f.write(self.to_html()) - print("Plot saved to file %s." % uid) diff --git a/examples/bug.png b/examples/bug.png new file mode 100644 index 0000000000000000000000000000000000000000..bd730216d50e830c4010f08a30442f626268a3ac Binary files /dev/null and b/examples/bug.png differ diff --git a/examples/dragon.png b/examples/dragon.png new file mode 100644 index 0000000000000000000000000000000000000000..1ed16896722535aedbdaf629dd0068cd168bd597 Binary files /dev/null and b/examples/dragon.png differ diff --git a/examples/mask.png b/examples/mask.png new file mode 100644 index 0000000000000000000000000000000000000000..6ab444c7a2850ac645567ff6f93ed8149d9cbb45 Binary files /dev/null and b/examples/mask.png differ diff --git a/examples/monster.png b/examples/monster.png new file mode 100644 index 0000000000000000000000000000000000000000..1cfe5eb404319480d5f17390fd4c76baa890b300 Binary files /dev/null and b/examples/monster.png differ diff --git a/examples/werewolf.png b/examples/werewolf.png new file mode 100644 index 0000000000000000000000000000000000000000..f20d321902a1c50f642533dbe90ee1aa83039ccd Binary files /dev/null and b/examples/werewolf.png differ diff --git a/gradio_app.py b/gradio_app.py old mode 100644 new mode 100755 index c5944f7bae5525247a2b26d9934fc4c31912e47b..3e2e0140e0d81aac04fd03a5c5af373a1fd9bf39 --- a/gradio_app.py +++ b/gradio_app.py @@ -9,26 +9,32 @@ import importlib import numpy as np from omegaconf import OmegaConf from huggingface_hub import hf_hub_download +from diffusers import DiffusionPipeline +import PIL +from PIL import Image from collections import OrderedDict import trimesh +import rembg import gradio as gr from typing import Any -from einops import rearrange proj_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(os.path.join(proj_dir)) import tempfile -from apps.utils import * +import craftsman +from craftsman.utils.config import ExperimentConfig, load_config _TITLE = '''CraftsMan: High-fidelity Mesh Generation with 3D Native Generation and Interactive Geometry Refiner''' _DESCRIPTION = '''
-Important: The ckpt models released have been primarily trained on character data, hence they are likely to exhibit superior performance in this category. We are also planning to release more advanced pretrained models in the future. +Important: If you have your own data and want to collaborate, we are welcom to any contact. +
+Select or upload a image, then just click 'Generate'.
-By mimicking the artist/craftsman modeling workflow, we propose CraftsMan (aka 匠心) which uses 3D Latent Set Diffusion Model that directly generates coarse meshes, +By mimicking the artist/craftsman modeling workflow, we propose CraftsMan (aka 匠心) that uses 3D Latent Set Diffusion Model that directly generate coarse meshes, then a multi-view normal enhanced image generation model is used to refine the mesh. We provide the coarse 3D diffusion part here.
@@ -36,8 +42,6 @@ If you found CraftsMan is helpful, please help to ⭐ the
*If you have your own multi-view images, you can directly upload it. -Tutorial -使用教程
''' _CITE_ = r""" @@ -59,153 +63,101 @@ CraftsMan is under [AGPL-3.0](https://www.gnu.org/licenses/agpl-3.0.en.html), so 📧 **Contact** If you have any questions, feel free to open a discussion or contact us at weiyuli.cn@gmail.com. """ -from apps.third_party.CRM.pipelines import TwoStagePipeline -from apps.third_party.LGM.pipeline_mvdream import MVDreamPipeline -from apps.third_party.Era3D.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline -from apps.third_party.Era3D.data.single_image_dataset import SingleImageDataset - -import re -import os -import stat - -RD, WD, XD = 4, 2, 1 -BNS = [RD, WD, XD] -MDS = [ - [stat.S_IRUSR, stat.S_IRGRP, stat.S_IROTH], - [stat.S_IWUSR, stat.S_IWGRP, stat.S_IWOTH], - [stat.S_IXUSR, stat.S_IXGRP, stat.S_IXOTH] -] - -def chmod(path, mode): - if isinstance(mode, int): - mode = str(mode) - if not re.match("^[0-7]{1,3}$", mode): - raise Exception("mode does not conform to ^[0-7]{1,3}$ pattern") - mode = "{0:0>3}".format(mode) - mode_num = 0 - for midx, m in enumerate(mode): - for bnidx, bn in enumerate(BNS): - if (int(m) & bn) > 0: - mode_num += MDS[bnidx][midx] - os.chmod(path, mode_num) - -chmod(f"{parent_dir}/apps/third_party/InstantMeshes", "777") -device = None model = None cached_dir = None -generator = None - -sys.path.append(f"apps/third_party/CRM") -crm_pipeline = None - -sys.path.append(f"apps/third_party/LGM") -imgaedream_pipeline = None - -sys.path.append(f"apps/third_party/Era3D") -era3d_pipeline = None - -@spaces.GPU(duration=120) -def gen_mvimg( - mvimg_model, image, seed, guidance_scale, step, text, neg_text, elevation, backgroud_color -): - global device - if seed == 0: - seed = np.random.randint(1, 65535) - global generator - generator = torch.Generator(device) - generator.manual_seed(seed) - if mvimg_model == "CRM": - global crm_pipeline - crm_pipeline.set_seed(seed) - background = Image.new("RGBA", image.size, (127, 127, 127)) - image = Image.alpha_composite(background, image) - mv_imgs = crm_pipeline( - image, - scale=guidance_scale, - step=step - )["stage1_images"] - return mv_imgs[5], mv_imgs[3], mv_imgs[2], mv_imgs[0] - - elif mvimg_model == "ImageDream": - global imagedream_pipeline - background = Image.new("RGBA", image.size, backgroud_color) - image = Image.alpha_composite(background, image) - image = np.array(image).astype(np.float32) / 255.0 - image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) - mv_imgs = imagedream_pipeline( - text, - image, - negative_prompt=neg_text, - guidance_scale=guidance_scale, - num_inference_steps=step, - elevation=elevation, - generator=generator, - ) - return mv_imgs[1], mv_imgs[2], mv_imgs[3], mv_imgs[0] - - elif mvimg_model == "Era3D": - global era3d_pipeline - era3d_pipeline.to(device) - era3d_pipeline.unet.enable_xformers_memory_efficient_attention() - era3d_pipeline.set_progress_bar_config(disable=True) +generator = None - crop_size = 420 - batch = SingleImageDataset(root_dir='', num_views=6, img_wh=[512, 512], bg_color='white', - crop_size=crop_size, single_image=image, prompt_embeds_path='apps/third_party/Era3D/data/fixed_prompt_embeds_6view')[0] - imgs_in = torch.cat([batch['imgs_in']]*2, dim=0) - imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W")# (B*Nv, 3, H, W) - - normal_prompt_embeddings, clr_prompt_embeddings = batch['normal_prompt_embeddings'], batch['color_prompt_embeddings'] - prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0) - prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C") - - imgs_in = imgs_in.to(dtype=torch.float16) - prompt_embeddings = prompt_embeddings.to(dtype=torch.float16) +def check_input_image(input_image): + if input_image is None: + raise gr.Error("No image uploaded!") + +class RMBG(object): + def __init__(self): + pass + + def rmbg_rembg(self, input_image, background_color): + def _rembg_remove( + image: PIL.Image.Image, + rembg_session = None, + force: bool = False, + **rembg_kwargs, + ) -> PIL.Image.Image: + do_remove = True + if image.mode == "RGBA" and image.getextrema()[3][0] < 255: + # explain why current do not rm bg + print("alhpa channl not enpty, skip remove background, using alpha channel as mask") + background = Image.new("RGBA", image.size, background_color) + image = Image.alpha_composite(background, image) + do_remove = False + do_remove = do_remove or force + if do_remove: + image = rembg.remove(image, session=rembg_session, **rembg_kwargs) + + # calculate the min bbox of the image + alpha = image.split()[-1] + image = image.crop(alpha.getbbox()) + + return image + return _rembg_remove(input_image, None, force_remove=True) + + def run(self, rm_type, image, foreground_ratio, background_choice, background_color=(0, 0, 0, 0)): + if "Original" in background_choice: + return image + else: + if background_choice == "Alpha as mask": + alpha = image.split()[-1] + image = image.crop(alpha.getbbox()) + + elif "Remove" in background_choice: + if rm_type.upper() == "REMBG": + image = self.rmbg_rembg(image, background_color=background_color) + else: + return -1 - mv_imgs = era3d_pipeline( - imgs_in, - None, - prompt_embeds=prompt_embeddings, - generator=generator, - guidance_scale=guidance_scale, - num_inference_steps=step, - num_images_per_prompt=1, - **{'eta': 1.0} - ).images - return mv_imgs[6], mv_imgs[8], mv_imgs[9], mv_imgs[10] - -@spaces.GPU -def image2mesh(view_front: np.ndarray, - view_right: np.ndarray, - view_back: np.ndarray, - view_left: np.ndarray, + # Calculate the new size after rescaling + new_size = tuple(int(dim * foreground_ratio) for dim in image.size) + # Resize the image while maintaining the aspect ratio + resized_image = image.resize(new_size) + # Create a new image with the original size and white background + padded_image = PIL.Image.new("RGBA", image.size, (0, 0, 0, 0)) + paste_position = ((image.width - resized_image.width) // 2, (image.height - resized_image.height) // 2) + padded_image.paste(resized_image, paste_position) + + # expand image to 1:1 + width, height = padded_image.size + if width == height: + return padded_image + new_size = (max(width, height), max(width, height)) + image = PIL.Image.new("RGBA", new_size, (0, 0, 0, 0)) + paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2) + image.paste(padded_image, paste_position) + return image + +# @spaces.GPU +def image2mesh(image: Any, more: bool = False, scheluder_name: str ="DDIMScheduler", guidance_scale: int = 7.5, - steps: int = 50, + steps: int = 30, seed: int = 4, + target_face_count: int = 2000, octree_depth: int = 7): sample_inputs = { - "mvimages": [[ - Image.fromarray(view_front), - Image.fromarray(view_right), - Image.fromarray(view_back), - Image.fromarray(view_left) - ]] + "image": [ + image + ] } global model latents = model.sample( sample_inputs, sample_times=1, - guidance_scale=guidance_scale, - return_intermediates=False, steps=steps, + guidance_scale=guidance_scale, seed=seed - )[0] # decode the latents to mesh @@ -224,75 +176,42 @@ def image2mesh(view_front: np.ndarray, if 'Remesh' in more: remeshed_filepath = tempfile.NamedTemporaryFile(suffix=f"_remeshed.obj", delete=False).name print("Remeshing with Instant Meshes...") - # target_face_count = int(len(mesh.faces)/10) - target_face_count = 2000 command = f"{proj_dir}/apps/third_party/InstantMeshes {filepath} -f {target_face_count} -o {remeshed_filepath}" os.system(command) - del filepath filepath = remeshed_filepath - # filepath = filepath.replace('.obj', '_remeshed.obj') return filepath if __name__=="__main__": parser = argparse.ArgumentParser() - # parser.add_argument("--model_path", type=str, required=True, help="Path to the object file",) - parser.add_argument("--cached_dir", type=str, default="./gradio_cached_dir") + parser.add_argument("--model_path", type=str, default="./ckpts/craftsman-v1-5", help="Path to the object file",) + parser.add_argument("--cached_dir", type=str, default="") parser.add_argument("--device", type=int, default=0) args = parser.parse_args() + cached_dir = args.cached_dir - os.makedirs(args.cached_dir, exist_ok=True) + if cached_dir != "": + os.makedirs(args.cached_dir, exist_ok=True) device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu") print(f"using device: {device}") - # for multi-view images generation + # for input image background_choice = OrderedDict({ "Alpha as Mask": "Alpha as Mask", "Auto Remove Background": "Auto Remove Background", "Original Image": "Original Image", }) - mvimg_model_config_list = [ - "Era3D", - "CRM", - "ImageDream" - ] - if "Era3D" in mvimg_model_config_list: - # cfg = load_config("apps/third_party/Era3D/configs/test_unclip-512-6view.yaml") - # schema = OmegaConf.structured(TestConfig) - # cfg = OmegaConf.merge(schema, cfg) - era3d_pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained( - 'pengHTYX/MacLab-Era3D-512-6view', - dtype=torch.float16, - ) - # enable xformers - # era3d_pipeline.unet.enable_xformers_memory_efficient_attention() - # era3d_pipeline.to(device) - if "CRM" in mvimg_model_config_list: - stage1_config = OmegaConf.load(f"apps/third_party/CRM/configs/nf7_v3_SNR_rd_size_stroke.yaml").config - stage1_sampler_config = stage1_config.sampler - stage1_model_config = stage1_config.models - stage1_model_config.resume = hf_hub_download(repo_id="Zhengyi/CRM", filename="pixel-diffusion.pth", repo_type="model") - stage1_model_config.config = f"apps/third_party/CRM/" + stage1_model_config.config - crm_pipeline = TwoStagePipeline( - stage1_model_config, - stage1_sampler_config, - device=device, - dtype=torch.float16 - ) - if "ImageDream" in mvimg_model_config_list: - imagedream_pipeline = MVDreamPipeline.from_pretrained( - "ashawkey/imagedream-ipmv-diffusers", # remote weights - torch_dtype=torch.float16, - trust_remote_code=True, - ) + generator = torch.Generator(device) # for 3D latent set diffusion - ckpt_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6-aligned-vae/model.ckpt", repo_type="model") - config_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6-aligned-vae/config.yaml", repo_type="model") - # ckpt_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/model-300k.ckpt", repo_type="model") - # config_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/config.yaml", repo_type="model") + if args.model_path == "": + ckpt_path = hf_hub_download(repo_id="craftsman3d/craftsman-v1-5", filename="model.ckpt", repo_type="model") + config_path = hf_hub_download(repo_id="craftsman3d/craftsman-v1-5", filename="config.yaml", repo_type="model") + else: + ckpt_path = os.path.join(args.model_path, "model.ckpt") + config_path = os.path.join(args.model_path, "config.yaml") scheluder_dict = OrderedDict({ "DDIMScheduler": 'diffusers.schedulers.DDIMScheduler', # "DPMSolverMultistepScheduler": 'diffusers.schedulers.DPMSolverMultistepScheduler', # not support yet @@ -330,22 +249,12 @@ if __name__=="__main__": gr.Markdown('''Try a different seed and MV Model for better results. Good Luck :)''') with gr.Row(): seed = gr.Number(0, label='Seed', show_label=True) - mvimg_model = gr.Dropdown(value="CRM", label="MV Image Model", choices=list(mvimg_model_config_list)) more = gr.CheckboxGroup(["Remesh"], label="More", show_label=False) - - with gr.Row(): - # input prompt - text = gr.Textbox(label="Prompt (Opt.)", info="only works for ImageDream") - - with gr.Accordion('Advanced options', open=False): - # negative prompt - neg_text = gr.Textbox(label="Negative Prompt", value='ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate') - # elevation - elevation = gr.Slider(label="elevation", minimum=-90, maximum=90, step=1, value=0) + target_face_count = gr.Number(2000, label='Target Face Count', show_label=True) with gr.Row(): gr.Examples( - examples=[os.path.join("./apps/examples", i) for i in os.listdir("./apps/examples")], + examples=[os.path.join("./assets/examples", i) for i in os.listdir("./assets/examples")], inputs=[image_input], examples_per_page=8 ) @@ -357,43 +266,17 @@ if __name__=="__main__": camera_position=(90.0, 90.0, 3.5), interactive=False, ) - # with gr.Row(): - # gr.Markdown('''*please note that the model is fliped due to the gradio viewer, please download the obj file and you will get the correct orientation.''') - with gr.Row(): - view_front = gr.Image(label="Front", interactive=True, show_label=True) - view_right = gr.Image(label="Right", interactive=True, show_label=True) - view_back = gr.Image(label="Back", interactive=True, show_label=True) - view_left = gr.Image(label="Left", interactive=True, show_label=True) + gr.Markdown('''*please note that the model is fliped due to the gradio viewer, please download the obj file and you will get the correct orientation.''') with gr.Accordion('Advanced options', open=False): - with gr.Row(equal_height=True): - run_mv_btn = gr.Button('Only Generate 2D', interactive=True) - run_3d_btn = gr.Button('Only Generate 3D', interactive=True) - - with gr.Accordion('Advanced options (2D)', open=False): - with gr.Row(): - foreground_ratio = gr.Slider( - label="Foreground Ratio", - minimum=0.5, - maximum=1.0, - value=1.0, - step=0.05, - ) - with gr.Row(): background_choice = gr.Dropdown(label="Backgroud Choice", value="Auto Remove Background",choices=list(background_choice.keys())) rmbg_type = gr.Dropdown(label="Backgroud Remove Type", value="rembg",choices=['sam', "rembg"]) - backgroud_color = gr.ColorPicker(label="Background Color", value="#FFFFFF", interactive=True) - # backgroud_color = gr.ColorPicker(label="Background Color", value="#7F7F7F", interactive=True) - - with gr.Row(): - mvimg_guidance_scale = gr.Number(value=3.0, minimum=1, maximum=10, label="2D Guidance Scale") - mvimg_steps = gr.Number(value=30, minimum=20, maximum=100, label="2D Sample Steps") - - with gr.Accordion('Advanced options (3D)', open=False): + foreground_ratio = gr.Slider(label="Foreground Ratio", value=1.0, minimum=0.5, maximum=1.0, step=0.01) + with gr.Row(): - guidance_scale = gr.Number(label="3D Guidance Scale", value=3.0, minimum=1.0, maximum=10.0) + guidance_scale = gr.Number(label="3D Guidance Scale", value=5.0, minimum=3.0, maximum=10.0) steps = gr.Number(value=50, minimum=20, maximum=100, label="3D Sample Steps") with gr.Row(): @@ -403,31 +286,28 @@ if __name__=="__main__": gr.Markdown(_CITE_) outputs = [output_model_obj] - rmbg = RMBG(device) + rmbg = RMBG() - model = load_model(ckpt_path, config_path, device) + # model = load_model(ckpt_path, config_path, device) + cfg = load_config(config_path) + model = craftsman.find(cfg.system_type)(cfg.system) + print(f"Restoring states from the checkpoint path at {ckpt_path} with config {cfg}") + ckpt = torch.load(ckpt_path, map_location=torch.device('cpu')) + model.load_state_dict( + ckpt["state_dict"] if "state_dict" in ckpt else ckpt, + ) + model = model.to(device).eval() + run_btn.click(fn=check_input_image, inputs=[image_input] ).success( fn=rmbg.run, - inputs=[rmbg_type, image_input, foreground_ratio, background_choice, backgroud_color], + inputs=[rmbg_type, image_input, foreground_ratio, background_choice], outputs=[image_input] - ).success( - fn=gen_mvimg, - inputs=[mvimg_model, image_input, seed, mvimg_guidance_scale, mvimg_steps, text, neg_text, elevation, backgroud_color], - outputs=[view_front, view_right, view_back, view_left] ).success( fn=image2mesh, - inputs=[view_front, view_right, view_back, view_left, more, scheduler, guidance_scale, steps, seed, octree_depth], + inputs=[image_input, more, scheduler, guidance_scale, steps, seed, target_face_count, octree_depth], outputs=outputs, api_name="generate_img2obj") - run_mv_btn.click(fn=gen_mvimg, - inputs=[mvimg_model, image_input, seed, mvimg_guidance_scale, mvimg_steps, text, neg_text, elevation, backgroud_color], - outputs=[view_front, view_right, view_back, view_left] - ) - run_3d_btn.click(fn=image2mesh, - inputs=[view_front, view_right, view_back, view_left, more, scheduler, guidance_scale, steps, seed, octree_depth], - outputs=outputs, - api_name="generate_img2obj") demo.queue().launch(share=True, allowed_paths=[args.cached_dir]) \ No newline at end of file diff --git a/launch.py b/launch.py deleted file mode 100644 index 854425983580adb24a56bc56071a9003e087f47d..0000000000000000000000000000000000000000 --- a/launch.py +++ /dev/null @@ -1,299 +0,0 @@ -import argparse -import contextlib -import importlib -import logging -import os -import sys -import time -import traceback -import pytorch_lightning as pl -import torch -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint -from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger -from pytorch_lightning.utilities.rank_zero import rank_zero_only -import craftsman -from craftsman.systems.base import BaseSystem -from craftsman.utils.callbacks import ( - CodeSnapshotCallback, - ConfigSnapshotCallback, - CustomProgressBar, - ProgressCallback, -) -from craftsman.utils.config import ExperimentConfig, load_config -from craftsman.utils.misc import get_rank -from craftsman.utils.typing import Optional -class ColoredFilter(logging.Filter): - """ - A logging filter to add color to certain log levels. - """ - - RESET = "\033[0m" - RED = "\033[31m" - GREEN = "\033[32m" - YELLOW = "\033[33m" - BLUE = "\033[34m" - MAGENTA = "\033[35m" - CYAN = "\033[36m" - - COLORS = { - "WARNING": YELLOW, - "INFO": GREEN, - "DEBUG": BLUE, - "CRITICAL": MAGENTA, - "ERROR": RED, - } - - RESET = "\x1b[0m" - - def __init__(self): - super().__init__() - - def filter(self, record): - if record.levelname in self.COLORS: - color_start = self.COLORS[record.levelname] - record.levelname = f"{color_start}[{record.levelname}]" - record.msg = f"{record.msg}{self.RESET}" - return True - - -def load_custom_module(module_path): - module_name = os.path.basename(module_path) - if os.path.isfile(module_path): - sp = os.path.splitext(module_path) - module_name = sp[0] - try: - if os.path.isfile(module_path): - module_spec = importlib.util.spec_from_file_location( - module_name, module_path - ) - else: - module_spec = importlib.util.spec_from_file_location( - module_name, os.path.join(module_path, "__init__.py") - ) - - module = importlib.util.module_from_spec(module_spec) - sys.modules[module_name] = module - module_spec.loader.exec_module(module) - return True - except Exception as e: - print(traceback.format_exc()) - print(f"Cannot import {module_path} module for custom nodes:", e) - return False - - -def load_custom_modules(): - node_paths = ["custom"] - node_import_times = [] - if not os.path.exists("node_paths"): - return - for custom_node_path in node_paths: - possible_modules = os.listdir(custom_node_path) - if "__pycache__" in possible_modules: - possible_modules.remove("__pycache__") - - for possible_module in possible_modules: - module_path = os.path.join(custom_node_path, possible_module) - if ( - os.path.isfile(module_path) - and os.path.splitext(module_path)[1] != ".py" - ): - continue - if module_path.endswith(".disabled"): - continue - time_before = time.perf_counter() - success = load_custom_module(module_path) - node_import_times.append( - (time.perf_counter() - time_before, module_path, success) - ) - - if len(node_import_times) > 0: - print("\nImport times for custom modules:") - for n in sorted(node_import_times): - if n[2]: - import_message = "" - else: - import_message = " (IMPORT FAILED)" - print("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) - print() - - -def main(args, extras) -> None: - # set CUDA_VISIBLE_DEVICES if needed, then import pytorch-lightning - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - env_gpus_str = os.environ.get("CUDA_VISIBLE_DEVICES", None) - env_gpus = list(env_gpus_str.split(",")) if env_gpus_str else [] - selected_gpus = [0] - torch.set_float32_matmul_precision("high") - - # Always rely on CUDA_VISIBLE_DEVICES if specific GPU ID(s) are specified. - # As far as Pytorch Lightning is concerned, we always use all available GPUs - # (possibly filtered by CUDA_VISIBLE_DEVICES). - devices = -1 - if len(env_gpus) > 0: - n_gpus = len(env_gpus) - else: - selected_gpus = list(args.gpu.split(",")) - n_gpus = len(selected_gpus) - print(f"Using {n_gpus} GPUs: {selected_gpus}") - os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu - - if args.typecheck: - from jaxtyping import install_import_hook - - install_import_hook("craftsman", "typeguard.typechecked") - - logger = logging.getLogger("pytorch_lightning") - if args.verbose: - logger.setLevel(logging.DEBUG) - - for handler in logger.handlers: - if handler.stream == sys.stderr: # type: ignore - if not args.gradio: - handler.setFormatter(logging.Formatter("%(levelname)s %(message)s")) - handler.addFilter(ColoredFilter()) - else: - handler.setFormatter(logging.Formatter("[%(levelname)s] %(message)s")) - - load_custom_modules() - - # parse YAML config to OmegaConf - cfg: ExperimentConfig - cfg = load_config(args.config, cli_args=extras, n_gpus=n_gpus) - - # set a different seed for each device - pl.seed_everything(cfg.seed + get_rank(), workers=True) - - dm = craftsman.find(cfg.data_type)(cfg.data) - system: BaseSystem = craftsman.find(cfg.system_type)( - cfg.system, resumed=cfg.resume is not None - ) - system.set_save_dir(os.path.join(cfg.trial_dir, "save")) - - if args.gradio: - fh = logging.FileHandler(os.path.join(cfg.trial_dir, "logs")) - fh.setLevel(logging.INFO) - if args.verbose: - fh.setLevel(logging.DEBUG) - fh.setFormatter(logging.Formatter("[%(levelname)s] %(message)s")) - logger.addHandler(fh) - - callbacks = [] - if args.train: - callbacks += [ - ModelCheckpoint( - dirpath=os.path.join(cfg.trial_dir, "ckpts"), **cfg.checkpoint - ), - LearningRateMonitor(logging_interval="step"), - CodeSnapshotCallback( - os.path.join(cfg.trial_dir, "code"), use_version=False - ), - ConfigSnapshotCallback( - args.config, - cfg, - os.path.join(cfg.trial_dir, "configs"), - use_version=False, - ), - ] - if args.gradio: - callbacks += [ - ProgressCallback(save_path=os.path.join(cfg.trial_dir, "progress")) - ] - else: - callbacks += [CustomProgressBar(refresh_rate=1)] - - def write_to_text(file, lines): - with open(file, "w") as f: - for line in lines: - f.write(line + "\n") - - loggers = [] - if args.train: - # make tensorboard logging dir to suppress warning - rank_zero_only( - lambda: os.makedirs(os.path.join(cfg.trial_dir, "tb_logs"), exist_ok=True) - )() - loggers += [ - TensorBoardLogger(cfg.trial_dir, name="tb_logs"), - CSVLogger(cfg.trial_dir, name="csv_logs"), - ] + system.get_loggers() - rank_zero_only( - lambda: write_to_text( - os.path.join(cfg.trial_dir, "cmd.txt"), - ["python " + " ".join(sys.argv), str(args)], - ) - )() - - trainer = Trainer( - callbacks=callbacks, - logger=loggers, - inference_mode=False, - accelerator="gpu", - devices=devices, - # profiler="advanced", - **cfg.trainer, - ) - - def set_system_status(system: BaseSystem, ckpt_path: Optional[str]): - if ckpt_path is None: - return - ckpt = torch.load(ckpt_path, map_location="cpu") - system.set_resume_status(ckpt["epoch"], ckpt["global_step"]) - if args.train: - trainer.fit(system, datamodule=dm, ckpt_path=cfg.resume) - trainer.test(system, datamodule=dm) - if args.gradio: - # also export assets if in gradio mode - trainer.predict(system, datamodule=dm) - elif args.validate: - # manually set epoch and global_step as they cannot be automatically resumed - set_system_status(system, cfg.resume) - trainer.validate(system, datamodule=dm, ckpt_path=cfg.resume) - elif args.test: - # manually set epoch and global_step as they cannot be automatically resumed - set_system_status(system, cfg.resume) - trainer.test(system, datamodule=dm, ckpt_path=cfg.resume) - elif args.export: - set_system_status(system, cfg.resume) - trainer.predict(system, datamodule=dm, ckpt_path=cfg.resume) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--config", required=True, help="path to config file") - parser.add_argument( - "--gpu", - default="0", - help="GPU(s) to be used. 0 means use the 1st available GPU. " - "1,2 means use the 2nd and 3rd available GPU. " - "If CUDA_VISIBLE_DEVICES is set before calling `launch.py`, " - "this argument is ignored and all available GPUs are always used.", - ) - - group = parser.add_mutually_exclusive_group(required=True) - group.add_argument("--train", action="store_true") - group.add_argument("--validate", action="store_true") - group.add_argument("--test", action="store_true") - group.add_argument("--export", action="store_true") - - parser.add_argument( - "--gradio", action="store_true", help="if true, run in gradio mode" - ) - - parser.add_argument( - "--verbose", action="store_true", help="if true, set logging level to DEBUG" - ) - - parser.add_argument( - "--typecheck", - action="store_true", - help="whether to enable dynamic type checking", - ) - - args, extras = parser.parse_known_args() - - if args.gradio: - with contextlib.redirect_stdout(sys.stderr): - main(args, extras) - else: - main(args, extras) diff --git a/packages.txt b/packages.txt old mode 100644 new mode 100755 diff --git a/requirements.txt b/requirements.txt old mode 100644 new mode 100755