diff --git a/apps/examples/1.jpg b/apps/examples/1.jpg
deleted file mode 100644
index b4b30221c1008797a291ded245f4153bf927fbb7..0000000000000000000000000000000000000000
Binary files a/apps/examples/1.jpg and /dev/null differ
diff --git a/apps/examples/1_cute_girl.webp b/apps/examples/1_cute_girl.webp
deleted file mode 100644
index 648206b5aa91124707cd5513bcc6ed8e866e1de6..0000000000000000000000000000000000000000
Binary files a/apps/examples/1_cute_girl.webp and /dev/null differ
diff --git a/apps/examples/bird.jpg b/apps/examples/bird.jpg
deleted file mode 100644
index 62622f31d4fe078c93c546941b9b64701889f37e..0000000000000000000000000000000000000000
Binary files a/apps/examples/bird.jpg and /dev/null differ
diff --git a/apps/examples/blue_monster.webp b/apps/examples/blue_monster.webp
deleted file mode 100644
index 7204f5fa63d86843269210c306488444e0efafe1..0000000000000000000000000000000000000000
Binary files a/apps/examples/blue_monster.webp and /dev/null differ
diff --git a/apps/examples/boy2.webp b/apps/examples/boy2.webp
deleted file mode 100644
index a2e1615efd6ca3f87aac3f43884666b85a804dc4..0000000000000000000000000000000000000000
Binary files a/apps/examples/boy2.webp and /dev/null differ
diff --git a/apps/examples/bulldog.webp b/apps/examples/bulldog.webp
deleted file mode 100644
index 318fcf22cf756b31e25364f8fe28a3852c72fc9c..0000000000000000000000000000000000000000
Binary files a/apps/examples/bulldog.webp and /dev/null differ
diff --git a/apps/examples/catman.webp b/apps/examples/catman.webp
deleted file mode 100644
index 5dd68d7308d4785f314e976561facc43e9d9e53a..0000000000000000000000000000000000000000
Binary files a/apps/examples/catman.webp and /dev/null differ
diff --git a/apps/examples/cyberpunk_man.webp b/apps/examples/cyberpunk_man.webp
deleted file mode 100644
index 4c609f49d6dcfcb93d11c85b675fe1b0641300c9..0000000000000000000000000000000000000000
Binary files a/apps/examples/cyberpunk_man.webp and /dev/null differ
diff --git a/apps/examples/dinosaur_boy.webp b/apps/examples/dinosaur_boy.webp
deleted file mode 100644
index afd68e6205bba8dc5dd178ec8dcd0a6fd24287bb..0000000000000000000000000000000000000000
Binary files a/apps/examples/dinosaur_boy.webp and /dev/null differ
diff --git a/apps/examples/doraemon.webp b/apps/examples/doraemon.webp
deleted file mode 100644
index b5923198c1a6d3da835f9d9f6299ad8f3ad81732..0000000000000000000000000000000000000000
Binary files a/apps/examples/doraemon.webp and /dev/null differ
diff --git a/apps/examples/dragon.webp b/apps/examples/dragon.webp
deleted file mode 100644
index 2debdc5f7c546e8830ee6f79b7470edc43aff33a..0000000000000000000000000000000000000000
Binary files a/apps/examples/dragon.webp and /dev/null differ
diff --git a/apps/examples/dragontoy.jpg b/apps/examples/dragontoy.jpg
deleted file mode 100644
index 39ce417c23734e91a667eeb852fca63e1f6249b3..0000000000000000000000000000000000000000
Binary files a/apps/examples/dragontoy.jpg and /dev/null differ
diff --git a/apps/examples/girl1.webp b/apps/examples/girl1.webp
deleted file mode 100644
index 3e42bbdffd07d3a5449399e4c2e229d3cd1da7ec..0000000000000000000000000000000000000000
Binary files a/apps/examples/girl1.webp and /dev/null differ
diff --git a/apps/examples/gun.webp b/apps/examples/gun.webp
deleted file mode 100644
index 589c5391e53bb89c98673d5b2b2e0d76677bcb32..0000000000000000000000000000000000000000
Binary files a/apps/examples/gun.webp and /dev/null differ
diff --git a/apps/examples/kunkun.webp b/apps/examples/kunkun.webp
deleted file mode 100644
index a0c45d3c7091cdad8ef69374a9cbd065353e7982..0000000000000000000000000000000000000000
Binary files a/apps/examples/kunkun.webp and /dev/null differ
diff --git a/apps/examples/link.webp b/apps/examples/link.webp
deleted file mode 100644
index 1b5bcfc99fd3e2e6d674b5c877f4128109355900..0000000000000000000000000000000000000000
Binary files a/apps/examples/link.webp and /dev/null differ
diff --git a/apps/examples/mushroom1.webp b/apps/examples/mushroom1.webp
deleted file mode 100644
index bc41abe7d5113bb147a85fdb4f5b8f2d8e1e3851..0000000000000000000000000000000000000000
Binary files a/apps/examples/mushroom1.webp and /dev/null differ
diff --git a/apps/examples/mushroom2.webp b/apps/examples/mushroom2.webp
deleted file mode 100644
index 08d870a503df6970534797730f7e429e57a84ab8..0000000000000000000000000000000000000000
Binary files a/apps/examples/mushroom2.webp and /dev/null differ
diff --git a/apps/examples/phoenix.webp b/apps/examples/phoenix.webp
deleted file mode 100644
index 7a15695cbdbda761ac08926c40cb9d19400051d8..0000000000000000000000000000000000000000
Binary files a/apps/examples/phoenix.webp and /dev/null differ
diff --git a/apps/examples/robot.png b/apps/examples/robot.png
deleted file mode 100644
index 5522be23a8373e96af1f6ba0b08fe31ec0467a41..0000000000000000000000000000000000000000
Binary files a/apps/examples/robot.png and /dev/null differ
diff --git a/apps/examples/rose.webp b/apps/examples/rose.webp
deleted file mode 100644
index d245944b32d29a5a7c56bf72b595fcc2b7de6650..0000000000000000000000000000000000000000
Binary files a/apps/examples/rose.webp and /dev/null differ
diff --git a/apps/examples/shoe.webp b/apps/examples/shoe.webp
deleted file mode 100644
index 41bbb431f566663f3fb8ebd5dc32533cd4a2b990..0000000000000000000000000000000000000000
Binary files a/apps/examples/shoe.webp and /dev/null differ
diff --git a/apps/examples/sports_girl.webp b/apps/examples/sports_girl.webp
deleted file mode 100644
index 15006e2e0608f602aa79d17ea0f24d94b9e62061..0000000000000000000000000000000000000000
Binary files a/apps/examples/sports_girl.webp and /dev/null differ
diff --git a/apps/examples/stone.webp b/apps/examples/stone.webp
deleted file mode 100644
index c462867d2f5d5a0f9ab7414b20af8b6adb46e2ae..0000000000000000000000000000000000000000
Binary files a/apps/examples/stone.webp and /dev/null differ
diff --git a/apps/examples/sweater.webp b/apps/examples/sweater.webp
deleted file mode 100644
index 31bcb40ca77d0b2a6cb37ebde421c1d72fd44ddd..0000000000000000000000000000000000000000
Binary files a/apps/examples/sweater.webp and /dev/null differ
diff --git a/apps/examples/sword.webp b/apps/examples/sword.webp
deleted file mode 100644
index be576fa4a76ca8ddcc934d81caaa65eac2faa04c..0000000000000000000000000000000000000000
Binary files a/apps/examples/sword.webp and /dev/null differ
diff --git a/apps/examples/teapot.webp b/apps/examples/teapot.webp
deleted file mode 100644
index a1e5809eeae47e05149c9f7b19d80810de58d19f..0000000000000000000000000000000000000000
Binary files a/apps/examples/teapot.webp and /dev/null differ
diff --git a/apps/examples/toy_bear.webp b/apps/examples/toy_bear.webp
deleted file mode 100644
index 7cfff09ab76ea5f0ea6ac6a180d349e593e05fee..0000000000000000000000000000000000000000
Binary files a/apps/examples/toy_bear.webp and /dev/null differ
diff --git a/apps/examples/toy_dog.webp b/apps/examples/toy_dog.webp
deleted file mode 100644
index 9e9d6aae6fa32807b734df2351a3d07e7e1ffa55..0000000000000000000000000000000000000000
Binary files a/apps/examples/toy_dog.webp and /dev/null differ
diff --git a/apps/examples/toy_pig.webp b/apps/examples/toy_pig.webp
deleted file mode 100644
index 600edfaea0c067f78d9db4c1e9af044b92437c91..0000000000000000000000000000000000000000
Binary files a/apps/examples/toy_pig.webp and /dev/null differ
diff --git a/apps/examples/toy_rabbit.webp b/apps/examples/toy_rabbit.webp
deleted file mode 100644
index 0c51b6e83ba928031967ae091c31b40efd9580ae..0000000000000000000000000000000000000000
Binary files a/apps/examples/toy_rabbit.webp and /dev/null differ
diff --git a/apps/examples/wiking.webp b/apps/examples/wiking.webp
deleted file mode 100644
index 71ff562749bbe8a43c154d282f1c21ca21dd5a96..0000000000000000000000000000000000000000
Binary files a/apps/examples/wiking.webp and /dev/null differ
diff --git a/apps/examples/wings.webp b/apps/examples/wings.webp
deleted file mode 100644
index 41b9e0e629d029bd27674b420d6a9fe0b3655156..0000000000000000000000000000000000000000
Binary files a/apps/examples/wings.webp and /dev/null differ
diff --git a/apps/third_party/CRM/LICENSE b/apps/third_party/CRM/LICENSE
deleted file mode 100644
index 8840910d3e0d809884ce88440d615cf493475272..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/LICENSE
+++ /dev/null
@@ -1,21 +0,0 @@
-MIT License
-
-Copyright (c) 2024 TSAIL group
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
diff --git a/apps/third_party/CRM/README.md b/apps/third_party/CRM/README.md
deleted file mode 100644
index 0fb9821a04f51330e00206575bab88fce08fb786..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/README.md
+++ /dev/null
@@ -1,85 +0,0 @@
-# Convolutional Reconstruction Model
-
-Official implementation for *CRM: Single Image to 3D Textured Mesh with Convolutional Reconstruction Model*.
-
-**CRM is a feed-forward model which can generate 3D textured mesh in 10 seconds.**
-
-## [Project Page](https://ml.cs.tsinghua.edu.cn/~zhengyi/CRM/) | [Arxiv](https://arxiv.org/abs/2403.05034) | [HF-Demo](https://huggingface.co./spaces/Zhengyi/CRM) | [Weights](https://huggingface.co./Zhengyi/CRM)
-
-https://github.com/thu-ml/CRM/assets/40787266/8b325bc0-aa74-4c26-92e8-a8f0c1079382
-
-## Try CRM 🍻
-* Try CRM at [Huggingface Demo](https://huggingface.co./spaces/Zhengyi/CRM).
-* Try CRM at [Replicate Demo](https://replicate.com/camenduru/crm). Thanks [@camenduru](https://github.com/camenduru)!
-
-## Install
-
-### Step 1 - Base
-
-Install package one by one, we use **python 3.9**
-
-```bash
-pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117
-pip install torch-scatter==2.1.1 -f https://data.pyg.org/whl/torch-1.13.1+cu117.html
-pip install kaolin==0.14.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-1.13.1_cu117.html
-pip install -r requirements.txt
-```
-
-besides, one by one need to install xformers manually according to the official [doc](https://github.com/facebookresearch/xformers?tab=readme-ov-file#installing-xformers) (**conda no need**), e.g.
-
-```bash
-pip install ninja
-pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
-```
-
-### Step 2 - Nvdiffrast
-
-Install nvdiffrast according to the official [doc](https://nvlabs.github.io/nvdiffrast/#installation), e.g.
-
-```bash
-pip install git+https://github.com/NVlabs/nvdiffrast
-```
-
-
-
-## Inference
-
-We suggest gradio for a visualized inference.
-
-```
-gradio app.py
-```
-
-![image](https://github.com/thu-ml/CRM/assets/40787266/4354d22a-a641-4531-8408-c761ead8b1a2)
-
-For inference in command lines, simply run
-```bash
-CUDA_VISIBLE_DEVICES="0" python run.py --inputdir "examples/kunkun.webp"
-```
-It will output the preprocessed image, generated 6-view images and CCMs and a 3D model in obj format.
-
-**Tips:** (1) If the result is unsatisfatory, please check whether the input image is correctly pre-processed into a grey background. Otherwise the results will be unpredictable.
-(2) Different from the [Huggingface Demo](https://huggingface.co./spaces/Zhengyi/CRM), this official implementation uses UV texture instead of vertex color. It has better texture than the online demo but longer generating time owing to the UV texturing.
-
-## Todo List
-- [x] Release inference code.
-- [x] Release pretrained models.
-- [ ] Optimize inference code to fit in low memery GPU.
-- [ ] Upload training code.
-
-## Acknowledgement
-- [ImageDream](https://github.com/bytedance/ImageDream)
-- [nvdiffrast](https://github.com/NVlabs/nvdiffrast)
-- [kiuikit](https://github.com/ashawkey/kiuikit)
-- [GET3D](https://github.com/nv-tlabs/GET3D)
-
-## Citation
-
-```
-@article{wang2024crm,
- title={CRM: Single Image to 3D Textured Mesh with Convolutional Reconstruction Model},
- author={Zhengyi Wang and Yikai Wang and Yifei Chen and Chendong Xiang and Shuo Chen and Dajiang Yu and Chongxuan Li and Hang Su and Jun Zhu},
- journal={arXiv preprint arXiv:2403.05034},
- year={2024}
-}
-```
diff --git a/apps/third_party/CRM/app.py b/apps/third_party/CRM/app.py
deleted file mode 100644
index dae2f6baf19cf3bb2817f166c81679ae4eec0685..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/app.py
+++ /dev/null
@@ -1,228 +0,0 @@
-# Not ready to use yet
-import argparse
-import numpy as np
-import gradio as gr
-from omegaconf import OmegaConf
-import torch
-from PIL import Image
-import PIL
-from pipelines import TwoStagePipeline
-from huggingface_hub import hf_hub_download
-import os
-import rembg
-from typing import Any
-import json
-import os
-import json
-import argparse
-
-from model import CRM
-from inference import generate3d
-
-pipeline = None
-rembg_session = rembg.new_session()
-
-
-def expand_to_square(image, bg_color=(0, 0, 0, 0)):
- # expand image to 1:1
- width, height = image.size
- if width == height:
- return image
- new_size = (max(width, height), max(width, height))
- new_image = Image.new("RGBA", new_size, bg_color)
- paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2)
- new_image.paste(image, paste_position)
- return new_image
-
-def check_input_image(input_image):
- if input_image is None:
- raise gr.Error("No image uploaded!")
-
-
-def remove_background(
- image: PIL.Image.Image,
- rembg_session = None,
- force: bool = False,
- **rembg_kwargs,
-) -> PIL.Image.Image:
- do_remove = True
- if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
- # explain why current do not rm bg
- print("alhpa channl not enpty, skip remove background, using alpha channel as mask")
- background = Image.new("RGBA", image.size, (0, 0, 0, 0))
- image = Image.alpha_composite(background, image)
- do_remove = False
- do_remove = do_remove or force
- if do_remove:
- image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
- return image
-
-def do_resize_content(original_image: Image, scale_rate):
- # resize image content wile retain the original image size
- if scale_rate != 1:
- # Calculate the new size after rescaling
- new_size = tuple(int(dim * scale_rate) for dim in original_image.size)
- # Resize the image while maintaining the aspect ratio
- resized_image = original_image.resize(new_size)
- # Create a new image with the original size and black background
- padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0))
- paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2)
- padded_image.paste(resized_image, paste_position)
- return padded_image
- else:
- return original_image
-
-def add_background(image, bg_color=(255, 255, 255)):
- # given an RGBA image, alpha channel is used as mask to add background color
- background = Image.new("RGBA", image.size, bg_color)
- return Image.alpha_composite(background, image)
-
-
-def preprocess_image(image, background_choice, foreground_ratio, backgroud_color):
- """
- input image is a pil image in RGBA, return RGB image
- """
- print(background_choice)
- if background_choice == "Alpha as mask":
- background = Image.new("RGBA", image.size, (0, 0, 0, 0))
- image = Image.alpha_composite(background, image)
- else:
- image = remove_background(image, rembg_session, force_remove=True)
- image = do_resize_content(image, foreground_ratio)
- image = expand_to_square(image)
- image = add_background(image, backgroud_color)
- return image.convert("RGB")
-
-
-def gen_image(input_image, seed, scale, step):
- global pipeline, model, args
- pipeline.set_seed(seed)
- rt_dict = pipeline(input_image, scale=scale, step=step)
- stage1_images = rt_dict["stage1_images"]
- stage2_images = rt_dict["stage2_images"]
- np_imgs = np.concatenate(stage1_images, 1)
- np_xyzs = np.concatenate(stage2_images, 1)
-
- glb_path, obj_path = generate3d(model, np_imgs, np_xyzs, args.device)
- return Image.fromarray(np_imgs), Image.fromarray(np_xyzs), glb_path, obj_path
-
-
-parser = argparse.ArgumentParser()
-parser.add_argument(
- "--stage1_config",
- type=str,
- default="configs/nf7_v3_SNR_rd_size_stroke.yaml",
- help="config for stage1",
-)
-parser.add_argument(
- "--stage2_config",
- type=str,
- default="configs/stage2-v2-snr.yaml",
- help="config for stage2",
-)
-
-parser.add_argument("--device", type=str, default="cuda")
-args = parser.parse_args()
-
-crm_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="CRM.pth")
-specs = json.load(open("configs/specs_objaverse_total.json"))
-model = CRM(specs).to(args.device)
-model.load_state_dict(torch.load(crm_path, map_location = args.device), strict=False)
-
-stage1_config = OmegaConf.load(args.stage1_config).config
-stage2_config = OmegaConf.load(args.stage2_config).config
-stage2_sampler_config = stage2_config.sampler
-stage1_sampler_config = stage1_config.sampler
-
-stage1_model_config = stage1_config.models
-stage2_model_config = stage2_config.models
-
-xyz_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="ccm-diffusion.pth")
-pixel_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="pixel-diffusion.pth")
-stage1_model_config.resume = pixel_path
-stage2_model_config.resume = xyz_path
-
-pipeline = TwoStagePipeline(
- stage1_model_config,
- stage2_model_config,
- stage1_sampler_config,
- stage2_sampler_config,
- device=args.device,
- dtype=torch.float16
-)
-
-with gr.Blocks() as demo:
- gr.Markdown("# CRM: Single Image to 3D Textured Mesh with Convolutional Reconstruction Model")
- with gr.Row():
- with gr.Column():
- with gr.Row():
- image_input = gr.Image(
- label="Image input",
- image_mode="RGBA",
- sources="upload",
- type="pil",
- )
- processed_image = gr.Image(label="Processed Image", interactive=False, type="pil", image_mode="RGB")
- with gr.Row():
- with gr.Column():
- with gr.Row():
- background_choice = gr.Radio([
- "Alpha as mask",
- "Auto Remove background"
- ], value="Auto Remove background",
- label="backgroud choice")
- # do_remove_background = gr.Checkbox(label=, value=True)
- # force_remove = gr.Checkbox(label=, value=False)
- back_groud_color = gr.ColorPicker(label="Background Color", value="#7F7F7F", interactive=False)
- foreground_ratio = gr.Slider(
- label="Foreground Ratio",
- minimum=0.5,
- maximum=1.0,
- value=1.0,
- step=0.05,
- )
-
- with gr.Column():
- seed = gr.Number(value=1234, label="seed", precision=0)
- guidance_scale = gr.Number(value=5.5, minimum=3, maximum=10, label="guidance_scale")
- step = gr.Number(value=50, minimum=30, maximum=100, label="sample steps", precision=0)
- text_button = gr.Button("Generate 3D shape")
- gr.Examples(
- examples=[os.path.join("examples", i) for i in os.listdir("examples")],
- inputs=[image_input],
- )
- with gr.Column():
- image_output = gr.Image(interactive=False, label="Output RGB image")
- xyz_ouput = gr.Image(interactive=False, label="Output CCM image")
-
- output_model = gr.Model3D(
- label="Output GLB",
- interactive=False,
- )
- gr.Markdown("Note: The GLB model shown here has a darker lighting and enlarged UV seams. Download for correct results.")
- output_obj = gr.File(interactive=False, label="Output OBJ")
-
- inputs = [
- processed_image,
- seed,
- guidance_scale,
- step,
- ]
- outputs = [
- image_output,
- xyz_ouput,
- output_model,
- output_obj,
- ]
-
-
- text_button.click(fn=check_input_image, inputs=[image_input]).success(
- fn=preprocess_image,
- inputs=[image_input, background_choice, foreground_ratio, back_groud_color],
- outputs=[processed_image],
- ).success(
- fn=gen_image,
- inputs=inputs,
- outputs=outputs,
- )
- demo.queue().launch()
diff --git a/apps/third_party/CRM/configs/nf7_v3_SNR_rd_size_stroke.yaml b/apps/third_party/CRM/configs/nf7_v3_SNR_rd_size_stroke.yaml
deleted file mode 100644
index 760f41f2728a94114f674e2160a75a65a8a3a656..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/configs/nf7_v3_SNR_rd_size_stroke.yaml
+++ /dev/null
@@ -1,21 +0,0 @@
-config:
-# others
- seed: 1234
- num_frames: 7
- mode: pixel
- offset_noise: true
-# model related
- models:
- config: imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml
- resume: models/pixel.pth
-# sampler related
- sampler:
- target: libs.sample.ImageDreamDiffusion
- params:
- mode: pixel
- num_frames: 7
- camera_views: [1, 2, 3, 4, 5, 0, 0]
- ref_position: 6
- random_background: false
- offset_noise: true
- resize_rate: 1.0
\ No newline at end of file
diff --git a/apps/third_party/CRM/configs/specs_objaverse_total.json b/apps/third_party/CRM/configs/specs_objaverse_total.json
deleted file mode 100644
index c99ebee563a7d44859338382b197ef55963e87d0..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/configs/specs_objaverse_total.json
+++ /dev/null
@@ -1,57 +0,0 @@
-{
- "Input": {
- "img_num": 16,
- "class": "all",
- "camera_angle_num": 8,
- "tet_grid_size": 80,
- "validate_num": 16,
- "scale": 0.95,
- "radius": 3,
- "resolution": [256, 256]
- },
-
- "Pretrain": {
- "mode": null,
- "sdf_threshold": 0.1,
- "sdf_scale": 10,
- "batch_infer": false,
- "lr": 1e-4,
- "radius": 0.5
- },
-
- "Train": {
- "mode": "rnd",
- "num_epochs": 500,
- "grad_acc": 1,
- "warm_up": 0,
- "decay": 0.000,
- "learning_rate": {
- "init": 1e-4,
- "sdf_decay": 1,
- "rgb_decay": 1
- },
- "batch_size": 4,
- "eva_iter": 80,
- "eva_all_epoch": 10,
- "tex_sup_mode": "blender",
- "exp_uv_mesh": false,
- "doub": false,
- "random_bg": false,
- "shift": 0,
- "aug_shift": 0,
- "geo_type": "flex"
- },
-
- "ArchSpecs": {
- "unet_type": "diffusers",
- "use_3D_aware": false,
- "fea_concat": false,
- "mlp_bias": true
- },
-
- "DecoderSpecs": {
- "c_dim": 32,
- "plane_resolution": 256
- }
-}
-
diff --git a/apps/third_party/CRM/configs/stage2-v2-snr.yaml b/apps/third_party/CRM/configs/stage2-v2-snr.yaml
deleted file mode 100644
index 8e76d1a2a8ff71ba1318c9ff6ff6c59a7a9e606e..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/configs/stage2-v2-snr.yaml
+++ /dev/null
@@ -1,25 +0,0 @@
-config:
-# others
- seed: 1234
- num_frames: 6
- mode: pixel
- offset_noise: true
- gd_type: xyz
-# model related
- models:
- config: imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml
- resume: models/xyz.pth
-
-# eval related
- sampler:
- target: libs.sample.ImageDreamDiffusionStage2
- params:
- mode: pixel
- num_frames: 6
- camera_views: [1, 2, 3, 4, 5, 0]
- ref_position: null
- random_background: false
- offset_noise: true
- resize_rate: 1.0
-
-
diff --git a/apps/third_party/CRM/imagedream/.DS_Store b/apps/third_party/CRM/imagedream/.DS_Store
deleted file mode 100644
index 5e042f9aa5cf18c3b51e766fee10546755d5ed4d..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/.DS_Store and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/__init__.py b/apps/third_party/CRM/imagedream/__init__.py
deleted file mode 100644
index 326f18c2d65a018b1c214c71f1c44428a8a8089b..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/imagedream/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .model_zoo import build_model
diff --git a/apps/third_party/CRM/imagedream/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index 5514f1066fcdb863f45a135e26483e3b31c2a44d..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/__pycache__/__init__.cpython-38.pyc
deleted file mode 100644
index f7c150c7cb625327e1ccb59742e7c4842561eb1d..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/__pycache__/__init__.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/__pycache__/camera_utils.cpython-310.pyc b/apps/third_party/CRM/imagedream/__pycache__/camera_utils.cpython-310.pyc
deleted file mode 100644
index aaaf53fc96fd79461da559cd87213eac03d63e1a..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/__pycache__/camera_utils.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/__pycache__/camera_utils.cpython-38.pyc b/apps/third_party/CRM/imagedream/__pycache__/camera_utils.cpython-38.pyc
deleted file mode 100644
index 599048b52c049602f55ce7f6b569e278ae8e8dd6..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/__pycache__/camera_utils.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/__pycache__/model_zoo.cpython-310.pyc b/apps/third_party/CRM/imagedream/__pycache__/model_zoo.cpython-310.pyc
deleted file mode 100644
index 7eaa9fbbbfe649e49709f705e513ef0f40d7690d..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/__pycache__/model_zoo.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/__pycache__/model_zoo.cpython-38.pyc b/apps/third_party/CRM/imagedream/__pycache__/model_zoo.cpython-38.pyc
deleted file mode 100644
index 6bd45cca067b99143299e1ba9b78a2a23846e390..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/__pycache__/model_zoo.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/camera_utils.py b/apps/third_party/CRM/imagedream/camera_utils.py
deleted file mode 100644
index 6fb352745d737d96b2652f80c23778058e81f6e6..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/imagedream/camera_utils.py
+++ /dev/null
@@ -1,99 +0,0 @@
-import numpy as np
-import torch
-
-
-def create_camera_to_world_matrix(elevation, azimuth):
- elevation = np.radians(elevation)
- azimuth = np.radians(azimuth)
- # Convert elevation and azimuth angles to Cartesian coordinates on a unit sphere
- x = np.cos(elevation) * np.sin(azimuth)
- y = np.sin(elevation)
- z = np.cos(elevation) * np.cos(azimuth)
-
- # Calculate camera position, target, and up vectors
- camera_pos = np.array([x, y, z])
- target = np.array([0, 0, 0])
- up = np.array([0, 1, 0])
-
- # Construct view matrix
- forward = target - camera_pos
- forward /= np.linalg.norm(forward)
- right = np.cross(forward, up)
- right /= np.linalg.norm(right)
- new_up = np.cross(right, forward)
- new_up /= np.linalg.norm(new_up)
- cam2world = np.eye(4)
- cam2world[:3, :3] = np.array([right, new_up, -forward]).T
- cam2world[:3, 3] = camera_pos
- return cam2world
-
-
-def convert_opengl_to_blender(camera_matrix):
- if isinstance(camera_matrix, np.ndarray):
- # Construct transformation matrix to convert from OpenGL space to Blender space
- flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
- camera_matrix_blender = np.dot(flip_yz, camera_matrix)
- else:
- # Construct transformation matrix to convert from OpenGL space to Blender space
- flip_yz = torch.tensor(
- [[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]
- )
- if camera_matrix.ndim == 3:
- flip_yz = flip_yz.unsqueeze(0)
- camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix), camera_matrix)
- return camera_matrix_blender
-
-
-def normalize_camera(camera_matrix):
- """normalize the camera location onto a unit-sphere"""
- if isinstance(camera_matrix, np.ndarray):
- camera_matrix = camera_matrix.reshape(-1, 4, 4)
- translation = camera_matrix[:, :3, 3]
- translation = translation / (
- np.linalg.norm(translation, axis=1, keepdims=True) + 1e-8
- )
- camera_matrix[:, :3, 3] = translation
- else:
- camera_matrix = camera_matrix.reshape(-1, 4, 4)
- translation = camera_matrix[:, :3, 3]
- translation = translation / (
- torch.norm(translation, dim=1, keepdim=True) + 1e-8
- )
- camera_matrix[:, :3, 3] = translation
- return camera_matrix.reshape(-1, 16)
-
-
-def get_camera(
- num_frames,
- elevation=15,
- azimuth_start=0,
- azimuth_span=360,
- blender_coord=True,
- extra_view=False,
-):
- angle_gap = azimuth_span / num_frames
- cameras = []
- for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
- camera_matrix = create_camera_to_world_matrix(elevation, azimuth)
- if blender_coord:
- camera_matrix = convert_opengl_to_blender(camera_matrix)
- cameras.append(camera_matrix.flatten())
-
- if extra_view:
- dim = len(cameras[0])
- cameras.append(np.zeros(dim))
- return torch.tensor(np.stack(cameras, 0)).float()
-
-
-def get_camera_for_index(data_index):
- """
- 按照当前我们的数据格式, 以000为正对我们的情况:
- 000是正面, ev: 0, azimuth: 0
- 001是左边, ev: 0, azimuth: -90
- 002是下面, ev: -90, azimuth: 0
- 003是背面, ev: 0, azimuth: 180
- 004是右边, ev: 0, azimuth: 90
- 005是上面, ev: 90, azimuth: 0
- """
- params = [(0, 0), (0, -90), (-90, 0), (0, 180), (0, 90), (90, 0)]
- return get_camera(1, *params[data_index])
\ No newline at end of file
diff --git a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv.yaml b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv.yaml
deleted file mode 100644
index b4ecc2a694dbbde82d45bc3b16d1ebbb27ac552a..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv.yaml
+++ /dev/null
@@ -1,61 +0,0 @@
-model:
- target: imagedream.ldm.interface.LatentDiffusionInterface
- params:
- linear_start: 0.00085
- linear_end: 0.0120
- timesteps: 1000
- scale_factor: 0.18215
- parameterization: "eps"
-
- unet_config:
- target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel
- params:
- image_size: 32 # unused
- in_channels: 4
- out_channels: 4
- model_channels: 320
- attention_resolutions: [ 4, 2, 1 ]
- num_res_blocks: 2
- channel_mult: [ 1, 2, 4, 4 ]
- num_head_channels: 64 # need to fix for flash-attn
- use_spatial_transformer: True
- use_linear_in_transformer: True
- transformer_depth: 1
- context_dim: 1024
- use_checkpoint: False
- legacy: False
- camera_dim: 16
- with_ip: True
- ip_dim: 16 # ip token length
- ip_mode: "local_resample"
-
- vae_config:
- target: imagedream.ldm.models.autoencoder.AutoencoderKL
- params:
- embed_dim: 4
- monitor: val/rec_loss
- ddconfig:
- #attn_type: "vanilla-xformers"
- double_z: true
- z_channels: 4
- resolution: 256
- in_channels: 3
- out_ch: 3
- ch: 128
- ch_mult:
- - 1
- - 2
- - 4
- - 4
- num_res_blocks: 2
- attn_resolutions: []
- dropout: 0.0
- lossconfig:
- target: torch.nn.Identity
-
- clip_config:
- target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
- params:
- freeze: True
- layer: "penultimate"
- ip_mode: "local_resample"
diff --git a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_ch8.yaml b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_ch8.yaml
deleted file mode 100644
index ee80712395fa6323bceb437a40b4829bb497adf7..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_ch8.yaml
+++ /dev/null
@@ -1,61 +0,0 @@
-model:
- target: imagedream.ldm.interface.LatentDiffusionInterface
- params:
- linear_start: 0.00085
- linear_end: 0.0120
- timesteps: 1000
- scale_factor: 0.18215
- parameterization: "eps"
-
- unet_config:
- target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel
- params:
- image_size: 32 # unused
- in_channels: 8
- out_channels: 8
- model_channels: 320
- attention_resolutions: [ 4, 2, 1 ]
- num_res_blocks: 2
- channel_mult: [ 1, 2, 4, 4 ]
- num_head_channels: 64 # need to fix for flash-attn
- use_spatial_transformer: True
- use_linear_in_transformer: True
- transformer_depth: 1
- context_dim: 1024
- use_checkpoint: False
- legacy: False
- camera_dim: 16
- with_ip: True
- ip_dim: 16 # ip token length
- ip_mode: "local_resample"
-
- vae_config:
- target: imagedream.ldm.models.autoencoder.AutoencoderKL
- params:
- embed_dim: 4
- monitor: val/rec_loss
- ddconfig:
- #attn_type: "vanilla-xformers"
- double_z: true
- z_channels: 4
- resolution: 256
- in_channels: 3
- out_ch: 3
- ch: 128
- ch_mult:
- - 1
- - 2
- - 4
- - 4
- num_res_blocks: 2
- attn_resolutions: []
- dropout: 0.0
- lossconfig:
- target: torch.nn.Identity
-
- clip_config:
- target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
- params:
- freeze: True
- layer: "penultimate"
- ip_mode: "local_resample"
diff --git a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_chin8.yaml b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_chin8.yaml
deleted file mode 100644
index ce64b053fc76ded2ed85af7c5d398e2981e51c3d..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_chin8.yaml
+++ /dev/null
@@ -1,61 +0,0 @@
-model:
- target: imagedream.ldm.interface.LatentDiffusionInterface
- params:
- linear_start: 0.00085
- linear_end: 0.0120
- timesteps: 1000
- scale_factor: 0.18215
- parameterization: "eps"
-
- unet_config:
- target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModelStage2
- params:
- image_size: 32 # unused
- in_channels: 8
- out_channels: 4
- model_channels: 320
- attention_resolutions: [ 4, 2, 1 ]
- num_res_blocks: 2
- channel_mult: [ 1, 2, 4, 4 ]
- num_head_channels: 64 # need to fix for flash-attn
- use_spatial_transformer: True
- use_linear_in_transformer: True
- transformer_depth: 1
- context_dim: 1024
- use_checkpoint: False
- legacy: False
- camera_dim: 16
- with_ip: True
- ip_dim: 16 # ip token length
- ip_mode: "local_resample"
-
- vae_config:
- target: imagedream.ldm.models.autoencoder.AutoencoderKL
- params:
- embed_dim: 4
- monitor: val/rec_loss
- ddconfig:
- #attn_type: "vanilla-xformers"
- double_z: true
- z_channels: 4
- resolution: 256
- in_channels: 3
- out_ch: 3
- ch: 128
- ch_mult:
- - 1
- - 2
- - 4
- - 4
- num_res_blocks: 2
- attn_resolutions: []
- dropout: 0.0
- lossconfig:
- target: torch.nn.Identity
-
- clip_config:
- target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
- params:
- freeze: True
- layer: "penultimate"
- ip_mode: "local_resample"
diff --git a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml
deleted file mode 100644
index bd9c835f06707aeffcd5cd8fd4cd7ec262646905..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml
+++ /dev/null
@@ -1,62 +0,0 @@
-model:
- target: imagedream.ldm.interface.LatentDiffusionInterface
- params:
- linear_start: 0.00085
- linear_end: 0.0120
- timesteps: 1000
- scale_factor: 0.18215
- parameterization: "eps"
- zero_snr: true
-
- unet_config:
- target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModelStage2
- params:
- image_size: 32 # unused
- in_channels: 8
- out_channels: 4
- model_channels: 320
- attention_resolutions: [ 4, 2, 1 ]
- num_res_blocks: 2
- channel_mult: [ 1, 2, 4, 4 ]
- num_head_channels: 64 # need to fix for flash-attn
- use_spatial_transformer: True
- use_linear_in_transformer: True
- transformer_depth: 1
- context_dim: 1024
- use_checkpoint: False
- legacy: False
- camera_dim: 16
- with_ip: True
- ip_dim: 16 # ip token length
- ip_mode: "local_resample"
-
- vae_config:
- target: imagedream.ldm.models.autoencoder.AutoencoderKL
- params:
- embed_dim: 4
- monitor: val/rec_loss
- ddconfig:
- #attn_type: "vanilla-xformers"
- double_z: true
- z_channels: 4
- resolution: 256
- in_channels: 3
- out_ch: 3
- ch: 128
- ch_mult:
- - 1
- - 2
- - 4
- - 4
- num_res_blocks: 2
- attn_resolutions: []
- dropout: 0.0
- lossconfig:
- target: torch.nn.Identity
-
- clip_config:
- target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
- params:
- freeze: True
- layer: "penultimate"
- ip_mode: "local_resample"
diff --git a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_local.yaml b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_local.yaml
deleted file mode 100644
index b4fbe27c84d5c2b05737894dc8970f45f08ba606..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_local.yaml
+++ /dev/null
@@ -1,62 +0,0 @@
-model:
- target: imagedream.ldm.interface.LatentDiffusionInterface
- params:
- linear_start: 0.00085
- linear_end: 0.0120
- timesteps: 1000
- scale_factor: 0.18215
- parameterization: "eps"
-
- unet_config:
- target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel
- params:
- image_size: 32 # unused
- in_channels: 4
- out_channels: 4
- model_channels: 320
- attention_resolutions: [ 4, 2, 1 ]
- num_res_blocks: 2
- channel_mult: [ 1, 2, 4, 4 ]
- num_head_channels: 64 # need to fix for flash-attn
- use_spatial_transformer: True
- use_linear_in_transformer: True
- transformer_depth: 1
- context_dim: 1024
- use_checkpoint: False
- legacy: False
- camera_dim: 16
- with_ip: True
- ip_dim: 16 # ip token length
- ip_mode: "local_resample"
- ip_weight: 1.0 # adjust for similarity to image
-
- vae_config:
- target: imagedream.ldm.models.autoencoder.AutoencoderKL
- params:
- embed_dim: 4
- monitor: val/rec_loss
- ddconfig:
- #attn_type: "vanilla-xformers"
- double_z: true
- z_channels: 4
- resolution: 256
- in_channels: 3
- out_ch: 3
- ch: 128
- ch_mult:
- - 1
- - 2
- - 4
- - 4
- num_res_blocks: 2
- attn_resolutions: []
- dropout: 0.0
- lossconfig:
- target: torch.nn.Identity
-
- clip_config:
- target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
- params:
- freeze: True
- layer: "penultimate"
- ip_mode: "local_resample"
diff --git a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml
deleted file mode 100644
index 5824cd7f25a9d1ef2e397341a39f7c724eb1cf76..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml
+++ /dev/null
@@ -1,62 +0,0 @@
-model:
- target: imagedream.ldm.interface.LatentDiffusionInterface
- params:
- linear_start: 0.00085
- linear_end: 0.0120
- timesteps: 1000
- scale_factor: 0.18215
- parameterization: "eps"
- zero_snr: true
-
- unet_config:
- target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel
- params:
- image_size: 32 # unused
- in_channels: 4
- out_channels: 4
- model_channels: 320
- attention_resolutions: [ 4, 2, 1 ]
- num_res_blocks: 2
- channel_mult: [ 1, 2, 4, 4 ]
- num_head_channels: 64 # need to fix for flash-attn
- use_spatial_transformer: True
- use_linear_in_transformer: True
- transformer_depth: 1
- context_dim: 1024
- use_checkpoint: False
- legacy: False
- camera_dim: 16
- with_ip: True
- ip_dim: 16 # ip token length
- ip_mode: "local_resample"
-
- vae_config:
- target: imagedream.ldm.models.autoencoder.AutoencoderKL
- params:
- embed_dim: 4
- monitor: val/rec_loss
- ddconfig:
- #attn_type: "vanilla-xformers"
- double_z: true
- z_channels: 4
- resolution: 256
- in_channels: 3
- out_ch: 3
- ch: 128
- ch_mult:
- - 1
- - 2
- - 4
- - 4
- num_res_blocks: 2
- attn_resolutions: []
- dropout: 0.0
- lossconfig:
- target: torch.nn.Identity
-
- clip_config:
- target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
- params:
- freeze: True
- layer: "penultimate"
- ip_mode: "local_resample"
diff --git a/apps/third_party/CRM/imagedream/ldm/__init__.py b/apps/third_party/CRM/imagedream/ldm/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/apps/third_party/CRM/imagedream/ldm/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index d029d6984122456fadb1fcb2eb074aa9471b82d4..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/__pycache__/__init__.cpython-38.pyc
deleted file mode 100644
index 02bf8990d7b4f5cdaf387895d86be5e17bf8ed39..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/__pycache__/__init__.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/__pycache__/interface.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/__pycache__/interface.cpython-310.pyc
deleted file mode 100644
index 01d44982e9c3c83b6c7479b28fe0604fd8f79fbd..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/__pycache__/interface.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/__pycache__/interface.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/__pycache__/interface.cpython-38.pyc
deleted file mode 100644
index d9f7089053acea4170588b4fc2a48adf647beaf0..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/__pycache__/interface.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/__pycache__/util.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/__pycache__/util.cpython-310.pyc
deleted file mode 100644
index 0890976472bd4db53fbbf6883313845de8002651..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/__pycache__/util.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/__pycache__/util.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/__pycache__/util.cpython-38.pyc
deleted file mode 100644
index e30a7fbad8ea87fce676138a6a8e481191d09dd7..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/__pycache__/util.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/interface.py b/apps/third_party/CRM/imagedream/ldm/interface.py
deleted file mode 100644
index 3bbeed7921ea30efb81573c3efed42d8375cb21a..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/imagedream/ldm/interface.py
+++ /dev/null
@@ -1,206 +0,0 @@
-from typing import List
-from functools import partial
-
-import numpy as np
-import torch
-import torch.nn as nn
-
-from .modules.diffusionmodules.util import (
- make_beta_schedule,
- extract_into_tensor,
- enforce_zero_terminal_snr,
- noise_like,
-)
-from .util import exists, default, instantiate_from_config
-from .modules.distributions.distributions import DiagonalGaussianDistribution
-
-
-class DiffusionWrapper(nn.Module):
- def __init__(self, diffusion_model):
- super().__init__()
- self.diffusion_model = diffusion_model
-
- def forward(self, *args, **kwargs):
- return self.diffusion_model(*args, **kwargs)
-
-
-class LatentDiffusionInterface(nn.Module):
- """a simple interface class for LDM inference"""
-
- def __init__(
- self,
- unet_config,
- clip_config,
- vae_config,
- parameterization="eps",
- scale_factor=0.18215,
- beta_schedule="linear",
- timesteps=1000,
- linear_start=0.00085,
- linear_end=0.0120,
- cosine_s=8e-3,
- given_betas=None,
- zero_snr=False,
- *args,
- **kwargs,
- ):
- super().__init__()
-
- unet = instantiate_from_config(unet_config)
- self.model = DiffusionWrapper(unet)
- self.clip_model = instantiate_from_config(clip_config)
- self.vae_model = instantiate_from_config(vae_config)
-
- self.parameterization = parameterization
- self.scale_factor = scale_factor
- self.register_schedule(
- given_betas=given_betas,
- beta_schedule=beta_schedule,
- timesteps=timesteps,
- linear_start=linear_start,
- linear_end=linear_end,
- cosine_s=cosine_s,
- zero_snr=zero_snr
- )
-
- def register_schedule(
- self,
- given_betas=None,
- beta_schedule="linear",
- timesteps=1000,
- linear_start=1e-4,
- linear_end=2e-2,
- cosine_s=8e-3,
- zero_snr=False
- ):
- if exists(given_betas):
- betas = given_betas
- else:
- betas = make_beta_schedule(
- beta_schedule,
- timesteps,
- linear_start=linear_start,
- linear_end=linear_end,
- cosine_s=cosine_s,
- )
- if zero_snr:
- print("--- using zero snr---")
- betas = enforce_zero_terminal_snr(betas).numpy()
- alphas = 1.0 - betas
- alphas_cumprod = np.cumprod(alphas, axis=0)
- alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
-
- (timesteps,) = betas.shape
- self.num_timesteps = int(timesteps)
- self.linear_start = linear_start
- self.linear_end = linear_end
- assert (
- alphas_cumprod.shape[0] == self.num_timesteps
- ), "alphas have to be defined for each timestep"
-
- to_torch = partial(torch.tensor, dtype=torch.float32)
-
- self.register_buffer("betas", to_torch(betas))
- self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
- self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
- self.register_buffer(
- "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
- )
- self.register_buffer(
- "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
- )
- eps = 1e-8 # adding small epsilon value to avoid devide by zero error
- self.register_buffer(
- "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps)))
- )
- self.register_buffer(
- "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps) - 1))
- )
-
- # calculations for posterior q(x_{t-1} | x_t, x_0)
- self.v_posterior = 0
- posterior_variance = (1 - self.v_posterior) * betas * (
- 1.0 - alphas_cumprod_prev
- ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
- self.register_buffer("posterior_variance", to_torch(posterior_variance))
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
- self.register_buffer(
- "posterior_log_variance_clipped",
- to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
- )
- self.register_buffer(
- "posterior_mean_coef1",
- to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
- )
- self.register_buffer(
- "posterior_mean_coef2",
- to_torch(
- (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
- ),
- )
-
- def q_sample(self, x_start, t, noise=None):
- noise = default(noise, lambda: torch.randn_like(x_start))
- return (
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
- + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
- * noise
- )
-
- def get_v(self, x, noise, t):
- return (
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
- - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
- )
-
- def predict_start_from_noise(self, x_t, t, noise):
- return (
- extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
- * noise
- )
-
- def predict_start_from_z_and_v(self, x_t, t, v):
- return (
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
- - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
- )
-
- def predict_eps_from_z_and_v(self, x_t, t, v):
- return (
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
- + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
- * x_t
- )
-
- def apply_model(self, x_noisy, t, cond, **kwargs):
- assert isinstance(cond, dict), "cond has to be a dictionary"
- return self.model(x_noisy, t, **cond, **kwargs)
-
- def get_learned_conditioning(self, prompts: List[str]):
- return self.clip_model(prompts)
-
- def get_learned_image_conditioning(self, images):
- return self.clip_model.forward_image(images)
-
- def get_first_stage_encoding(self, encoder_posterior):
- if isinstance(encoder_posterior, DiagonalGaussianDistribution):
- z = encoder_posterior.sample()
- elif isinstance(encoder_posterior, torch.Tensor):
- z = encoder_posterior
- else:
- raise NotImplementedError(
- f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
- )
- return self.scale_factor * z
-
- def encode_first_stage(self, x):
- return self.vae_model.encode(x)
-
- def decode_first_stage(self, z):
- z = 1.0 / self.scale_factor * z
- return self.vae_model.decode(z)
diff --git a/apps/third_party/CRM/imagedream/ldm/models/__init__.py b/apps/third_party/CRM/imagedream/ldm/models/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/apps/third_party/CRM/imagedream/ldm/models/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/models/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index 672ea939d54368780d02dd86a921336018342e63..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/models/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/models/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/models/__pycache__/__init__.cpython-38.pyc
deleted file mode 100644
index e64c9b27cd3f052820627d22906cadc4b3e9dd43..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/models/__pycache__/__init__.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/models/__pycache__/autoencoder.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/models/__pycache__/autoencoder.cpython-310.pyc
deleted file mode 100644
index 6ab9a034b2a331fc87a7e50e1a44af00ff39efac..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/models/__pycache__/autoencoder.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/models/__pycache__/autoencoder.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/models/__pycache__/autoencoder.cpython-38.pyc
deleted file mode 100644
index 5e3a48c1e459902bdb6f0621a7ede086b9c8cd08..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/models/__pycache__/autoencoder.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/models/autoencoder.py b/apps/third_party/CRM/imagedream/ldm/models/autoencoder.py
deleted file mode 100644
index 92f83096ddaf2146772f6b49b23d8f99e787fbb4..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/imagedream/ldm/models/autoencoder.py
+++ /dev/null
@@ -1,270 +0,0 @@
-import torch
-import torch.nn.functional as F
-from contextlib import contextmanager
-
-from ..modules.diffusionmodules.model import Encoder, Decoder
-from ..modules.distributions.distributions import DiagonalGaussianDistribution
-
-from ..util import instantiate_from_config
-from ..modules.ema import LitEma
-
-
-class AutoencoderKL(torch.nn.Module):
- def __init__(
- self,
- ddconfig,
- lossconfig,
- embed_dim,
- ckpt_path=None,
- ignore_keys=[],
- image_key="image",
- colorize_nlabels=None,
- monitor=None,
- ema_decay=None,
- learn_logvar=False,
- ):
- super().__init__()
- self.learn_logvar = learn_logvar
- self.image_key = image_key
- self.encoder = Encoder(**ddconfig)
- self.decoder = Decoder(**ddconfig)
- self.loss = instantiate_from_config(lossconfig)
- assert ddconfig["double_z"]
- self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
- self.embed_dim = embed_dim
- if colorize_nlabels is not None:
- assert type(colorize_nlabels) == int
- self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
- if monitor is not None:
- self.monitor = monitor
-
- self.use_ema = ema_decay is not None
- if self.use_ema:
- self.ema_decay = ema_decay
- assert 0.0 < ema_decay < 1.0
- self.model_ema = LitEma(self, decay=ema_decay)
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
-
- if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
-
- def init_from_ckpt(self, path, ignore_keys=list()):
- sd = torch.load(path, map_location="cpu")["state_dict"]
- keys = list(sd.keys())
- for k in keys:
- for ik in ignore_keys:
- if k.startswith(ik):
- print("Deleting key {} from state_dict.".format(k))
- del sd[k]
- self.load_state_dict(sd, strict=False)
- print(f"Restored from {path}")
-
- @contextmanager
- def ema_scope(self, context=None):
- if self.use_ema:
- self.model_ema.store(self.parameters())
- self.model_ema.copy_to(self)
- if context is not None:
- print(f"{context}: Switched to EMA weights")
- try:
- yield None
- finally:
- if self.use_ema:
- self.model_ema.restore(self.parameters())
- if context is not None:
- print(f"{context}: Restored training weights")
-
- def on_train_batch_end(self, *args, **kwargs):
- if self.use_ema:
- self.model_ema(self)
-
- def encode(self, x):
- h = self.encoder(x)
- moments = self.quant_conv(h)
- posterior = DiagonalGaussianDistribution(moments)
- return posterior
-
- def decode(self, z):
- z = self.post_quant_conv(z)
- dec = self.decoder(z)
- return dec
-
- def forward(self, input, sample_posterior=True):
- posterior = self.encode(input)
- if sample_posterior:
- z = posterior.sample()
- else:
- z = posterior.mode()
- dec = self.decode(z)
- return dec, posterior
-
- def get_input(self, batch, k):
- x = batch[k]
- if len(x.shape) == 3:
- x = x[..., None]
- x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
- return x
-
- def training_step(self, batch, batch_idx, optimizer_idx):
- inputs = self.get_input(batch, self.image_key)
- reconstructions, posterior = self(inputs)
-
- if optimizer_idx == 0:
- # train encoder+decoder+logvar
- aeloss, log_dict_ae = self.loss(
- inputs,
- reconstructions,
- posterior,
- optimizer_idx,
- self.global_step,
- last_layer=self.get_last_layer(),
- split="train",
- )
- self.log(
- "aeloss",
- aeloss,
- prog_bar=True,
- logger=True,
- on_step=True,
- on_epoch=True,
- )
- self.log_dict(
- log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
- )
- return aeloss
-
- if optimizer_idx == 1:
- # train the discriminator
- discloss, log_dict_disc = self.loss(
- inputs,
- reconstructions,
- posterior,
- optimizer_idx,
- self.global_step,
- last_layer=self.get_last_layer(),
- split="train",
- )
-
- self.log(
- "discloss",
- discloss,
- prog_bar=True,
- logger=True,
- on_step=True,
- on_epoch=True,
- )
- self.log_dict(
- log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
- )
- return discloss
-
- def validation_step(self, batch, batch_idx):
- log_dict = self._validation_step(batch, batch_idx)
- with self.ema_scope():
- log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
- return log_dict
-
- def _validation_step(self, batch, batch_idx, postfix=""):
- inputs = self.get_input(batch, self.image_key)
- reconstructions, posterior = self(inputs)
- aeloss, log_dict_ae = self.loss(
- inputs,
- reconstructions,
- posterior,
- 0,
- self.global_step,
- last_layer=self.get_last_layer(),
- split="val" + postfix,
- )
-
- discloss, log_dict_disc = self.loss(
- inputs,
- reconstructions,
- posterior,
- 1,
- self.global_step,
- last_layer=self.get_last_layer(),
- split="val" + postfix,
- )
-
- self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
- self.log_dict(log_dict_ae)
- self.log_dict(log_dict_disc)
- return self.log_dict
-
- def configure_optimizers(self):
- lr = self.learning_rate
- ae_params_list = (
- list(self.encoder.parameters())
- + list(self.decoder.parameters())
- + list(self.quant_conv.parameters())
- + list(self.post_quant_conv.parameters())
- )
- if self.learn_logvar:
- print(f"{self.__class__.__name__}: Learning logvar")
- ae_params_list.append(self.loss.logvar)
- opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9))
- opt_disc = torch.optim.Adam(
- self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
- )
- return [opt_ae, opt_disc], []
-
- def get_last_layer(self):
- return self.decoder.conv_out.weight
-
- @torch.no_grad()
- def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
- log = dict()
- x = self.get_input(batch, self.image_key)
- x = x.to(self.device)
- if not only_inputs:
- xrec, posterior = self(x)
- if x.shape[1] > 3:
- # colorize with random projection
- assert xrec.shape[1] > 3
- x = self.to_rgb(x)
- xrec = self.to_rgb(xrec)
- log["samples"] = self.decode(torch.randn_like(posterior.sample()))
- log["reconstructions"] = xrec
- if log_ema or self.use_ema:
- with self.ema_scope():
- xrec_ema, posterior_ema = self(x)
- if x.shape[1] > 3:
- # colorize with random projection
- assert xrec_ema.shape[1] > 3
- xrec_ema = self.to_rgb(xrec_ema)
- log["samples_ema"] = self.decode(
- torch.randn_like(posterior_ema.sample())
- )
- log["reconstructions_ema"] = xrec_ema
- log["inputs"] = x
- return log
-
- def to_rgb(self, x):
- assert self.image_key == "segmentation"
- if not hasattr(self, "colorize"):
- self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
- x = F.conv2d(x, weight=self.colorize)
- x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
- return x
-
-
-class IdentityFirstStage(torch.nn.Module):
- def __init__(self, *args, vq_interface=False, **kwargs):
- self.vq_interface = vq_interface
- super().__init__()
-
- def encode(self, x, *args, **kwargs):
- return x
-
- def decode(self, x, *args, **kwargs):
- return x
-
- def quantize(self, x, *args, **kwargs):
- if self.vq_interface:
- return x, None, [None, None, None]
- return x
-
- def forward(self, x, *args, **kwargs):
- return x
diff --git a/apps/third_party/CRM/imagedream/ldm/models/diffusion/__init__.py b/apps/third_party/CRM/imagedream/ldm/models/diffusion/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index a5c8f08217241b9ddc7d371105a50328a41ca897..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc
deleted file mode 100644
index 9f083019b05ad6e286010c0cd0a21e6d9f9b2b9d..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc
deleted file mode 100644
index 62e7bce7e47a73aa87c94c0cfdd772df724aff7a..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc
deleted file mode 100644
index 0512edd43d2c8d4c576a7076c33332b30b0e0d6e..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/models/diffusion/ddim.py b/apps/third_party/CRM/imagedream/ldm/models/diffusion/ddim.py
deleted file mode 100644
index 4c10321c39078e985f18a0ca7b388086a4aa4e2f..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/imagedream/ldm/models/diffusion/ddim.py
+++ /dev/null
@@ -1,430 +0,0 @@
-"""SAMPLING ONLY."""
-
-import torch
-import numpy as np
-from tqdm import tqdm
-from functools import partial
-
-from ...modules.diffusionmodules.util import (
- make_ddim_sampling_parameters,
- make_ddim_timesteps,
- noise_like,
- extract_into_tensor,
-)
-
-
-class DDIMSampler(object):
- def __init__(self, model, schedule="linear", **kwargs):
- super().__init__()
- self.model = model
- self.ddpm_num_timesteps = model.num_timesteps
- self.schedule = schedule
-
- def register_buffer(self, name, attr):
- if type(attr) == torch.Tensor:
- if attr.device != torch.device("cuda"):
- attr = attr.to(torch.device("cuda"))
- setattr(self, name, attr)
-
- def make_schedule(
- self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
- ):
- self.ddim_timesteps = make_ddim_timesteps(
- ddim_discr_method=ddim_discretize,
- num_ddim_timesteps=ddim_num_steps,
- num_ddpm_timesteps=self.ddpm_num_timesteps,
- verbose=verbose,
- )
- alphas_cumprod = self.model.alphas_cumprod
- assert (
- alphas_cumprod.shape[0] == self.ddpm_num_timesteps
- ), "alphas have to be defined for each timestep"
- to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
-
- self.register_buffer("betas", to_torch(self.model.betas))
- self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
- self.register_buffer(
- "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
- )
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.register_buffer(
- "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
- )
- self.register_buffer(
- "sqrt_one_minus_alphas_cumprod",
- to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
- )
- self.register_buffer(
- "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
- )
- self.register_buffer(
- "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
- )
- self.register_buffer(
- "sqrt_recipm1_alphas_cumprod",
- to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
- )
-
- # ddim sampling parameters
- ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
- alphacums=alphas_cumprod.cpu(),
- ddim_timesteps=self.ddim_timesteps,
- eta=ddim_eta,
- verbose=verbose,
- )
- self.register_buffer("ddim_sigmas", ddim_sigmas)
- self.register_buffer("ddim_alphas", ddim_alphas)
- self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
- self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
- sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
- (1 - self.alphas_cumprod_prev)
- / (1 - self.alphas_cumprod)
- * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
- )
- self.register_buffer(
- "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
- )
-
- @torch.no_grad()
- def sample(
- self,
- S,
- batch_size,
- shape,
- conditioning=None,
- callback=None,
- normals_sequence=None,
- img_callback=None,
- quantize_x0=False,
- eta=0.0,
- mask=None,
- x0=None,
- temperature=1.0,
- noise_dropout=0.0,
- score_corrector=None,
- corrector_kwargs=None,
- verbose=True,
- x_T=None,
- log_every_t=100,
- unconditional_guidance_scale=1.0,
- unconditional_conditioning=None,
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
- **kwargs,
- ):
- if conditioning is not None:
- if isinstance(conditioning, dict):
- cbs = conditioning[list(conditioning.keys())[0]].shape[0]
- if cbs != batch_size:
- print(
- f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
- )
- else:
- if conditioning.shape[0] != batch_size:
- print(
- f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
- )
-
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
- # sampling
- C, H, W = shape
- size = (batch_size, C, H, W)
-
- samples, intermediates = self.ddim_sampling(
- conditioning,
- size,
- callback=callback,
- img_callback=img_callback,
- quantize_denoised=quantize_x0,
- mask=mask,
- x0=x0,
- ddim_use_original_steps=False,
- noise_dropout=noise_dropout,
- temperature=temperature,
- score_corrector=score_corrector,
- corrector_kwargs=corrector_kwargs,
- x_T=x_T,
- log_every_t=log_every_t,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning,
- **kwargs,
- )
- return samples, intermediates
-
- @torch.no_grad()
- def ddim_sampling(
- self,
- cond,
- shape,
- x_T=None,
- ddim_use_original_steps=False,
- callback=None,
- timesteps=None,
- quantize_denoised=False,
- mask=None,
- x0=None,
- img_callback=None,
- log_every_t=100,
- temperature=1.0,
- noise_dropout=0.0,
- score_corrector=None,
- corrector_kwargs=None,
- unconditional_guidance_scale=1.0,
- unconditional_conditioning=None,
- **kwargs,
- ):
- """
- when inference time: all values of parameter
- cond.keys(): dict_keys(['context', 'camera', 'num_frames', 'ip', 'ip_img'])
- shape: (5, 4, 32, 32)
- x_T: None
- ddim_use_original_steps: False
- timesteps: None
- callback: None
- quantize_denoised: False
- mask: None
- image_callback: None
- log_every_t: 100
- temperature: 1.0
- noise_dropout: 0.0
- score_corrector: None
- corrector_kwargs: None
- unconditional_guidance_scale: 5
- unconditional_conditioning.keys(): dict_keys(['context', 'camera', 'num_frames', 'ip', 'ip_img'])
- kwargs: {}
- """
- device = self.model.betas.device
- b = shape[0]
- if x_T is None:
- img = torch.randn(shape, device=device) # shape: torch.Size([5, 4, 32, 32]) mean: -0.00, std: 1.00, min: -3.64, max: 3.94
- else:
- img = x_T
-
- if timesteps is None: # equal with set time step in hf
- timesteps = (
- self.ddpm_num_timesteps
- if ddim_use_original_steps
- else self.ddim_timesteps
- )
- elif timesteps is not None and not ddim_use_original_steps:
- subset_end = (
- int(
- min(timesteps / self.ddim_timesteps.shape[0], 1)
- * self.ddim_timesteps.shape[0]
- )
- - 1
- )
- timesteps = self.ddim_timesteps[:subset_end]
-
- intermediates = {"x_inter": [img], "pred_x0": [img]}
- time_range = ( # reversed timesteps
- reversed(range(0, timesteps))
- if ddim_use_original_steps
- else np.flip(timesteps)
- )
- total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
- iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
- for i, step in enumerate(iterator):
- index = total_steps - i - 1
- ts = torch.full((b,), step, device=device, dtype=torch.long)
-
- if mask is not None:
- assert x0 is not None
- img_orig = self.model.q_sample(
- x0, ts
- ) # TODO: deterministic forward pass?
- img = img_orig * mask + (1.0 - mask) * img
-
- outs = self.p_sample_ddim(
- img,
- cond,
- ts,
- index=index,
- use_original_steps=ddim_use_original_steps,
- quantize_denoised=quantize_denoised,
- temperature=temperature,
- noise_dropout=noise_dropout,
- score_corrector=score_corrector,
- corrector_kwargs=corrector_kwargs,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning,
- **kwargs,
- )
- img, pred_x0 = outs
- if callback:
- callback(i)
- if img_callback:
- img_callback(pred_x0, i)
-
- if index % log_every_t == 0 or index == total_steps - 1:
- intermediates["x_inter"].append(img)
- intermediates["pred_x0"].append(pred_x0)
-
- return img, intermediates
-
- @torch.no_grad()
- def p_sample_ddim(
- self,
- x,
- c,
- t,
- index,
- repeat_noise=False,
- use_original_steps=False,
- quantize_denoised=False,
- temperature=1.0,
- noise_dropout=0.0,
- score_corrector=None,
- corrector_kwargs=None,
- unconditional_guidance_scale=1.0,
- unconditional_conditioning=None,
- dynamic_threshold=None,
- **kwargs,
- ):
- b, *_, device = *x.shape, x.device
-
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
- model_output = self.model.apply_model(x, t, c)
- else:
- x_in = torch.cat([x] * 2)
- t_in = torch.cat([t] * 2)
- if isinstance(c, dict):
- assert isinstance(unconditional_conditioning, dict)
- c_in = dict()
- for k in c:
- if isinstance(c[k], list):
- c_in[k] = [
- torch.cat([unconditional_conditioning[k][i], c[k][i]])
- for i in range(len(c[k]))
- ]
- elif isinstance(c[k], torch.Tensor):
- c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
- else:
- assert c[k] == unconditional_conditioning[k]
- c_in[k] = c[k]
- elif isinstance(c, list):
- c_in = list()
- assert isinstance(unconditional_conditioning, list)
- for i in range(len(c)):
- c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
- else:
- c_in = torch.cat([unconditional_conditioning, c])
- model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
- model_output = model_uncond + unconditional_guidance_scale * (
- model_t - model_uncond
- )
-
-
- if self.model.parameterization == "v":
- print("using v!")
- e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
- else:
- e_t = model_output
-
- if score_corrector is not None:
- assert self.model.parameterization == "eps", "not implemented"
- e_t = score_corrector.modify_score(
- self.model, e_t, x, t, c, **corrector_kwargs
- )
-
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
- alphas_prev = (
- self.model.alphas_cumprod_prev
- if use_original_steps
- else self.ddim_alphas_prev
- )
- sqrt_one_minus_alphas = (
- self.model.sqrt_one_minus_alphas_cumprod
- if use_original_steps
- else self.ddim_sqrt_one_minus_alphas
- )
- sigmas = (
- self.model.ddim_sigmas_for_original_num_steps
- if use_original_steps
- else self.ddim_sigmas
- )
- # select parameters corresponding to the currently considered timestep
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
- sqrt_one_minus_at = torch.full(
- (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
- )
-
- # current prediction for x_0
- if self.model.parameterization != "v":
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
- else:
- pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
-
- if quantize_denoised:
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
-
- if dynamic_threshold is not None:
- raise NotImplementedError()
-
- # direction pointing to x_t
- dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
- if noise_dropout > 0.0:
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
- return x_prev, pred_x0
-
- @torch.no_grad()
- def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
- # fast, but does not allow for exact reconstruction
- # t serves as an index to gather the correct alphas
- if use_original_steps:
- sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
- sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
- else:
- sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
- sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
-
- if noise is None:
- noise = torch.randn_like(x0)
- return (
- extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
- + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
- )
-
- @torch.no_grad()
- def decode(
- self,
- x_latent,
- cond,
- t_start,
- unconditional_guidance_scale=1.0,
- unconditional_conditioning=None,
- use_original_steps=False,
- **kwargs,
- ):
- timesteps = (
- np.arange(self.ddpm_num_timesteps)
- if use_original_steps
- else self.ddim_timesteps
- )
- timesteps = timesteps[:t_start]
-
- time_range = np.flip(timesteps)
- total_steps = timesteps.shape[0]
-
- iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
- x_dec = x_latent
- for i, step in enumerate(iterator):
- index = total_steps - i - 1
- ts = torch.full(
- (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
- )
- x_dec, _ = self.p_sample_ddim(
- x_dec,
- cond,
- ts,
- index=index,
- use_original_steps=use_original_steps,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning,
- **kwargs,
- )
- return x_dec
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/__init__.py b/apps/third_party/CRM/imagedream/ldm/modules/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index 188b2eab520e6aaf281181afb6d84e1a3098a914..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/__init__.cpython-38.pyc
deleted file mode 100644
index 8b89c02f3c6827cdac3513cca77f2efd913115b0..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/__init__.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/attention.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/attention.cpython-310.pyc
deleted file mode 100644
index e77dd1a9bb9eae984eb7a6f2f33a98a54662d64a..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/attention.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/attention.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/attention.cpython-38.pyc
deleted file mode 100644
index 378dad50ffdfa150bf000062d153d2dd6fb715b0..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/attention.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/ema.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/ema.cpython-310.pyc
deleted file mode 100644
index a6c836e51e276424666c8d22441a804bf8ae4722..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/ema.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/ema.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/ema.cpython-38.pyc
deleted file mode 100644
index ed990a3f549b22da782b4f7e7cfe07bf64d4969c..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/ema.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/attention.py b/apps/third_party/CRM/imagedream/ldm/modules/attention.py
deleted file mode 100644
index 9578027d3d9e9b8766941dc0c986d42cd93b04bf..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/imagedream/ldm/modules/attention.py
+++ /dev/null
@@ -1,456 +0,0 @@
-from inspect import isfunction
-import math
-import torch
-import torch.nn.functional as F
-from torch import nn, einsum
-from einops import rearrange, repeat
-from typing import Optional, Any
-
-from .diffusionmodules.util import checkpoint
-
-
-try:
- import xformers
- import xformers.ops
-
- XFORMERS_IS_AVAILBLE = True
-except:
- XFORMERS_IS_AVAILBLE = False
-
-# CrossAttn precision handling
-import os
-
-_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
-
-
-def exists(val):
- return val is not None
-
-
-def uniq(arr):
- return {el: True for el in arr}.keys()
-
-
-def default(val, d):
- if exists(val):
- return val
- return d() if isfunction(d) else d
-
-
-def max_neg_value(t):
- return -torch.finfo(t.dtype).max
-
-
-def init_(tensor):
- dim = tensor.shape[-1]
- std = 1 / math.sqrt(dim)
- tensor.uniform_(-std, std)
- return tensor
-
-
-# feedforward
-class GEGLU(nn.Module):
- def __init__(self, dim_in, dim_out):
- super().__init__()
- self.proj = nn.Linear(dim_in, dim_out * 2)
-
- def forward(self, x):
- x, gate = self.proj(x).chunk(2, dim=-1)
- return x * F.gelu(gate)
-
-
-class FeedForward(nn.Module):
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
- super().__init__()
- inner_dim = int(dim * mult)
- dim_out = default(dim_out, dim)
- project_in = (
- nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
- if not glu
- else GEGLU(dim, inner_dim)
- )
-
- self.net = nn.Sequential(
- project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
- )
-
- def forward(self, x):
- return self.net(x)
-
-
-def zero_module(module):
- """
- Zero out the parameters of a module and return it.
- """
- for p in module.parameters():
- p.detach().zero_()
- return module
-
-
-def Normalize(in_channels):
- return torch.nn.GroupNorm(
- num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
- )
-
-
-class SpatialSelfAttention(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = in_channels
-
- self.norm = Normalize(in_channels)
- self.q = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.k = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.v = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.proj_out = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
-
- def forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- b, c, h, w = q.shape
- q = rearrange(q, "b c h w -> b (h w) c")
- k = rearrange(k, "b c h w -> b c (h w)")
- w_ = torch.einsum("bij,bjk->bik", q, k)
-
- w_ = w_ * (int(c) ** (-0.5))
- w_ = torch.nn.functional.softmax(w_, dim=2)
-
- # attend to values
- v = rearrange(v, "b c h w -> b c (h w)")
- w_ = rearrange(w_, "b i j -> b j i")
- h_ = torch.einsum("bij,bjk->bik", v, w_)
- h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
- h_ = self.proj_out(h_)
-
- return x + h_
-
-
-class MemoryEfficientCrossAttention(nn.Module):
- # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
- super().__init__()
- print(
- f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
- f"{heads} heads."
- )
- inner_dim = dim_head * heads
- context_dim = default(context_dim, query_dim)
-
- self.heads = heads
- self.dim_head = dim_head
-
- self.with_ip = kwargs.get("with_ip", False)
- if self.with_ip and (context_dim is not None):
- self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
- self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
- self.ip_dim= kwargs.get("ip_dim", 16)
- self.ip_weight = kwargs.get("ip_weight", 1.0)
-
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
-
- self.to_out = nn.Sequential(
- nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
- )
- self.attention_op = None
-
- def forward(self, x, context=None, mask=None):
- q = self.to_q(x)
-
- has_ip = self.with_ip and (context is not None)
- if has_ip:
- # context dim [(b frame_num), (77 + img_token), 1024]
- token_len = context.shape[1]
- context_ip = context[:, -self.ip_dim:, :]
- k_ip = self.to_k_ip(context_ip)
- v_ip = self.to_v_ip(context_ip)
- context = context[:, :(token_len - self.ip_dim), :]
-
- context = default(context, x)
- k = self.to_k(context)
- v = self.to_v(context)
-
- b, _, _ = q.shape
- q, k, v = map(
- lambda t: t.unsqueeze(3)
- .reshape(b, t.shape[1], self.heads, self.dim_head)
- .permute(0, 2, 1, 3)
- .reshape(b * self.heads, t.shape[1], self.dim_head)
- .contiguous(),
- (q, k, v),
- )
-
- # actually compute the attention, what we cannot get enough of
- out = xformers.ops.memory_efficient_attention(
- q, k, v, attn_bias=None, op=self.attention_op
- )
-
- if has_ip:
- k_ip, v_ip = map(
- lambda t: t.unsqueeze(3)
- .reshape(b, t.shape[1], self.heads, self.dim_head)
- .permute(0, 2, 1, 3)
- .reshape(b * self.heads, t.shape[1], self.dim_head)
- .contiguous(),
- (k_ip, v_ip),
- )
- # actually compute the attention, what we cannot get enough of
- out_ip = xformers.ops.memory_efficient_attention(
- q, k_ip, v_ip, attn_bias=None, op=self.attention_op
- )
- out = out + self.ip_weight * out_ip
-
- if exists(mask):
- raise NotImplementedError
- out = (
- out.unsqueeze(0)
- .reshape(b, self.heads, out.shape[1], self.dim_head)
- .permute(0, 2, 1, 3)
- .reshape(b, out.shape[1], self.heads * self.dim_head)
- )
- return self.to_out(out)
-
-
-class BasicTransformerBlock(nn.Module):
- def __init__(
- self,
- dim,
- n_heads,
- d_head,
- dropout=0.0,
- context_dim=None,
- gated_ff=True,
- checkpoint=True,
- disable_self_attn=False,
- **kwargs
- ):
- super().__init__()
- assert XFORMERS_IS_AVAILBLE, "xformers is not available"
- attn_cls = MemoryEfficientCrossAttention
- self.disable_self_attn = disable_self_attn
- self.attn1 = attn_cls(
- query_dim=dim,
- heads=n_heads,
- dim_head=d_head,
- dropout=dropout,
- context_dim=context_dim if self.disable_self_attn else None,
- ) # is a self-attention if not self.disable_self_attn
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
- self.attn2 = attn_cls(
- query_dim=dim,
- context_dim=context_dim,
- heads=n_heads,
- dim_head=d_head,
- dropout=dropout,
- **kwargs
- ) # is self-attn if context is none
- self.norm1 = nn.LayerNorm(dim)
- self.norm2 = nn.LayerNorm(dim)
- self.norm3 = nn.LayerNorm(dim)
- self.checkpoint = checkpoint
-
- def forward(self, x, context=None):
- return checkpoint(
- self._forward, (x, context), self.parameters(), self.checkpoint
- )
-
- def _forward(self, x, context=None):
- x = (
- self.attn1(
- self.norm1(x), context=context if self.disable_self_attn else None
- )
- + x
- )
- x = self.attn2(self.norm2(x), context=context) + x
- x = self.ff(self.norm3(x)) + x
- return x
-
-
-class SpatialTransformer(nn.Module):
- """
- Transformer block for image-like data.
- First, project the input (aka embedding)
- and reshape to b, t, d.
- Then apply standard transformer action.
- Finally, reshape to image
- NEW: use_linear for more efficiency instead of the 1x1 convs
- """
-
- def __init__(
- self,
- in_channels,
- n_heads,
- d_head,
- depth=1,
- dropout=0.0,
- context_dim=None,
- disable_self_attn=False,
- use_linear=False,
- use_checkpoint=True,
- **kwargs
- ):
- super().__init__()
- if exists(context_dim) and not isinstance(context_dim, list):
- context_dim = [context_dim]
- self.in_channels = in_channels
- inner_dim = n_heads * d_head
- self.norm = Normalize(in_channels)
- if not use_linear:
- self.proj_in = nn.Conv2d(
- in_channels, inner_dim, kernel_size=1, stride=1, padding=0
- )
- else:
- self.proj_in = nn.Linear(in_channels, inner_dim)
-
- self.transformer_blocks = nn.ModuleList(
- [
- BasicTransformerBlock(
- inner_dim,
- n_heads,
- d_head,
- dropout=dropout,
- context_dim=context_dim[d],
- disable_self_attn=disable_self_attn,
- checkpoint=use_checkpoint,
- **kwargs
- )
- for d in range(depth)
- ]
- )
- if not use_linear:
- self.proj_out = zero_module(
- nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
- )
- else:
- self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
- self.use_linear = use_linear
-
- def forward(self, x, context=None):
- # note: if no context is given, cross-attention defaults to self-attention
- if not isinstance(context, list):
- context = [context]
- b, c, h, w = x.shape
- x_in = x
- x = self.norm(x)
- if not self.use_linear:
- x = self.proj_in(x)
- x = rearrange(x, "b c h w -> b (h w) c").contiguous()
- if self.use_linear:
- x = self.proj_in(x)
- for i, block in enumerate(self.transformer_blocks):
- x = block(x, context=context[i])
- if self.use_linear:
- x = self.proj_out(x)
- x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
- if not self.use_linear:
- x = self.proj_out(x)
- return x + x_in
-
-
-class BasicTransformerBlock3D(BasicTransformerBlock):
- def forward(self, x, context=None, num_frames=1):
- return checkpoint(
- self._forward, (x, context, num_frames), self.parameters(), self.checkpoint
- )
-
- def _forward(self, x, context=None, num_frames=1):
- x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
- x = (
- self.attn1(
- self.norm1(x),
- context=context if self.disable_self_attn else None
- )
- + x
- )
- x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
- x = self.attn2(self.norm2(x), context=context) + x
- x = self.ff(self.norm3(x)) + x
- return x
-
-
-class SpatialTransformer3D(nn.Module):
- """3D self-attention"""
-
- def __init__(
- self,
- in_channels,
- n_heads,
- d_head,
- depth=1,
- dropout=0.0,
- context_dim=None,
- disable_self_attn=False,
- use_linear=False,
- use_checkpoint=True,
- **kwargs
- ):
- super().__init__()
- if exists(context_dim) and not isinstance(context_dim, list):
- context_dim = [context_dim]
- self.in_channels = in_channels
- inner_dim = n_heads * d_head
- self.norm = Normalize(in_channels)
- if not use_linear:
- self.proj_in = nn.Conv2d(
- in_channels, inner_dim, kernel_size=1, stride=1, padding=0
- )
- else:
- self.proj_in = nn.Linear(in_channels, inner_dim)
-
- self.transformer_blocks = nn.ModuleList(
- [
- BasicTransformerBlock3D(
- inner_dim,
- n_heads,
- d_head,
- dropout=dropout,
- context_dim=context_dim[d],
- disable_self_attn=disable_self_attn,
- checkpoint=use_checkpoint,
- **kwargs
- )
- for d in range(depth)
- ]
- )
- if not use_linear:
- self.proj_out = zero_module(
- nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
- )
- else:
- self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
- self.use_linear = use_linear
-
- def forward(self, x, context=None, num_frames=1):
- # note: if no context is given, cross-attention defaults to self-attention
- if not isinstance(context, list):
- context = [context]
- b, c, h, w = x.shape
- x_in = x
- x = self.norm(x)
- if not self.use_linear:
- x = self.proj_in(x)
- x = rearrange(x, "b c h w -> b (h w) c").contiguous()
- if self.use_linear:
- x = self.proj_in(x)
- for i, block in enumerate(self.transformer_blocks):
- x = block(x, context=context[i], num_frames=num_frames)
- if self.use_linear:
- x = self.proj_out(x)
- x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
- if not self.use_linear:
- x = self.proj_out(x)
- return x + x_in
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__init__.py b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index 87936969c3da946fe6a3d832741e28e3f8c5a465..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc
deleted file mode 100644
index 709d11f26d4279dc79749621e1fadb3d3214664c..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/adaptors.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/adaptors.cpython-310.pyc
deleted file mode 100644
index a776d040c73ba392b8c4498ceae40daa0ac5c375..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/adaptors.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/adaptors.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/adaptors.cpython-38.pyc
deleted file mode 100644
index 009edfe033ed8828bdc0ae3ae92f56076fe6752a..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/adaptors.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc
deleted file mode 100644
index e42f7d66fb4493c5d29a2e00d78db7ee9459f090..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc
deleted file mode 100644
index ef2270fc4896075b92413258975349f6c68c9bec..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc
deleted file mode 100644
index a1a0a6de705d51589cf78005348fa64f9e7eb29e..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc
deleted file mode 100644
index f5e8589aeb17fb3c779c10c01f058ddaf3165198..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc
deleted file mode 100644
index 8aa9c29edf8fd31ddfd84c3b6784e57a6037b2fb..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc
deleted file mode 100644
index 0a455c6a98a6d27b74477615cf6bf53f0899da51..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/adaptors.py b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/adaptors.py
deleted file mode 100644
index 8d66e480728073294015cf0eb906dba471d602ca..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/adaptors.py
+++ /dev/null
@@ -1,163 +0,0 @@
-# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
-import math
-
-import torch
-import torch.nn as nn
-
-
-# FFN
-def FeedForward(dim, mult=4):
- inner_dim = int(dim * mult)
- return nn.Sequential(
- nn.LayerNorm(dim),
- nn.Linear(dim, inner_dim, bias=False),
- nn.GELU(),
- nn.Linear(inner_dim, dim, bias=False),
- )
-
-
-def reshape_tensor(x, heads):
- bs, length, width = x.shape
- #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
- x = x.view(bs, length, heads, -1)
- # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
- x = x.transpose(1, 2)
- # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
- x = x.reshape(bs, heads, length, -1)
- return x
-
-
-class PerceiverAttention(nn.Module):
- def __init__(self, *, dim, dim_head=64, heads=8):
- super().__init__()
- self.scale = dim_head**-0.5
- self.dim_head = dim_head
- self.heads = heads
- inner_dim = dim_head * heads
-
- self.norm1 = nn.LayerNorm(dim)
- self.norm2 = nn.LayerNorm(dim)
-
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
- self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
-
-
- def forward(self, x, latents):
- """
- Args:
- x (torch.Tensor): image features
- shape (b, n1, D)
- latent (torch.Tensor): latent features
- shape (b, n2, D)
- """
- x = self.norm1(x)
- latents = self.norm2(latents)
-
- b, l, _ = latents.shape
-
- q = self.to_q(latents)
- kv_input = torch.cat((x, latents), dim=-2)
- k, v = self.to_kv(kv_input).chunk(2, dim=-1)
-
- q = reshape_tensor(q, self.heads)
- k = reshape_tensor(k, self.heads)
- v = reshape_tensor(v, self.heads)
-
- # attention
- scale = 1 / math.sqrt(math.sqrt(self.dim_head))
- weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
- weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
- out = weight @ v
-
- out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
-
- return self.to_out(out)
-
-
-class ImageProjModel(torch.nn.Module):
- """Projection Model"""
- def __init__(self,
- cross_attention_dim=1024,
- clip_embeddings_dim=1024,
- clip_extra_context_tokens=4):
- super().__init__()
- self.cross_attention_dim = cross_attention_dim
- self.clip_extra_context_tokens = clip_extra_context_tokens
-
- # from 1024 -> 4 * 1024
- self.proj = torch.nn.Linear(
- clip_embeddings_dim,
- self.clip_extra_context_tokens * cross_attention_dim)
- self.norm = torch.nn.LayerNorm(cross_attention_dim)
-
- def forward(self, image_embeds):
- embeds = image_embeds
- clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
- clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
- return clip_extra_context_tokens
-
-
-class SimpleReSampler(nn.Module):
- def __init__(self, embedding_dim=1280, output_dim=1024):
- super().__init__()
- self.proj_out = nn.Linear(embedding_dim, output_dim)
- self.norm_out = nn.LayerNorm(output_dim)
-
- def forward(self, latents):
- """
- latents: B 256 N
- """
- latents = self.proj_out(latents)
- return self.norm_out(latents)
-
-
-class Resampler(nn.Module):
- def __init__(
- self,
- dim=1024,
- depth=8,
- dim_head=64,
- heads=16,
- num_queries=8,
- embedding_dim=768,
- output_dim=1024,
- ff_mult=4,
- ):
- super().__init__()
- self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
- self.proj_in = nn.Linear(embedding_dim, dim)
- self.proj_out = nn.Linear(dim, output_dim)
- self.norm_out = nn.LayerNorm(output_dim)
-
- self.layers = nn.ModuleList([])
- for _ in range(depth):
- self.layers.append(
- nn.ModuleList(
- [
- PerceiverAttention(dim=dim,
- dim_head=dim_head,
- heads=heads),
- FeedForward(dim=dim, mult=ff_mult),
- ]
- )
- )
-
- def forward(self, x):
- latents = self.latents.repeat(x.size(0), 1, 1)
- x = self.proj_in(x)
- for attn, ff in self.layers:
- latents = attn(x, latents) + latents
- latents = ff(latents) + latents
-
- latents = self.proj_out(latents)
- return self.norm_out(latents)
-
-
-if __name__ == '__main__':
- resampler = Resampler(embedding_dim=1280)
- resampler = SimpleReSampler(embedding_dim=1280)
- tensor = torch.rand(4, 257, 1280)
- embed = resampler(tensor)
- # embed = (tensor)
- print(embed.shape)
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/model.py b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/model.py
deleted file mode 100644
index 52a1d5c8e8ba62dd25133ffe76d370c637c5d25e..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/model.py
+++ /dev/null
@@ -1,1018 +0,0 @@
-# pytorch_diffusion + derived encoder decoder
-import math
-import torch
-import torch.nn as nn
-import numpy as np
-from einops import rearrange
-from typing import Optional, Any
-
-from ..attention import MemoryEfficientCrossAttention
-
-try:
- import xformers
- import xformers.ops
-
- XFORMERS_IS_AVAILBLE = True
-except:
- XFORMERS_IS_AVAILBLE = False
- print("No module 'xformers'. Proceeding without it.")
-
-
-def get_timestep_embedding(timesteps, embedding_dim):
- """
- This matches the implementation in Denoising Diffusion Probabilistic Models:
- From Fairseq.
- Build sinusoidal embeddings.
- This matches the implementation in tensor2tensor, but differs slightly
- from the description in Section 3.5 of "Attention Is All You Need".
- """
- assert len(timesteps.shape) == 1
-
- half_dim = embedding_dim // 2
- emb = math.log(10000) / (half_dim - 1)
- emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
- emb = emb.to(device=timesteps.device)
- emb = timesteps.float()[:, None] * emb[None, :]
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
- if embedding_dim % 2 == 1: # zero pad
- emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
- return emb
-
-
-def nonlinearity(x):
- # swish
- return x * torch.sigmoid(x)
-
-
-def Normalize(in_channels, num_groups=32):
- return torch.nn.GroupNorm(
- num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
- )
-
-
-class Upsample(nn.Module):
- def __init__(self, in_channels, with_conv):
- super().__init__()
- self.with_conv = with_conv
- if self.with_conv:
- self.conv = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=3, stride=1, padding=1
- )
-
- def forward(self, x):
- x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
- if self.with_conv:
- x = self.conv(x)
- return x
-
-
-class Downsample(nn.Module):
- def __init__(self, in_channels, with_conv):
- super().__init__()
- self.with_conv = with_conv
- if self.with_conv:
- # no asymmetric padding in torch conv, must do it ourselves
- self.conv = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=3, stride=2, padding=0
- )
-
- def forward(self, x):
- if self.with_conv:
- pad = (0, 1, 0, 1)
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
- x = self.conv(x)
- else:
- x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
- return x
-
-
-class ResnetBlock(nn.Module):
- def __init__(
- self,
- *,
- in_channels,
- out_channels=None,
- conv_shortcut=False,
- dropout,
- temb_channels=512,
- ):
- super().__init__()
- self.in_channels = in_channels
- out_channels = in_channels if out_channels is None else out_channels
- self.out_channels = out_channels
- self.use_conv_shortcut = conv_shortcut
-
- self.norm1 = Normalize(in_channels)
- self.conv1 = torch.nn.Conv2d(
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
- )
- if temb_channels > 0:
- self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
- self.norm2 = Normalize(out_channels)
- self.dropout = torch.nn.Dropout(dropout)
- self.conv2 = torch.nn.Conv2d(
- out_channels, out_channels, kernel_size=3, stride=1, padding=1
- )
- if self.in_channels != self.out_channels:
- if self.use_conv_shortcut:
- self.conv_shortcut = torch.nn.Conv2d(
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
- )
- else:
- self.nin_shortcut = torch.nn.Conv2d(
- in_channels, out_channels, kernel_size=1, stride=1, padding=0
- )
-
- def forward(self, x, temb):
- h = x
- h = self.norm1(h)
- h = nonlinearity(h)
- h = self.conv1(h)
-
- if temb is not None:
- h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
-
- h = self.norm2(h)
- h = nonlinearity(h)
- h = self.dropout(h)
- h = self.conv2(h)
-
- if self.in_channels != self.out_channels:
- if self.use_conv_shortcut:
- x = self.conv_shortcut(x)
- else:
- x = self.nin_shortcut(x)
-
- return x + h
-
-
-class AttnBlock(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = in_channels
-
- self.norm = Normalize(in_channels)
- self.q = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.k = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.v = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.proj_out = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
-
- def forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- b, c, h, w = q.shape
- q = q.reshape(b, c, h * w)
- q = q.permute(0, 2, 1) # b,hw,c
- k = k.reshape(b, c, h * w) # b,c,hw
- w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
- w_ = w_ * (int(c) ** (-0.5))
- w_ = torch.nn.functional.softmax(w_, dim=2)
-
- # attend to values
- v = v.reshape(b, c, h * w)
- w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
- h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
- h_ = h_.reshape(b, c, h, w)
-
- h_ = self.proj_out(h_)
-
- return x + h_
-
-
-class MemoryEfficientAttnBlock(nn.Module):
- """
- Uses xformers efficient implementation,
- see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
- Note: this is a single-head self-attention operation
- """
-
- #
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = in_channels
-
- self.norm = Normalize(in_channels)
- self.q = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.k = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.v = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.proj_out = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.attention_op = None
-
- def forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- B, C, H, W = q.shape
- q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
-
- q, k, v = map(
- lambda t: t.unsqueeze(3)
- .reshape(B, t.shape[1], 1, C)
- .permute(0, 2, 1, 3)
- .reshape(B * 1, t.shape[1], C)
- .contiguous(),
- (q, k, v),
- )
- out = xformers.ops.memory_efficient_attention(
- q, k, v, attn_bias=None, op=self.attention_op
- )
-
- out = (
- out.unsqueeze(0)
- .reshape(B, 1, out.shape[1], C)
- .permute(0, 2, 1, 3)
- .reshape(B, out.shape[1], C)
- )
- out = rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
- out = self.proj_out(out)
- return x + out
-
-
-class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
- def forward(self, x, context=None, mask=None):
- b, c, h, w = x.shape
- x = rearrange(x, "b c h w -> b (h w) c")
- out = super().forward(x, context=context, mask=mask)
- out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
- return x + out
-
-
-def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
- assert attn_type in [
- "vanilla",
- "vanilla-xformers",
- "memory-efficient-cross-attn",
- "linear",
- "none",
- ], f"attn_type {attn_type} unknown"
- if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
- attn_type = "vanilla-xformers"
- print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
- if attn_type == "vanilla":
- assert attn_kwargs is None
- return AttnBlock(in_channels)
- elif attn_type == "vanilla-xformers":
- print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
- return MemoryEfficientAttnBlock(in_channels)
- elif type == "memory-efficient-cross-attn":
- attn_kwargs["query_dim"] = in_channels
- return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
- elif attn_type == "none":
- return nn.Identity(in_channels)
- else:
- raise NotImplementedError()
-
-
-class Model(nn.Module):
- def __init__(
- self,
- *,
- ch,
- out_ch,
- ch_mult=(1, 2, 4, 8),
- num_res_blocks,
- attn_resolutions,
- dropout=0.0,
- resamp_with_conv=True,
- in_channels,
- resolution,
- use_timestep=True,
- use_linear_attn=False,
- attn_type="vanilla",
- ):
- super().__init__()
- if use_linear_attn:
- attn_type = "linear"
- self.ch = ch
- self.temb_ch = self.ch * 4
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.in_channels = in_channels
-
- self.use_timestep = use_timestep
- if self.use_timestep:
- # timestep embedding
- self.temb = nn.Module()
- self.temb.dense = nn.ModuleList(
- [
- torch.nn.Linear(self.ch, self.temb_ch),
- torch.nn.Linear(self.temb_ch, self.temb_ch),
- ]
- )
-
- # downsampling
- self.conv_in = torch.nn.Conv2d(
- in_channels, self.ch, kernel_size=3, stride=1, padding=1
- )
-
- curr_res = resolution
- in_ch_mult = (1,) + tuple(ch_mult)
- self.down = nn.ModuleList()
- for i_level in range(self.num_resolutions):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_in = ch * in_ch_mult[i_level]
- block_out = ch * ch_mult[i_level]
- for i_block in range(self.num_res_blocks):
- block.append(
- ResnetBlock(
- in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- )
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(make_attn(block_in, attn_type=attn_type))
- down = nn.Module()
- down.block = block
- down.attn = attn
- if i_level != self.num_resolutions - 1:
- down.downsample = Downsample(block_in, resamp_with_conv)
- curr_res = curr_res // 2
- self.down.append(down)
-
- # middle
- self.mid = nn.Module()
- self.mid.block_1 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
- self.mid.block_2 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
-
- # upsampling
- self.up = nn.ModuleList()
- for i_level in reversed(range(self.num_resolutions)):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_out = ch * ch_mult[i_level]
- skip_in = ch * ch_mult[i_level]
- for i_block in range(self.num_res_blocks + 1):
- if i_block == self.num_res_blocks:
- skip_in = ch * in_ch_mult[i_level]
- block.append(
- ResnetBlock(
- in_channels=block_in + skip_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- )
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(make_attn(block_in, attn_type=attn_type))
- up = nn.Module()
- up.block = block
- up.attn = attn
- if i_level != 0:
- up.upsample = Upsample(block_in, resamp_with_conv)
- curr_res = curr_res * 2
- self.up.insert(0, up) # prepend to get consistent order
-
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(
- block_in, out_ch, kernel_size=3, stride=1, padding=1
- )
-
- def forward(self, x, t=None, context=None):
- # assert x.shape[2] == x.shape[3] == self.resolution
- if context is not None:
- # assume aligned context, cat along channel axis
- x = torch.cat((x, context), dim=1)
- if self.use_timestep:
- # timestep embedding
- assert t is not None
- temb = get_timestep_embedding(t, self.ch)
- temb = self.temb.dense[0](temb)
- temb = nonlinearity(temb)
- temb = self.temb.dense[1](temb)
- else:
- temb = None
-
- # downsampling
- hs = [self.conv_in(x)]
- for i_level in range(self.num_resolutions):
- for i_block in range(self.num_res_blocks):
- h = self.down[i_level].block[i_block](hs[-1], temb)
- if len(self.down[i_level].attn) > 0:
- h = self.down[i_level].attn[i_block](h)
- hs.append(h)
- if i_level != self.num_resolutions - 1:
- hs.append(self.down[i_level].downsample(hs[-1]))
-
- # middle
- h = hs[-1]
- h = self.mid.block_1(h, temb)
- h = self.mid.attn_1(h)
- h = self.mid.block_2(h, temb)
-
- # upsampling
- for i_level in reversed(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks + 1):
- h = self.up[i_level].block[i_block](
- torch.cat([h, hs.pop()], dim=1), temb
- )
- if len(self.up[i_level].attn) > 0:
- h = self.up[i_level].attn[i_block](h)
- if i_level != 0:
- h = self.up[i_level].upsample(h)
-
- # end
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- return h
-
- def get_last_layer(self):
- return self.conv_out.weight
-
-
-class Encoder(nn.Module):
- def __init__(
- self,
- *,
- ch,
- out_ch,
- ch_mult=(1, 2, 4, 8),
- num_res_blocks,
- attn_resolutions,
- dropout=0.0,
- resamp_with_conv=True,
- in_channels,
- resolution,
- z_channels,
- double_z=True,
- use_linear_attn=False,
- attn_type="vanilla",
- **ignore_kwargs,
- ):
- super().__init__()
- if use_linear_attn:
- attn_type = "linear"
- self.ch = ch
- self.temb_ch = 0
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.in_channels = in_channels
-
- # downsampling
- self.conv_in = torch.nn.Conv2d(
- in_channels, self.ch, kernel_size=3, stride=1, padding=1
- )
-
- curr_res = resolution
- in_ch_mult = (1,) + tuple(ch_mult)
- self.in_ch_mult = in_ch_mult
- self.down = nn.ModuleList()
- for i_level in range(self.num_resolutions):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_in = ch * in_ch_mult[i_level]
- block_out = ch * ch_mult[i_level]
- for i_block in range(self.num_res_blocks):
- block.append(
- ResnetBlock(
- in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- )
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(make_attn(block_in, attn_type=attn_type))
- down = nn.Module()
- down.block = block
- down.attn = attn
- if i_level != self.num_resolutions - 1:
- down.downsample = Downsample(block_in, resamp_with_conv)
- curr_res = curr_res // 2
- self.down.append(down)
-
- # middle
- self.mid = nn.Module()
- self.mid.block_1 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
- self.mid.block_2 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
-
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(
- block_in,
- 2 * z_channels if double_z else z_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- )
-
- def forward(self, x):
- # timestep embedding
- temb = None
-
- # downsampling
- hs = [self.conv_in(x)]
- for i_level in range(self.num_resolutions):
- for i_block in range(self.num_res_blocks):
- h = self.down[i_level].block[i_block](hs[-1], temb)
- if len(self.down[i_level].attn) > 0:
- h = self.down[i_level].attn[i_block](h)
- hs.append(h)
- if i_level != self.num_resolutions - 1:
- hs.append(self.down[i_level].downsample(hs[-1]))
-
- # middle
- h = hs[-1]
- h = self.mid.block_1(h, temb)
- h = self.mid.attn_1(h)
- h = self.mid.block_2(h, temb)
-
- # end
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- return h
-
-
-class Decoder(nn.Module):
- def __init__(
- self,
- *,
- ch,
- out_ch,
- ch_mult=(1, 2, 4, 8),
- num_res_blocks,
- attn_resolutions,
- dropout=0.0,
- resamp_with_conv=True,
- in_channels,
- resolution,
- z_channels,
- give_pre_end=False,
- tanh_out=False,
- use_linear_attn=False,
- attn_type="vanilla",
- **ignorekwargs,
- ):
- super().__init__()
- if use_linear_attn:
- attn_type = "linear"
- self.ch = ch
- self.temb_ch = 0
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.in_channels = in_channels
- self.give_pre_end = give_pre_end
- self.tanh_out = tanh_out
-
- # compute in_ch_mult, block_in and curr_res at lowest res
- in_ch_mult = (1,) + tuple(ch_mult)
- block_in = ch * ch_mult[self.num_resolutions - 1]
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
- self.z_shape = (1, z_channels, curr_res, curr_res)
- print(
- "Working with z of shape {} = {} dimensions.".format(
- self.z_shape, np.prod(self.z_shape)
- )
- )
-
- # z to block_in
- self.conv_in = torch.nn.Conv2d(
- z_channels, block_in, kernel_size=3, stride=1, padding=1
- )
-
- # middle
- self.mid = nn.Module()
- self.mid.block_1 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
- self.mid.block_2 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
-
- # upsampling
- self.up = nn.ModuleList()
- for i_level in reversed(range(self.num_resolutions)):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_out = ch * ch_mult[i_level]
- for i_block in range(self.num_res_blocks + 1):
- block.append(
- ResnetBlock(
- in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- )
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(make_attn(block_in, attn_type=attn_type))
- up = nn.Module()
- up.block = block
- up.attn = attn
- if i_level != 0:
- up.upsample = Upsample(block_in, resamp_with_conv)
- curr_res = curr_res * 2
- self.up.insert(0, up) # prepend to get consistent order
-
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(
- block_in, out_ch, kernel_size=3, stride=1, padding=1
- )
-
- def forward(self, z):
- # assert z.shape[1:] == self.z_shape[1:]
- self.last_z_shape = z.shape
-
- # timestep embedding
- temb = None
-
- # z to block_in
- h = self.conv_in(z)
-
- # middle
- h = self.mid.block_1(h, temb)
- h = self.mid.attn_1(h)
- h = self.mid.block_2(h, temb)
-
- # upsampling
- for i_level in reversed(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks + 1):
- h = self.up[i_level].block[i_block](h, temb)
- if len(self.up[i_level].attn) > 0:
- h = self.up[i_level].attn[i_block](h)
- if i_level != 0:
- h = self.up[i_level].upsample(h)
-
- # end
- if self.give_pre_end:
- return h
-
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- if self.tanh_out:
- h = torch.tanh(h)
- return h
-
-
-class SimpleDecoder(nn.Module):
- def __init__(self, in_channels, out_channels, *args, **kwargs):
- super().__init__()
- self.model = nn.ModuleList(
- [
- nn.Conv2d(in_channels, in_channels, 1),
- ResnetBlock(
- in_channels=in_channels,
- out_channels=2 * in_channels,
- temb_channels=0,
- dropout=0.0,
- ),
- ResnetBlock(
- in_channels=2 * in_channels,
- out_channels=4 * in_channels,
- temb_channels=0,
- dropout=0.0,
- ),
- ResnetBlock(
- in_channels=4 * in_channels,
- out_channels=2 * in_channels,
- temb_channels=0,
- dropout=0.0,
- ),
- nn.Conv2d(2 * in_channels, in_channels, 1),
- Upsample(in_channels, with_conv=True),
- ]
- )
- # end
- self.norm_out = Normalize(in_channels)
- self.conv_out = torch.nn.Conv2d(
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
- )
-
- def forward(self, x):
- for i, layer in enumerate(self.model):
- if i in [1, 2, 3]:
- x = layer(x, None)
- else:
- x = layer(x)
-
- h = self.norm_out(x)
- h = nonlinearity(h)
- x = self.conv_out(h)
- return x
-
-
-class UpsampleDecoder(nn.Module):
- def __init__(
- self,
- in_channels,
- out_channels,
- ch,
- num_res_blocks,
- resolution,
- ch_mult=(2, 2),
- dropout=0.0,
- ):
- super().__init__()
- # upsampling
- self.temb_ch = 0
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- block_in = in_channels
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
- self.res_blocks = nn.ModuleList()
- self.upsample_blocks = nn.ModuleList()
- for i_level in range(self.num_resolutions):
- res_block = []
- block_out = ch * ch_mult[i_level]
- for i_block in range(self.num_res_blocks + 1):
- res_block.append(
- ResnetBlock(
- in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- )
- block_in = block_out
- self.res_blocks.append(nn.ModuleList(res_block))
- if i_level != self.num_resolutions - 1:
- self.upsample_blocks.append(Upsample(block_in, True))
- curr_res = curr_res * 2
-
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(
- block_in, out_channels, kernel_size=3, stride=1, padding=1
- )
-
- def forward(self, x):
- # upsampling
- h = x
- for k, i_level in enumerate(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks + 1):
- h = self.res_blocks[i_level][i_block](h, None)
- if i_level != self.num_resolutions - 1:
- h = self.upsample_blocks[k](h)
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- return h
-
-
-class LatentRescaler(nn.Module):
- def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
- super().__init__()
- # residual block, interpolate, residual block
- self.factor = factor
- self.conv_in = nn.Conv2d(
- in_channels, mid_channels, kernel_size=3, stride=1, padding=1
- )
- self.res_block1 = nn.ModuleList(
- [
- ResnetBlock(
- in_channels=mid_channels,
- out_channels=mid_channels,
- temb_channels=0,
- dropout=0.0,
- )
- for _ in range(depth)
- ]
- )
- self.attn = AttnBlock(mid_channels)
- self.res_block2 = nn.ModuleList(
- [
- ResnetBlock(
- in_channels=mid_channels,
- out_channels=mid_channels,
- temb_channels=0,
- dropout=0.0,
- )
- for _ in range(depth)
- ]
- )
-
- self.conv_out = nn.Conv2d(
- mid_channels,
- out_channels,
- kernel_size=1,
- )
-
- def forward(self, x):
- x = self.conv_in(x)
- for block in self.res_block1:
- x = block(x, None)
- x = torch.nn.functional.interpolate(
- x,
- size=(
- int(round(x.shape[2] * self.factor)),
- int(round(x.shape[3] * self.factor)),
- ),
- )
- x = self.attn(x)
- for block in self.res_block2:
- x = block(x, None)
- x = self.conv_out(x)
- return x
-
-
-class MergedRescaleEncoder(nn.Module):
- def __init__(
- self,
- in_channels,
- ch,
- resolution,
- out_ch,
- num_res_blocks,
- attn_resolutions,
- dropout=0.0,
- resamp_with_conv=True,
- ch_mult=(1, 2, 4, 8),
- rescale_factor=1.0,
- rescale_module_depth=1,
- ):
- super().__init__()
- intermediate_chn = ch * ch_mult[-1]
- self.encoder = Encoder(
- in_channels=in_channels,
- num_res_blocks=num_res_blocks,
- ch=ch,
- ch_mult=ch_mult,
- z_channels=intermediate_chn,
- double_z=False,
- resolution=resolution,
- attn_resolutions=attn_resolutions,
- dropout=dropout,
- resamp_with_conv=resamp_with_conv,
- out_ch=None,
- )
- self.rescaler = LatentRescaler(
- factor=rescale_factor,
- in_channels=intermediate_chn,
- mid_channels=intermediate_chn,
- out_channels=out_ch,
- depth=rescale_module_depth,
- )
-
- def forward(self, x):
- x = self.encoder(x)
- x = self.rescaler(x)
- return x
-
-
-class MergedRescaleDecoder(nn.Module):
- def __init__(
- self,
- z_channels,
- out_ch,
- resolution,
- num_res_blocks,
- attn_resolutions,
- ch,
- ch_mult=(1, 2, 4, 8),
- dropout=0.0,
- resamp_with_conv=True,
- rescale_factor=1.0,
- rescale_module_depth=1,
- ):
- super().__init__()
- tmp_chn = z_channels * ch_mult[-1]
- self.decoder = Decoder(
- out_ch=out_ch,
- z_channels=tmp_chn,
- attn_resolutions=attn_resolutions,
- dropout=dropout,
- resamp_with_conv=resamp_with_conv,
- in_channels=None,
- num_res_blocks=num_res_blocks,
- ch_mult=ch_mult,
- resolution=resolution,
- ch=ch,
- )
- self.rescaler = LatentRescaler(
- factor=rescale_factor,
- in_channels=z_channels,
- mid_channels=tmp_chn,
- out_channels=tmp_chn,
- depth=rescale_module_depth,
- )
-
- def forward(self, x):
- x = self.rescaler(x)
- x = self.decoder(x)
- return x
-
-
-class Upsampler(nn.Module):
- def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
- super().__init__()
- assert out_size >= in_size
- num_blocks = int(np.log2(out_size // in_size)) + 1
- factor_up = 1.0 + (out_size % in_size)
- print(
- f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
- )
- self.rescaler = LatentRescaler(
- factor=factor_up,
- in_channels=in_channels,
- mid_channels=2 * in_channels,
- out_channels=in_channels,
- )
- self.decoder = Decoder(
- out_ch=out_channels,
- resolution=out_size,
- z_channels=in_channels,
- num_res_blocks=2,
- attn_resolutions=[],
- in_channels=None,
- ch=in_channels,
- ch_mult=[ch_mult for _ in range(num_blocks)],
- )
-
- def forward(self, x):
- x = self.rescaler(x)
- x = self.decoder(x)
- return x
-
-
-class Resize(nn.Module):
- def __init__(self, in_channels=None, learned=False, mode="bilinear"):
- super().__init__()
- self.with_conv = learned
- self.mode = mode
- if self.with_conv:
- print(
- f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
- )
- raise NotImplementedError()
- assert in_channels is not None
- # no asymmetric padding in torch conv, must do it ourselves
- self.conv = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=4, stride=2, padding=1
- )
-
- def forward(self, x, scale_factor=1.0):
- if scale_factor == 1.0:
- return x
- else:
- x = torch.nn.functional.interpolate(
- x, mode=self.mode, align_corners=False, scale_factor=scale_factor
- )
- return x
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/openaimodel.py b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/openaimodel.py
deleted file mode 100644
index 2f12a389584a729e1177af8683d631e5c5d77fd5..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/openaimodel.py
+++ /dev/null
@@ -1,1135 +0,0 @@
-from abc import abstractmethod
-import math
-
-import numpy as np
-import torch
-import torch as th
-import torch.nn as nn
-import torch.nn.functional as F
-from einops import rearrange, repeat
-
-from imagedream.ldm.modules.diffusionmodules.util import (
- checkpoint,
- conv_nd,
- linear,
- avg_pool_nd,
- zero_module,
- normalization,
- timestep_embedding,
- convert_module_to_f16,
- convert_module_to_f32
-)
-from imagedream.ldm.modules.attention import (
- SpatialTransformer,
- SpatialTransformer3D,
- exists
-)
-from imagedream.ldm.modules.diffusionmodules.adaptors import (
- Resampler,
- ImageProjModel
-)
-
-## go
-class AttentionPool2d(nn.Module):
- """
- Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
- """
-
- def __init__(
- self,
- spacial_dim: int,
- embed_dim: int,
- num_heads_channels: int,
- output_dim: int = None,
- ):
- super().__init__()
- self.positional_embedding = nn.Parameter(
- th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
- )
- self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
- self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
- self.num_heads = embed_dim // num_heads_channels
- self.attention = QKVAttention(self.num_heads)
-
- def forward(self, x):
- b, c, *_spatial = x.shape
- x = x.reshape(b, c, -1) # NC(HW)
- x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
- x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
- x = self.qkv_proj(x)
- x = self.attention(x)
- x = self.c_proj(x)
- return x[:, :, 0]
-
-
-class TimestepBlock(nn.Module):
- """
- Any module where forward() takes timestep embeddings as a second argument.
- """
-
- @abstractmethod
- def forward(self, x, emb):
- """
- Apply the module to `x` given `emb` timestep embeddings.
- """
-
-
-class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
- """
- A sequential module that passes timestep embeddings to the children that
- support it as an extra input.
- """
-
- def forward(self, x, emb, context=None, num_frames=1):
- for layer in self:
- if isinstance(layer, TimestepBlock):
- x = layer(x, emb)
- elif isinstance(layer, SpatialTransformer3D):
- x = layer(x, context, num_frames=num_frames)
- elif isinstance(layer, SpatialTransformer):
- x = layer(x, context)
- else:
- x = layer(x)
- return x
-
-
-class Upsample(nn.Module):
- """
- An upsampling layer with an optional convolution.
- :param channels: channels in the inputs and outputs.
- :param use_conv: a bool determining if a convolution is applied.
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
- upsampling occurs in the inner-two dimensions.
- """
-
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
- super().__init__()
- self.channels = channels
- self.out_channels = out_channels or channels
- self.use_conv = use_conv
- self.dims = dims
- if use_conv:
- self.conv = conv_nd(
- dims, self.channels, self.out_channels, 3, padding=padding
- )
-
- def forward(self, x):
- assert x.shape[1] == self.channels
- if self.dims == 3:
- x = F.interpolate(
- x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
- )
- else:
- x = F.interpolate(x, scale_factor=2, mode="nearest")
- if self.use_conv:
- x = self.conv(x)
- return x
-
-
-class TransposedUpsample(nn.Module):
- "Learned 2x upsampling without padding"
-
- def __init__(self, channels, out_channels=None, ks=5):
- super().__init__()
- self.channels = channels
- self.out_channels = out_channels or channels
-
- self.up = nn.ConvTranspose2d(
- self.channels, self.out_channels, kernel_size=ks, stride=2
- )
-
- def forward(self, x):
- return self.up(x)
-
-
-class Downsample(nn.Module):
- """
- A downsampling layer with an optional convolution.
- :param channels: channels in the inputs and outputs.
- :param use_conv: a bool determining if a convolution is applied.
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
- downsampling occurs in the inner-two dimensions.
- """
-
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
- super().__init__()
- self.channels = channels
- self.out_channels = out_channels or channels
- self.use_conv = use_conv
- self.dims = dims
- stride = 2 if dims != 3 else (1, 2, 2)
- if use_conv:
- self.op = conv_nd(
- dims,
- self.channels,
- self.out_channels,
- 3,
- stride=stride,
- padding=padding,
- )
- else:
- assert self.channels == self.out_channels
- self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
-
- def forward(self, x):
- assert x.shape[1] == self.channels
- return self.op(x)
-
-
-class ResBlock(TimestepBlock):
- """
- A residual block that can optionally change the number of channels.
- :param channels: the number of input channels.
- :param emb_channels: the number of timestep embedding channels.
- :param dropout: the rate of dropout.
- :param out_channels: if specified, the number of out channels.
- :param use_conv: if True and out_channels is specified, use a spatial
- convolution instead of a smaller 1x1 convolution to change the
- channels in the skip connection.
- :param dims: determines if the signal is 1D, 2D, or 3D.
- :param use_checkpoint: if True, use gradient checkpointing on this module.
- :param up: if True, use this block for upsampling.
- :param down: if True, use this block for downsampling.
- """
-
- def __init__(
- self,
- channels,
- emb_channels,
- dropout,
- out_channels=None,
- use_conv=False,
- use_scale_shift_norm=False,
- dims=2,
- use_checkpoint=False,
- up=False,
- down=False,
- ):
- super().__init__()
- self.channels = channels
- self.emb_channels = emb_channels
- self.dropout = dropout
- self.out_channels = out_channels or channels
- self.use_conv = use_conv
- self.use_checkpoint = use_checkpoint
- self.use_scale_shift_norm = use_scale_shift_norm
-
- self.in_layers = nn.Sequential(
- normalization(channels),
- nn.SiLU(),
- conv_nd(dims, channels, self.out_channels, 3, padding=1),
- )
-
- self.updown = up or down
-
- if up:
- self.h_upd = Upsample(channels, False, dims)
- self.x_upd = Upsample(channels, False, dims)
- elif down:
- self.h_upd = Downsample(channels, False, dims)
- self.x_upd = Downsample(channels, False, dims)
- else:
- self.h_upd = self.x_upd = nn.Identity()
-
- self.emb_layers = nn.Sequential(
- nn.SiLU(),
- linear(
- emb_channels,
- 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
- ),
- )
- self.out_layers = nn.Sequential(
- normalization(self.out_channels),
- nn.SiLU(),
- nn.Dropout(p=dropout),
- zero_module(
- conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
- ),
- )
-
- if self.out_channels == channels:
- self.skip_connection = nn.Identity()
- elif use_conv:
- self.skip_connection = conv_nd(
- dims, channels, self.out_channels, 3, padding=1
- )
- else:
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
-
- def forward(self, x, emb):
- """
- Apply the block to a Tensor, conditioned on a timestep embedding.
- :param x: an [N x C x ...] Tensor of features.
- :param emb: an [N x emb_channels] Tensor of timestep embeddings.
- :return: an [N x C x ...] Tensor of outputs.
- """
- return checkpoint(
- self._forward, (x, emb), self.parameters(), self.use_checkpoint
- )
-
- def _forward(self, x, emb):
- if self.updown:
- in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
- h = in_rest(x)
- h = self.h_upd(h)
- x = self.x_upd(x)
- h = in_conv(h)
- else:
- h = self.in_layers(x)
- emb_out = self.emb_layers(emb).type(h.dtype)
- while len(emb_out.shape) < len(h.shape):
- emb_out = emb_out[..., None]
- if self.use_scale_shift_norm:
- out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
- scale, shift = th.chunk(emb_out, 2, dim=1)
- h = out_norm(h) * (1 + scale) + shift
- h = out_rest(h)
- else:
- h = h + emb_out
- h = self.out_layers(h)
- return self.skip_connection(x) + h
-
-
-class AttentionBlock(nn.Module):
- """
- An attention block that allows spatial positions to attend to each other.
- Originally ported from here, but adapted to the N-d case.
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
- """
-
- def __init__(
- self,
- channels,
- num_heads=1,
- num_head_channels=-1,
- use_checkpoint=False,
- use_new_attention_order=False,
- ):
- super().__init__()
- self.channels = channels
- if num_head_channels == -1:
- self.num_heads = num_heads
- else:
- assert (
- channels % num_head_channels == 0
- ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
- self.num_heads = channels // num_head_channels
- self.use_checkpoint = use_checkpoint
- self.norm = normalization(channels)
- self.qkv = conv_nd(1, channels, channels * 3, 1)
- if use_new_attention_order:
- # split qkv before split heads
- self.attention = QKVAttention(self.num_heads)
- else:
- # split heads before split qkv
- self.attention = QKVAttentionLegacy(self.num_heads)
-
- self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
-
- def forward(self, x):
- return checkpoint(
- self._forward, (x,), self.parameters(), True
- ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
- # return pt_checkpoint(self._forward, x) # pytorch
-
- def _forward(self, x):
- b, c, *spatial = x.shape
- x = x.reshape(b, c, -1)
- qkv = self.qkv(self.norm(x))
- h = self.attention(qkv)
- h = self.proj_out(h)
- return (x + h).reshape(b, c, *spatial)
-
-
-def count_flops_attn(model, _x, y):
- """
- A counter for the `thop` package to count the operations in an
- attention operation.
- Meant to be used like:
- macs, params = thop.profile(
- model,
- inputs=(inputs, timestamps),
- custom_ops={QKVAttention: QKVAttention.count_flops},
- )
- """
- b, c, *spatial = y[0].shape
- num_spatial = int(np.prod(spatial))
- # We perform two matmuls with the same number of ops.
- # The first computes the weight matrix, the second computes
- # the combination of the value vectors.
- matmul_ops = 2 * b * (num_spatial**2) * c
- model.total_ops += th.DoubleTensor([matmul_ops])
-
-
-class QKVAttentionLegacy(nn.Module):
- """
- A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
- """
-
- def __init__(self, n_heads):
- super().__init__()
- self.n_heads = n_heads
-
- def forward(self, qkv):
- """
- Apply QKV attention.
- :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
- :return: an [N x (H * C) x T] tensor after attention.
- """
- bs, width, length = qkv.shape
- assert width % (3 * self.n_heads) == 0
- ch = width // (3 * self.n_heads)
- q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
- scale = 1 / math.sqrt(math.sqrt(ch))
- weight = th.einsum(
- "bct,bcs->bts", q * scale, k * scale
- ) # More stable with f16 than dividing afterwards
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
- a = th.einsum("bts,bcs->bct", weight, v)
- return a.reshape(bs, -1, length)
-
- @staticmethod
- def count_flops(model, _x, y):
- return count_flops_attn(model, _x, y)
-
-
-class QKVAttention(nn.Module):
- """
- A module which performs QKV attention and splits in a different order.
- """
-
- def __init__(self, n_heads):
- super().__init__()
- self.n_heads = n_heads
-
- def forward(self, qkv):
- """
- Apply QKV attention.
- :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
- :return: an [N x (H * C) x T] tensor after attention.
- """
- bs, width, length = qkv.shape
- assert width % (3 * self.n_heads) == 0
- ch = width // (3 * self.n_heads)
- q, k, v = qkv.chunk(3, dim=1)
- scale = 1 / math.sqrt(math.sqrt(ch))
- weight = th.einsum(
- "bct,bcs->bts",
- (q * scale).view(bs * self.n_heads, ch, length),
- (k * scale).view(bs * self.n_heads, ch, length),
- ) # More stable with f16 than dividing afterwards
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
- a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
- return a.reshape(bs, -1, length)
-
- @staticmethod
- def count_flops(model, _x, y):
- return count_flops_attn(model, _x, y)
-
-
-class Timestep(nn.Module):
- def __init__(self, dim):
- super().__init__()
- self.dim = dim
-
- def forward(self, t):
- return timestep_embedding(t, self.dim)
-
-
-class MultiViewUNetModel(nn.Module):
- """
- The full multi-view UNet model with attention, timestep embedding and camera embedding.
- :param in_channels: channels in the input Tensor.
- :param model_channels: base channel count for the model.
- :param out_channels: channels in the output Tensor.
- :param num_res_blocks: number of residual blocks per downsample.
- :param attention_resolutions: a collection of downsample rates at which
- attention will take place. May be a set, list, or tuple.
- For example, if this contains 4, then at 4x downsampling, attention
- will be used.
- :param dropout: the dropout probability.
- :param channel_mult: channel multiplier for each level of the UNet.
- :param conv_resample: if True, use learned convolutions for upsampling and
- downsampling.
- :param dims: determines if the signal is 1D, 2D, or 3D.
- :param num_classes: if specified (as an int), then this model will be
- class-conditional with `num_classes` classes.
- :param use_checkpoint: use gradient checkpointing to reduce memory usage.
- :param num_heads: the number of attention heads in each attention layer.
- :param num_heads_channels: if specified, ignore num_heads and instead use
- a fixed channel width per attention head.
- :param num_heads_upsample: works with num_heads to set a different number
- of heads for upsampling. Deprecated.
- :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
- :param resblock_updown: use residual blocks for up/downsampling.
- :param use_new_attention_order: use a different attention pattern for potentially
- increased efficiency.
- :param camera_dim: dimensionality of camera input.
- """
-
- def __init__(
- self,
- image_size,
- in_channels,
- model_channels,
- out_channels,
- num_res_blocks,
- attention_resolutions,
- dropout=0,
- channel_mult=(1, 2, 4, 8),
- conv_resample=True,
- dims=2,
- num_classes=None,
- use_checkpoint=False,
- use_fp16=False,
- use_bf16=False,
- num_heads=-1,
- num_head_channels=-1,
- num_heads_upsample=-1,
- use_scale_shift_norm=False,
- resblock_updown=False,
- use_new_attention_order=False,
- use_spatial_transformer=False, # custom transformer support
- transformer_depth=1, # custom transformer support
- context_dim=None, # custom transformer support
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
- legacy=True,
- disable_self_attentions=None,
- num_attention_blocks=None,
- disable_middle_self_attn=False,
- use_linear_in_transformer=False,
- adm_in_channels=None,
- camera_dim=None,
- with_ip=False, # wether add image prompt images
- ip_dim=0, # number of extra token, 4 for global 16 for local
- ip_weight=1.0, # weight for image prompt context
- ip_mode="local_resample", # which mode of adaptor, global or local
- ):
- super().__init__()
- if use_spatial_transformer:
- assert (
- context_dim is not None
- ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
-
- if context_dim is not None:
- assert (
- use_spatial_transformer
- ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
- from omegaconf.listconfig import ListConfig
-
- if type(context_dim) == ListConfig:
- context_dim = list(context_dim)
-
- if num_heads_upsample == -1:
- num_heads_upsample = num_heads
-
- if num_heads == -1:
- assert (
- num_head_channels != -1
- ), "Either num_heads or num_head_channels has to be set"
-
- if num_head_channels == -1:
- assert (
- num_heads != -1
- ), "Either num_heads or num_head_channels has to be set"
-
- self.image_size = image_size
- self.in_channels = in_channels
- self.model_channels = model_channels
- self.out_channels = out_channels
- if isinstance(num_res_blocks, int):
- self.num_res_blocks = len(channel_mult) * [num_res_blocks]
- else:
- if len(num_res_blocks) != len(channel_mult):
- raise ValueError(
- "provide num_res_blocks either as an int (globally constant) or "
- "as a list/tuple (per-level) with the same length as channel_mult"
- )
- self.num_res_blocks = num_res_blocks
- if disable_self_attentions is not None:
- # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
- assert len(disable_self_attentions) == len(channel_mult)
- if num_attention_blocks is not None:
- assert len(num_attention_blocks) == len(self.num_res_blocks)
- assert all(
- map(
- lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
- range(len(num_attention_blocks)),
- )
- )
- print(
- f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
- f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
- f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
- f"attention will still not be set."
- )
-
- self.attention_resolutions = attention_resolutions
- self.dropout = dropout
- self.channel_mult = channel_mult
- self.conv_resample = conv_resample
- self.num_classes = num_classes
- self.use_checkpoint = use_checkpoint
- self.dtype = th.float16 if use_fp16 else th.float32
- self.dtype = th.bfloat16 if use_bf16 else self.dtype
- self.num_heads = num_heads
- self.num_head_channels = num_head_channels
- self.num_heads_upsample = num_heads_upsample
- self.predict_codebook_ids = n_embed is not None
-
- self.with_ip = with_ip # wether there is image prompt
- self.ip_dim = ip_dim # num of extra token, 4 for global 16 for local
- self.ip_weight = ip_weight
- assert ip_mode in ["global", "local_resample"]
- self.ip_mode = ip_mode # which mode of adaptor
-
- time_embed_dim = model_channels * 4
- self.time_embed = nn.Sequential(
- linear(model_channels, time_embed_dim),
- nn.SiLU(),
- linear(time_embed_dim, time_embed_dim),
- )
-
- if camera_dim is not None:
- time_embed_dim = model_channels * 4
- self.camera_embed = nn.Sequential(
- linear(camera_dim, time_embed_dim),
- nn.SiLU(),
- linear(time_embed_dim, time_embed_dim),
- )
-
- if self.num_classes is not None:
- if isinstance(self.num_classes, int):
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
- elif self.num_classes == "continuous":
- print("setting up linear c_adm embedding layer")
- self.label_emb = nn.Linear(1, time_embed_dim)
- elif self.num_classes == "sequential":
- assert adm_in_channels is not None
- self.label_emb = nn.Sequential(
- nn.Sequential(
- linear(adm_in_channels, time_embed_dim),
- nn.SiLU(),
- linear(time_embed_dim, time_embed_dim),
- )
- )
- else:
- raise ValueError()
-
- if self.with_ip and (context_dim is not None) and ip_dim > 0:
- if self.ip_mode == "local_resample":
- # ip-adapter-plus
- hidden_dim = 1280
- self.image_embed = Resampler(
- dim=context_dim,
- depth=4,
- dim_head=64,
- heads=12,
- num_queries=ip_dim, # num token
- embedding_dim=hidden_dim,
- output_dim=context_dim,
- ff_mult=4,
- )
- elif self.ip_mode == "global":
- self.image_embed = ImageProjModel(
- cross_attention_dim=context_dim,
- clip_extra_context_tokens=ip_dim)
- else:
- raise ValueError(f"{self.ip_mode} is not supported")
-
- self.input_blocks = nn.ModuleList(
- [
- TimestepEmbedSequential(
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
- )
- ]
- )
- self._feature_size = model_channels
- input_block_chans = [model_channels]
- ch = model_channels
- ds = 1
- for level, mult in enumerate(channel_mult):
- for nr in range(self.num_res_blocks[level]):
- layers = [
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- out_channels=mult * model_channels,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- )
- ]
- ch = mult * model_channels
- if ds in attention_resolutions:
- if num_head_channels == -1:
- dim_head = ch // num_heads
- else:
- num_heads = ch // num_head_channels
- dim_head = num_head_channels
- if legacy:
- # num_heads = 1
- dim_head = (
- ch // num_heads
- if use_spatial_transformer
- else num_head_channels
- )
- if exists(disable_self_attentions):
- disabled_sa = disable_self_attentions[level]
- else:
- disabled_sa = False
-
- if (
- not exists(num_attention_blocks)
- or nr < num_attention_blocks[level]
- ):
- layers.append(
- AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads,
- num_head_channels=dim_head,
- use_new_attention_order=use_new_attention_order,
- )
- if not use_spatial_transformer
- else SpatialTransformer3D(
- ch,
- num_heads,
- dim_head,
- depth=transformer_depth,
- context_dim=context_dim,
- disable_self_attn=disabled_sa,
- use_linear=use_linear_in_transformer,
- use_checkpoint=use_checkpoint,
- with_ip=self.with_ip,
- ip_dim=self.ip_dim,
- ip_weight=self.ip_weight
- )
- )
- self.input_blocks.append(TimestepEmbedSequential(*layers))
- self._feature_size += ch
- input_block_chans.append(ch)
-
- if level != len(channel_mult) - 1:
- out_ch = ch
- self.input_blocks.append(
- TimestepEmbedSequential(
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- out_channels=out_ch,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- down=True,
- )
- if resblock_updown
- else Downsample(
- ch, conv_resample, dims=dims, out_channels=out_ch
- )
- )
- )
- ch = out_ch
- input_block_chans.append(ch)
- ds *= 2
- self._feature_size += ch
-
- if num_head_channels == -1:
- dim_head = ch // num_heads
- else:
- num_heads = ch // num_head_channels
- dim_head = num_head_channels
- if legacy:
- # num_heads = 1
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
- self.middle_block = TimestepEmbedSequential(
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- ),
- AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads,
- num_head_channels=dim_head,
- use_new_attention_order=use_new_attention_order,
- )
- if not use_spatial_transformer
- else SpatialTransformer3D( # always uses a self-attn
- ch,
- num_heads,
- dim_head,
- depth=transformer_depth,
- context_dim=context_dim,
- disable_self_attn=disable_middle_self_attn,
- use_linear=use_linear_in_transformer,
- use_checkpoint=use_checkpoint,
- with_ip=self.with_ip,
- ip_dim=self.ip_dim,
- ip_weight=self.ip_weight
- ),
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- ),
- )
- self._feature_size += ch
-
- self.output_blocks = nn.ModuleList([])
- for level, mult in list(enumerate(channel_mult))[::-1]:
- for i in range(self.num_res_blocks[level] + 1):
- ich = input_block_chans.pop()
- layers = [
- ResBlock(
- ch + ich,
- time_embed_dim,
- dropout,
- out_channels=model_channels * mult,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- )
- ]
- ch = model_channels * mult
- if ds in attention_resolutions:
- if num_head_channels == -1:
- dim_head = ch // num_heads
- else:
- num_heads = ch // num_head_channels
- dim_head = num_head_channels
- if legacy:
- # num_heads = 1
- dim_head = (
- ch // num_heads
- if use_spatial_transformer
- else num_head_channels
- )
- if exists(disable_self_attentions):
- disabled_sa = disable_self_attentions[level]
- else:
- disabled_sa = False
-
- if (
- not exists(num_attention_blocks)
- or i < num_attention_blocks[level]
- ):
- layers.append(
- AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads_upsample,
- num_head_channels=dim_head,
- use_new_attention_order=use_new_attention_order,
- )
- if not use_spatial_transformer
- else SpatialTransformer3D(
- ch,
- num_heads,
- dim_head,
- depth=transformer_depth,
- context_dim=context_dim,
- disable_self_attn=disabled_sa,
- use_linear=use_linear_in_transformer,
- use_checkpoint=use_checkpoint,
- with_ip=self.with_ip,
- ip_dim=self.ip_dim,
- ip_weight=self.ip_weight
- )
- )
- if level and i == self.num_res_blocks[level]:
- out_ch = ch
- layers.append(
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- out_channels=out_ch,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- up=True,
- )
- if resblock_updown
- else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
- )
- ds //= 2
- self.output_blocks.append(TimestepEmbedSequential(*layers))
- self._feature_size += ch
-
- self.out = nn.Sequential(
- normalization(ch),
- nn.SiLU(),
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
- )
- if self.predict_codebook_ids:
- self.id_predictor = nn.Sequential(
- normalization(ch),
- conv_nd(dims, model_channels, n_embed, 1),
- # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
- )
-
- def convert_to_fp16(self):
- """
- Convert the torso of the model to float16.
- """
- self.input_blocks.apply(convert_module_to_f16)
- self.middle_block.apply(convert_module_to_f16)
- self.output_blocks.apply(convert_module_to_f16)
-
- def convert_to_fp32(self):
- """
- Convert the torso of the model to float32.
- """
- self.input_blocks.apply(convert_module_to_f32)
- self.middle_block.apply(convert_module_to_f32)
- self.output_blocks.apply(convert_module_to_f32)
-
- def forward(
- self,
- x,
- timesteps=None,
- context=None,
- y=None,
- camera=None,
- num_frames=1,
- **kwargs,
- ):
- """
- Apply the model to an input batch.
- :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
- :param timesteps: a 1-D batch of timesteps.
- :param context: a dict conditioning plugged in via crossattn
- :param y: an [N] Tensor of labels, if class-conditional, default None.
- :param num_frames: a integer indicating number of frames for tensor reshaping.
- :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
- """
- assert (
- x.shape[0] % num_frames == 0
- ), "[UNet] input batch size must be dividable by num_frames!"
- assert (y is not None) == (
- self.num_classes is not None
- ), "must specify y if and only if the model is class-conditional"
-
- hs = []
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # shape: torch.Size([B, 320]) mean: 0.18, std: 0.68, min: -1.00, max: 1.00
- emb = self.time_embed(t_emb) # shape: torch.Size([B, 1280]) mean: 0.12, std: 0.57, min: -5.73, max: 6.51
-
- if self.num_classes is not None:
- assert y.shape[0] == x.shape[0]
- emb = emb + self.label_emb(y)
-
- # Add camera embeddings
- if camera is not None:
- assert camera.shape[0] == emb.shape[0]
- # camera embed: shape: torch.Size([B, 1280]) mean: -0.02, std: 0.27, min: -7.23, max: 2.04
- emb = emb + self.camera_embed(camera)
- ip = kwargs.get("ip", None)
- ip_img = kwargs.get("ip_img", None)
-
- if ip_img is not None:
- x[(num_frames-1)::num_frames, :, :, :] = ip_img
-
- if ip is not None:
- ip_emb = self.image_embed(ip) # shape: torch.Size([B, 16, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
- context = torch.cat((context, ip_emb), 1) # shape: torch.Size([B, 93, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
-
- h = x.type(self.dtype)
- for module in self.input_blocks:
- h = module(h, emb, context, num_frames=num_frames)
- hs.append(h)
- h = self.middle_block(h, emb, context, num_frames=num_frames)
- for module in self.output_blocks:
- h = th.cat([h, hs.pop()], dim=1)
- h = module(h, emb, context, num_frames=num_frames)
- h = h.type(x.dtype) # shape: torch.Size([10, 320, 32, 32]) mean: -0.67, std: 3.96, min: -42.74, max: 25.58
- if self.predict_codebook_ids: # False
- return self.id_predictor(h)
- else:
- return self.out(h) # shape: torch.Size([10, 4, 32, 32]) mean: -0.00, std: 0.91, min: -3.65, max: 3.93
-
-
-
-
-class MultiViewUNetModelStage2(MultiViewUNetModel):
- """
- The full multi-view UNet model with attention, timestep embedding and camera embedding.
- :param in_channels: channels in the input Tensor.
- :param model_channels: base channel count for the model.
- :param out_channels: channels in the output Tensor.
- :param num_res_blocks: number of residual blocks per downsample.
- :param attention_resolutions: a collection of downsample rates at which
- attention will take place. May be a set, list, or tuple.
- For example, if this contains 4, then at 4x downsampling, attention
- will be used.
- :param dropout: the dropout probability.
- :param channel_mult: channel multiplier for each level of the UNet.
- :param conv_resample: if True, use learned convolutions for upsampling and
- downsampling.
- :param dims: determines if the signal is 1D, 2D, or 3D.
- :param num_classes: if specified (as an int), then this model will be
- class-conditional with `num_classes` classes.
- :param use_checkpoint: use gradient checkpointing to reduce memory usage.
- :param num_heads: the number of attention heads in each attention layer.
- :param num_heads_channels: if specified, ignore num_heads and instead use
- a fixed channel width per attention head.
- :param num_heads_upsample: works with num_heads to set a different number
- of heads for upsampling. Deprecated.
- :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
- :param resblock_updown: use residual blocks for up/downsampling.
- :param use_new_attention_order: use a different attention pattern for potentially
- increased efficiency.
- :param camera_dim: dimensionality of camera input.
- """
-
- def __init__(
- self,
- image_size,
- in_channels,
- model_channels,
- out_channels,
- num_res_blocks,
- attention_resolutions,
- dropout=0,
- channel_mult=(1, 2, 4, 8),
- conv_resample=True,
- dims=2,
- num_classes=None,
- use_checkpoint=False,
- use_fp16=False,
- use_bf16=False,
- num_heads=-1,
- num_head_channels=-1,
- num_heads_upsample=-1,
- use_scale_shift_norm=False,
- resblock_updown=False,
- use_new_attention_order=False,
- use_spatial_transformer=False, # custom transformer support
- transformer_depth=1, # custom transformer support
- context_dim=None, # custom transformer support
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
- legacy=True,
- disable_self_attentions=None,
- num_attention_blocks=None,
- disable_middle_self_attn=False,
- use_linear_in_transformer=False,
- adm_in_channels=None,
- camera_dim=None,
- with_ip=False, # wether add image prompt images
- ip_dim=0, # number of extra token, 4 for global 16 for local
- ip_weight=1.0, # weight for image prompt context
- ip_mode="local_resample", # which mode of adaptor, global or local
- ):
- super().__init__(
- image_size,
- in_channels,
- model_channels,
- out_channels,
- num_res_blocks,
- attention_resolutions,
- dropout,
- channel_mult,
- conv_resample,
- dims,
- num_classes,
- use_checkpoint,
- use_fp16,
- use_bf16,
- num_heads,
- num_head_channels,
- num_heads_upsample,
- use_scale_shift_norm,
- resblock_updown,
- use_new_attention_order,
- use_spatial_transformer,
- transformer_depth,
- context_dim,
- n_embed,
- legacy,
- disable_self_attentions,
- num_attention_blocks,
- disable_middle_self_attn,
- use_linear_in_transformer,
- adm_in_channels,
- camera_dim,
- with_ip,
- ip_dim,
- ip_weight,
- ip_mode,
- )
-
- def forward(
- self,
- x,
- timesteps=None,
- context=None,
- y=None,
- camera=None,
- num_frames=1,
- **kwargs,
- ):
- """
- Apply the model to an input batch.
- :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
- :param timesteps: a 1-D batch of timesteps.
- :param context: a dict conditioning plugged in via crossattn
- :param y: an [N] Tensor of labels, if class-conditional, default None.
- :param num_frames: a integer indicating number of frames for tensor reshaping.
- :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
- """
- assert (
- x.shape[0] % num_frames == 0
- ), "[UNet] input batch size must be dividable by num_frames!"
- assert (y is not None) == (
- self.num_classes is not None
- ), "must specify y if and only if the model is class-conditional"
-
- hs = []
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # shape: torch.Size([B, 320]) mean: 0.18, std: 0.68, min: -1.00, max: 1.00
- emb = self.time_embed(t_emb) # shape: torch.Size([B, 1280]) mean: 0.12, std: 0.57, min: -5.73, max: 6.51
-
- if self.num_classes is not None:
- assert y.shape[0] == x.shape[0]
- emb = emb + self.label_emb(y)
-
- # Add camera embeddings
- if camera is not None:
- assert camera.shape[0] == emb.shape[0]
- # camera embed: shape: torch.Size([B, 1280]) mean: -0.02, std: 0.27, min: -7.23, max: 2.04
- emb = emb + self.camera_embed(camera)
- ip = kwargs.get("ip", None)
- ip_img = kwargs.get("ip_img", None)
- pixel_images = kwargs.get("pixel_images", None)
-
- if ip_img is not None:
- x[(num_frames-1)::num_frames, :, :, :] = ip_img
-
- x = torch.cat((x, pixel_images), dim=1)
-
- if ip is not None:
- ip_emb = self.image_embed(ip) # shape: torch.Size([B, 16, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
- context = torch.cat((context, ip_emb), 1) # shape: torch.Size([B, 93, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
-
- h = x.type(self.dtype)
- for module in self.input_blocks:
- h = module(h, emb, context, num_frames=num_frames)
- hs.append(h)
- h = self.middle_block(h, emb, context, num_frames=num_frames)
- for module in self.output_blocks:
- h = th.cat([h, hs.pop()], dim=1)
- h = module(h, emb, context, num_frames=num_frames)
- h = h.type(x.dtype) # shape: torch.Size([10, 320, 32, 32]) mean: -0.67, std: 3.96, min: -42.74, max: 25.58
- if self.predict_codebook_ids: # False
- return self.id_predictor(h)
- else:
- return self.out(h) # shape: torch.Size([10, 4, 32, 32]) mean: -0.00, std: 0.91, min: -3.65, max: 3.93
-
\ No newline at end of file
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/util.py b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/util.py
deleted file mode 100644
index af744261d41deab5d686aead790726efcdfaf961..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/util.py
+++ /dev/null
@@ -1,353 +0,0 @@
-# adopted from
-# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
-# and
-# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
-# and
-# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
-#
-# thanks!
-
-
-import os
-import math
-import torch
-import torch.nn as nn
-import numpy as np
-from einops import repeat
-import importlib
-
-
-def instantiate_from_config(config):
- if not "target" in config:
- if config == "__is_first_stage__":
- return None
- elif config == "__is_unconditional__":
- return None
- raise KeyError("Expected key `target` to instantiate.")
- return get_obj_from_str(config["target"])(**config.get("params", dict()))
-
-
-def get_obj_from_str(string, reload=False):
- module, cls = string.rsplit(".", 1)
- if reload:
- module_imp = importlib.import_module(module)
- importlib.reload(module_imp)
- return getattr(importlib.import_module(module, package=None), cls)
-
-
-def make_beta_schedule(
- schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
-):
- if schedule == "linear":
- betas = (
- torch.linspace(
- linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
- )
- ** 2
- )
-
- elif schedule == "cosine":
- timesteps = (
- torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
- )
- alphas = timesteps / (1 + cosine_s) * np.pi / 2
- alphas = torch.cos(alphas).pow(2)
- alphas = alphas / alphas[0]
- betas = 1 - alphas[1:] / alphas[:-1]
- betas = np.clip(betas, a_min=0, a_max=0.999)
-
- elif schedule == "sqrt_linear":
- betas = torch.linspace(
- linear_start, linear_end, n_timestep, dtype=torch.float64
- )
- elif schedule == "sqrt":
- betas = (
- torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
- ** 0.5
- )
- else:
- raise ValueError(f"schedule '{schedule}' unknown.")
- return betas.numpy()
-
-def enforce_zero_terminal_snr(betas):
- betas = torch.tensor(betas) if not isinstance(betas, torch.Tensor) else betas
- # Convert betas to alphas_bar_sqrt
- alphas =1 - betas
- alphas_bar = alphas.cumprod(0)
- alphas_bar_sqrt = alphas_bar.sqrt()
- # Store old values.
- alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
- alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
- # Shift so last timestep is zero.
- alphas_bar_sqrt -= alphas_bar_sqrt_T
- # Scale so first timestep is back to old value.
- alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
- # Convert alphas_bar_sqrt to betas
- alphas_bar = alphas_bar_sqrt ** 2
- alphas = alphas_bar[1:] / alphas_bar[:-1]
- alphas = torch.cat ([alphas_bar[0:1], alphas])
- betas = 1 - alphas
- return betas
-
-
-def make_ddim_timesteps(
- ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
-):
- if ddim_discr_method == "uniform":
- c = num_ddpm_timesteps // num_ddim_timesteps
- ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
- elif ddim_discr_method == "quad":
- ddim_timesteps = (
- (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
- ).astype(int)
- else:
- raise NotImplementedError(
- f'There is no ddim discretization method called "{ddim_discr_method}"'
- )
-
- # assert ddim_timesteps.shape[0] == num_ddim_timesteps
- # add one to get the final alpha values right (the ones from first scale to data during sampling)
- steps_out = ddim_timesteps + 1
- if verbose:
- print(f"Selected timesteps for ddim sampler: {steps_out}")
- return steps_out
-
-
-def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
- # select alphas for computing the variance schedule
- alphas = alphacums[ddim_timesteps]
- alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
-
- # according the the formula provided in https://arxiv.org/abs/2010.02502
- sigmas = eta * np.sqrt(
- (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
- )
- if verbose:
- print(
- f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
- )
- print(
- f"For the chosen value of eta, which is {eta}, "
- f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
- )
- return sigmas, alphas, alphas_prev
-
-
-def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
- """
- Create a beta schedule that discretizes the given alpha_t_bar function,
- which defines the cumulative product of (1-beta) over time from t = [0,1].
- :param num_diffusion_timesteps: the number of betas to produce.
- :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
- produces the cumulative product of (1-beta) up to that
- part of the diffusion process.
- :param max_beta: the maximum beta to use; use values lower than 1 to
- prevent singularities.
- """
- betas = []
- for i in range(num_diffusion_timesteps):
- t1 = i / num_diffusion_timesteps
- t2 = (i + 1) / num_diffusion_timesteps
- betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
- return np.array(betas)
-
-
-def extract_into_tensor(a, t, x_shape):
- b, *_ = t.shape
- out = a.gather(-1, t)
- return out.reshape(b, *((1,) * (len(x_shape) - 1)))
-
-
-def checkpoint(func, inputs, params, flag):
- """
- Evaluate a function without caching intermediate activations, allowing for
- reduced memory at the expense of extra compute in the backward pass.
- :param func: the function to evaluate.
- :param inputs: the argument sequence to pass to `func`.
- :param params: a sequence of parameters `func` depends on but does not
- explicitly take as arguments.
- :param flag: if False, disable gradient checkpointing.
- """
- if flag:
- args = tuple(inputs) + tuple(params)
- return CheckpointFunction.apply(func, len(inputs), *args)
- else:
- return func(*inputs)
-
-
-class CheckpointFunction(torch.autograd.Function):
- @staticmethod
- def forward(ctx, run_function, length, *args):
- ctx.run_function = run_function
- ctx.input_tensors = list(args[:length])
- ctx.input_params = list(args[length:])
-
- with torch.no_grad():
- output_tensors = ctx.run_function(*ctx.input_tensors)
- return output_tensors
-
- @staticmethod
- def backward(ctx, *output_grads):
- ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
- with torch.enable_grad():
- # Fixes a bug where the first op in run_function modifies the
- # Tensor storage in place, which is not allowed for detach()'d
- # Tensors.
- shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
- output_tensors = ctx.run_function(*shallow_copies)
- input_grads = torch.autograd.grad(
- output_tensors,
- ctx.input_tensors + ctx.input_params,
- output_grads,
- allow_unused=True,
- )
- del ctx.input_tensors
- del ctx.input_params
- del output_tensors
- return (None, None) + input_grads
-
-
-def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
- """
- Create sinusoidal timestep embeddings.
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
- These may be fractional.
- :param dim: the dimension of the output.
- :param max_period: controls the minimum frequency of the embeddings.
- :return: an [N x dim] Tensor of positional embeddings.
- """
- if not repeat_only:
- half = dim // 2
- freqs = torch.exp(
- -math.log(max_period)
- * torch.arange(start=0, end=half, dtype=torch.float32)
- / half
- ).to(device=timesteps.device)
- args = timesteps[:, None].float() * freqs[None]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat(
- [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
- )
- else:
- embedding = repeat(timesteps, "b -> b d", d=dim)
- # import pdb; pdb.set_trace()
- return embedding
-
-
-def zero_module(module):
- """
- Zero out the parameters of a module and return it.
- """
- for p in module.parameters():
- p.detach().zero_()
- return module
-
-
-def scale_module(module, scale):
- """
- Scale the parameters of a module and return it.
- """
- for p in module.parameters():
- p.detach().mul_(scale)
- return module
-
-
-def mean_flat(tensor):
- """
- Take the mean over all non-batch dimensions.
- """
- return tensor.mean(dim=list(range(1, len(tensor.shape))))
-
-
-def normalization(channels):
- """
- Make a standard normalization layer.
- :param channels: number of input channels.
- :return: an nn.Module for normalization.
- """
- return GroupNorm32(32, channels)
-
-
-# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
-class SiLU(nn.Module):
- def forward(self, x):
- return x * torch.sigmoid(x)
-
-
-class GroupNorm32(nn.GroupNorm):
- def forward(self, x):
- return super().forward(x.float()).type(x.dtype)
-
-
-def conv_nd(dims, *args, **kwargs):
- """
- Create a 1D, 2D, or 3D convolution module.
- """
- if dims == 1:
- return nn.Conv1d(*args, **kwargs)
- elif dims == 2:
- return nn.Conv2d(*args, **kwargs)
- elif dims == 3:
- return nn.Conv3d(*args, **kwargs)
- raise ValueError(f"unsupported dimensions: {dims}")
-
-
-def linear(*args, **kwargs):
- """
- Create a linear module.
- """
- return nn.Linear(*args, **kwargs)
-
-
-def avg_pool_nd(dims, *args, **kwargs):
- """
- Create a 1D, 2D, or 3D average pooling module.
- """
- if dims == 1:
- return nn.AvgPool1d(*args, **kwargs)
- elif dims == 2:
- return nn.AvgPool2d(*args, **kwargs)
- elif dims == 3:
- return nn.AvgPool3d(*args, **kwargs)
- raise ValueError(f"unsupported dimensions: {dims}")
-
-
-class HybridConditioner(nn.Module):
- def __init__(self, c_concat_config, c_crossattn_config):
- super().__init__()
- self.concat_conditioner = instantiate_from_config(c_concat_config)
- self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
-
- def forward(self, c_concat, c_crossattn):
- c_concat = self.concat_conditioner(c_concat)
- c_crossattn = self.crossattn_conditioner(c_crossattn)
- return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
-
-
-def noise_like(shape, device, repeat=False):
- repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
- shape[0], *((1,) * (len(shape) - 1))
- )
- noise = lambda: torch.randn(shape, device=device)
- return repeat_noise() if repeat else noise()
-
-
-# dummy replace
-def convert_module_to_f16(l):
- """
- Convert primitive modules to float16.
- """
- if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
- l.weight.data = l.weight.data.half()
- if l.bias is not None:
- l.bias.data = l.bias.data.half()
-
-def convert_module_to_f32(l):
- """
- Convert primitive modules to float32, undoing convert_module_to_f16().
- """
- if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
- l.weight.data = l.weight.data.float()
- if l.bias is not None:
- l.bias.data = l.bias.data.float()
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/distributions/__init__.py b/apps/third_party/CRM/imagedream/ldm/modules/distributions/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index 312f98ad511d7f30d3483685fa70528b1910b20a..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc
deleted file mode 100644
index 30cc9dbbbadd654864db556064cf17f299a66b28..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/distributions.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/distributions.cpython-310.pyc
deleted file mode 100644
index c80371d54b3d6110586bfe2f53628dbae828ecde..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/distributions.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc
deleted file mode 100644
index ee436d35f6a01e2ff87c0588b6ec2bdf7211f43b..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/distributions/distributions.py b/apps/third_party/CRM/imagedream/ldm/modules/distributions/distributions.py
deleted file mode 100644
index 92f4428a3defd8fbae18fcd323c9d404036c652e..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/imagedream/ldm/modules/distributions/distributions.py
+++ /dev/null
@@ -1,102 +0,0 @@
-import torch
-import numpy as np
-
-
-class AbstractDistribution:
- def sample(self):
- raise NotImplementedError()
-
- def mode(self):
- raise NotImplementedError()
-
-
-class DiracDistribution(AbstractDistribution):
- def __init__(self, value):
- self.value = value
-
- def sample(self):
- return self.value
-
- def mode(self):
- return self.value
-
-
-class DiagonalGaussianDistribution(object):
- def __init__(self, parameters, deterministic=False):
- self.parameters = parameters
- self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
- self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
- self.deterministic = deterministic
- self.std = torch.exp(0.5 * self.logvar)
- self.var = torch.exp(self.logvar)
- if self.deterministic:
- self.var = self.std = torch.zeros_like(self.mean).to(
- device=self.parameters.device
- )
-
- def sample(self):
- x = self.mean + self.std * torch.randn(self.mean.shape).to(
- device=self.parameters.device
- )
- return x
-
- def kl(self, other=None):
- if self.deterministic:
- return torch.Tensor([0.0])
- else:
- if other is None:
- return 0.5 * torch.sum(
- torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
- dim=[1, 2, 3],
- )
- else:
- return 0.5 * torch.sum(
- torch.pow(self.mean - other.mean, 2) / other.var
- + self.var / other.var
- - 1.0
- - self.logvar
- + other.logvar,
- dim=[1, 2, 3],
- )
-
- def nll(self, sample, dims=[1, 2, 3]):
- if self.deterministic:
- return torch.Tensor([0.0])
- logtwopi = np.log(2.0 * np.pi)
- return 0.5 * torch.sum(
- logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
- dim=dims,
- )
-
- def mode(self):
- return self.mean
-
-
-def normal_kl(mean1, logvar1, mean2, logvar2):
- """
- source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
- Compute the KL divergence between two gaussians.
- Shapes are automatically broadcasted, so batches can be compared to
- scalars, among other use cases.
- """
- tensor = None
- for obj in (mean1, logvar1, mean2, logvar2):
- if isinstance(obj, torch.Tensor):
- tensor = obj
- break
- assert tensor is not None, "at least one argument must be a Tensor"
-
- # Force variances to be Tensors. Broadcasting helps convert scalars to
- # Tensors, but it does not work for torch.exp().
- logvar1, logvar2 = [
- x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
- for x in (logvar1, logvar2)
- ]
-
- return 0.5 * (
- -1.0
- + logvar2
- - logvar1
- + torch.exp(logvar1 - logvar2)
- + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
- )
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/ema.py b/apps/third_party/CRM/imagedream/ldm/modules/ema.py
deleted file mode 100644
index a073e116975f3313fb630b7d4ac115171c1fe31d..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/imagedream/ldm/modules/ema.py
+++ /dev/null
@@ -1,86 +0,0 @@
-import torch
-from torch import nn
-
-
-class LitEma(nn.Module):
- def __init__(self, model, decay=0.9999, use_num_upates=True):
- super().__init__()
- if decay < 0.0 or decay > 1.0:
- raise ValueError("Decay must be between 0 and 1")
-
- self.m_name2s_name = {}
- self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
- self.register_buffer(
- "num_updates",
- torch.tensor(0, dtype=torch.int)
- if use_num_upates
- else torch.tensor(-1, dtype=torch.int),
- )
-
- for name, p in model.named_parameters():
- if p.requires_grad:
- # remove as '.'-character is not allowed in buffers
- s_name = name.replace(".", "")
- self.m_name2s_name.update({name: s_name})
- self.register_buffer(s_name, p.clone().detach().data)
-
- self.collected_params = []
-
- def reset_num_updates(self):
- del self.num_updates
- self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
-
- def forward(self, model):
- decay = self.decay
-
- if self.num_updates >= 0:
- self.num_updates += 1
- decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
-
- one_minus_decay = 1.0 - decay
-
- with torch.no_grad():
- m_param = dict(model.named_parameters())
- shadow_params = dict(self.named_buffers())
-
- for key in m_param:
- if m_param[key].requires_grad:
- sname = self.m_name2s_name[key]
- shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
- shadow_params[sname].sub_(
- one_minus_decay * (shadow_params[sname] - m_param[key])
- )
- else:
- assert not key in self.m_name2s_name
-
- def copy_to(self, model):
- m_param = dict(model.named_parameters())
- shadow_params = dict(self.named_buffers())
- for key in m_param:
- if m_param[key].requires_grad:
- m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
- else:
- assert not key in self.m_name2s_name
-
- def store(self, parameters):
- """
- Save the current parameters for restoring later.
- Args:
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
- temporarily stored.
- """
- self.collected_params = [param.clone() for param in parameters]
-
- def restore(self, parameters):
- """
- Restore the parameters stored with the `store` method.
- Useful to validate the model with EMA parameters without affecting the
- original optimization process. Store the parameters before the
- `copy_to` method. After validation (or model saving), use this to
- restore the former parameters.
- Args:
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
- updated with the stored parameters.
- """
- for c_param, param in zip(self.collected_params, parameters):
- param.data.copy_(c_param.data)
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/encoders/__init__.py b/apps/third_party/CRM/imagedream/ldm/modules/encoders/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index 6ad2b831f070c88d8b1b6d35c697cbb2b8466d66..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc
deleted file mode 100644
index 98745a81153dfd563cb689b76944075d93cd4feb..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/modules.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/modules.cpython-310.pyc
deleted file mode 100644
index 8c722f5bdac9cc12253ad9934ee36901b8468d57..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/modules.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc
deleted file mode 100644
index 70899c6fc30bc670956be3ac51675b27beae2d02..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/imagedream/ldm/modules/encoders/modules.py b/apps/third_party/CRM/imagedream/ldm/modules/encoders/modules.py
deleted file mode 100644
index a19d2c193d660535f101fa1cc3f1b857ce1197fc..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/imagedream/ldm/modules/encoders/modules.py
+++ /dev/null
@@ -1,329 +0,0 @@
-import torch
-import torch.nn as nn
-from torch.utils.checkpoint import checkpoint
-
-from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
-
-import numpy as np
-import open_clip
-from PIL import Image
-from ...util import default, count_params
-
-
-class AbstractEncoder(nn.Module):
- def __init__(self):
- super().__init__()
-
- def encode(self, *args, **kwargs):
- raise NotImplementedError
-
-
-class IdentityEncoder(AbstractEncoder):
- def encode(self, x):
- return x
-
-
-class ClassEmbedder(nn.Module):
- def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1):
- super().__init__()
- self.key = key
- self.embedding = nn.Embedding(n_classes, embed_dim)
- self.n_classes = n_classes
- self.ucg_rate = ucg_rate
-
- def forward(self, batch, key=None, disable_dropout=False):
- if key is None:
- key = self.key
- # this is for use in crossattn
- c = batch[key][:, None]
- if self.ucg_rate > 0.0 and not disable_dropout:
- mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
- c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
- c = c.long()
- c = self.embedding(c)
- return c
-
- def get_unconditional_conditioning(self, bs, device="cuda"):
- uc_class = (
- self.n_classes - 1
- ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
- uc = torch.ones((bs,), device=device) * uc_class
- uc = {self.key: uc}
- return uc
-
-
-def disabled_train(self, mode=True):
- """Overwrite model.train with this function to make sure train/eval mode
- does not change anymore."""
- return self
-
-
-class FrozenT5Embedder(AbstractEncoder):
- """Uses the T5 transformer encoder for text"""
-
- def __init__(
- self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True
- ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
- super().__init__()
- self.tokenizer = T5Tokenizer.from_pretrained(version)
- self.transformer = T5EncoderModel.from_pretrained(version)
- self.device = device
- self.max_length = max_length # TODO: typical value?
- if freeze:
- self.freeze()
-
- def freeze(self):
- self.transformer = self.transformer.eval()
- # self.train = disabled_train
- for param in self.parameters():
- param.requires_grad = False
-
- def forward(self, text):
- batch_encoding = self.tokenizer(
- text,
- truncation=True,
- max_length=self.max_length,
- return_length=True,
- return_overflowing_tokens=False,
- padding="max_length",
- return_tensors="pt",
- )
- tokens = batch_encoding["input_ids"].to(self.device)
- outputs = self.transformer(input_ids=tokens)
-
- z = outputs.last_hidden_state
- return z
-
- def encode(self, text):
- return self(text)
-
-
-class FrozenCLIPEmbedder(AbstractEncoder):
- """Uses the CLIP transformer encoder for text (from huggingface)"""
-
- LAYERS = ["last", "pooled", "hidden"]
-
- def __init__(
- self,
- version="openai/clip-vit-large-patch14",
- device="cuda",
- max_length=77,
- freeze=True,
- layer="last",
- layer_idx=None,
- ): # clip-vit-base-patch32
- super().__init__()
- assert layer in self.LAYERS
- self.tokenizer = CLIPTokenizer.from_pretrained(version)
- self.transformer = CLIPTextModel.from_pretrained(version)
- self.device = device
- self.max_length = max_length
- if freeze:
- self.freeze()
- self.layer = layer
- self.layer_idx = layer_idx
- if layer == "hidden":
- assert layer_idx is not None
- assert 0 <= abs(layer_idx) <= 12
-
- def freeze(self):
- self.transformer = self.transformer.eval()
- # self.train = disabled_train
- for param in self.parameters():
- param.requires_grad = False
-
- def forward(self, text):
- batch_encoding = self.tokenizer(
- text,
- truncation=True,
- max_length=self.max_length,
- return_length=True,
- return_overflowing_tokens=False,
- padding="max_length",
- return_tensors="pt",
- )
- tokens = batch_encoding["input_ids"].to(self.device)
- outputs = self.transformer(
- input_ids=tokens, output_hidden_states=self.layer == "hidden"
- )
- if self.layer == "last":
- z = outputs.last_hidden_state
- elif self.layer == "pooled":
- z = outputs.pooler_output[:, None, :]
- else:
- z = outputs.hidden_states[self.layer_idx]
- return z
-
- def encode(self, text):
- return self(text)
-
-
-class FrozenOpenCLIPEmbedder(AbstractEncoder, nn.Module):
- """
- Uses the OpenCLIP transformer encoder for text
- """
-
- LAYERS = [
- # "pooled",
- "last",
- "penultimate",
- ]
-
- def __init__(
- self,
- arch="ViT-H-14",
- version="laion2b_s32b_b79k",
- device="cuda",
- max_length=77,
- freeze=True,
- layer="last",
- ip_mode=None
- ):
- """_summary_
-
- Args:
- ip_mode (str, optional): what is the image promcessing mode. Defaults to None.
-
- """
- super().__init__()
- assert layer in self.LAYERS
- model, _, preprocess = open_clip.create_model_and_transforms(
- arch, device=torch.device("cpu"), pretrained=version
- )
- if ip_mode is None:
- del model.visual
-
- self.model = model
- self.preprocess = preprocess
- self.device = device
- self.max_length = max_length
- self.ip_mode = ip_mode
- if freeze:
- self.freeze()
- self.layer = layer
- if self.layer == "last":
- self.layer_idx = 0
- elif self.layer == "penultimate":
- self.layer_idx = 1
- else:
- raise NotImplementedError()
-
- def freeze(self):
- self.model = self.model.eval()
- for param in self.parameters():
- param.requires_grad = False
-
- def forward(self, text):
- tokens = open_clip.tokenize(text)
- z = self.encode_with_transformer(tokens.to(self.device))
- return z
-
- def forward_image(self, pil_image):
- if isinstance(pil_image, Image.Image):
- pil_image = [pil_image]
- if isinstance(pil_image, torch.Tensor):
- pil_image = pil_image.cpu().numpy()
- if isinstance(pil_image, np.ndarray):
- if pil_image.ndim == 3:
- pil_image = pil_image[None, :, :, :]
- pil_image = [Image.fromarray(x) for x in pil_image]
-
- images = []
- for image in pil_image:
- images.append(self.preprocess(image).to(self.device))
-
- image = torch.stack(images, 0) # to [b, 3, h, w]
- if self.ip_mode == "global":
- image_features = self.model.encode_image(image)
- image_features /= image_features.norm(dim=-1, keepdim=True)
- elif "local" in self.ip_mode:
- image_features = self.encode_image_with_transformer(image)
-
- return image_features # b, l
-
- def encode_image_with_transformer(self, x):
- visual = self.model.visual
- x = visual.conv1(x) # shape = [*, width, grid, grid]
- x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
-
- # class embeddings and positional embeddings
- x = torch.cat(
- [visual.class_embedding.to(x.dtype) + \
- torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
- x], dim=1) # shape = [*, grid ** 2 + 1, width]
- x = x + visual.positional_embedding.to(x.dtype)
-
- # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
- # x = visual.patch_dropout(x)
- x = visual.ln_pre(x)
-
- x = x.permute(1, 0, 2) # NLD -> LND
- hidden = self.image_transformer_forward(x)
- x = hidden[-2].permute(1, 0, 2) # LND -> NLD
- return x
-
- def image_transformer_forward(self, x):
- encoder_states = ()
- trans = self.model.visual.transformer
- for r in trans.resblocks:
- if trans.grad_checkpointing and not torch.jit.is_scripting():
- # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
- x = checkpoint(r, x, None, None, None)
- else:
- x = r(x, attn_mask=None)
- encoder_states = encoder_states + (x, )
- return encoder_states
-
- def encode_with_transformer(self, text):
- x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
- x = x + self.model.positional_embedding
- x = x.permute(1, 0, 2) # NLD -> LND
- x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
- x = x.permute(1, 0, 2) # LND -> NLD
- x = self.model.ln_final(x)
- return x
-
- def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
- for i, r in enumerate(self.model.transformer.resblocks):
- if i == len(self.model.transformer.resblocks) - self.layer_idx:
- break
- if (
- self.model.transformer.grad_checkpointing
- and not torch.jit.is_scripting()
- ):
- x = checkpoint(r, x, attn_mask)
- else:
- x = r(x, attn_mask=attn_mask)
- return x
-
- def encode(self, text):
- return self(text)
-
-
-class FrozenCLIPT5Encoder(AbstractEncoder):
- def __init__(
- self,
- clip_version="openai/clip-vit-large-patch14",
- t5_version="google/t5-v1_1-xl",
- device="cuda",
- clip_max_length=77,
- t5_max_length=77,
- ):
- super().__init__()
- self.clip_encoder = FrozenCLIPEmbedder(
- clip_version, device, max_length=clip_max_length
- )
- self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
- print(
- f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
- f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params."
- )
-
- def encode(self, text):
- return self(text)
-
- def forward(self, text):
- clip_z = self.clip_encoder.encode(text)
- t5_z = self.t5_encoder.encode(text)
- return [clip_z, t5_z]
diff --git a/apps/third_party/CRM/imagedream/ldm/util.py b/apps/third_party/CRM/imagedream/ldm/util.py
deleted file mode 100644
index 8ed44393d153aad633175d3afda2a1d923c7d815..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/imagedream/ldm/util.py
+++ /dev/null
@@ -1,231 +0,0 @@
-import importlib
-
-import random
-import torch
-import numpy as np
-from collections import abc
-
-import multiprocessing as mp
-from threading import Thread
-from queue import Queue
-
-from inspect import isfunction
-from PIL import Image, ImageDraw, ImageFont
-
-
-def log_txt_as_img(wh, xc, size=10):
- # wh a tuple of (width, height)
- # xc a list of captions to plot
- b = len(xc)
- txts = list()
- for bi in range(b):
- txt = Image.new("RGB", wh, color="white")
- draw = ImageDraw.Draw(txt)
- font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
- nc = int(40 * (wh[0] / 256))
- lines = "\n".join(
- xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
- )
-
- try:
- draw.text((0, 0), lines, fill="black", font=font)
- except UnicodeEncodeError:
- print("Cant encode string for logging. Skipping.")
-
- txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
- txts.append(txt)
- txts = np.stack(txts)
- txts = torch.tensor(txts)
- return txts
-
-
-def ismap(x):
- if not isinstance(x, torch.Tensor):
- return False
- return (len(x.shape) == 4) and (x.shape[1] > 3)
-
-
-def isimage(x):
- if not isinstance(x, torch.Tensor):
- return False
- return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
-
-
-def exists(x):
- return x is not None
-
-
-def default(val, d):
- if exists(val):
- return val
- return d() if isfunction(d) else d
-
-
-def mean_flat(tensor):
- """
- https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
- Take the mean over all non-batch dimensions.
- """
- return tensor.mean(dim=list(range(1, len(tensor.shape))))
-
-
-def count_params(model, verbose=False):
- total_params = sum(p.numel() for p in model.parameters())
- if verbose:
- print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
- return total_params
-
-
-def instantiate_from_config(config):
- if not "target" in config:
- if config == "__is_first_stage__":
- return None
- elif config == "__is_unconditional__":
- return None
- raise KeyError("Expected key `target` to instantiate.")
- # import pdb; pdb.set_trace()
- return get_obj_from_str(config["target"])(**config.get("params", dict()))
-
-
-def get_obj_from_str(string, reload=False):
- module, cls = string.rsplit(".", 1)
- # import pdb; pdb.set_trace()
- if reload:
- module_imp = importlib.import_module(module)
- importlib.reload(module_imp)
-
- if 'imagedream' in module:
- module = 'apps.third_party.CRM.'+module
- if 'lib' in module:
- module = 'apps.third_party.CRM.'+module
- return getattr(importlib.import_module(module, package=None), cls)
-
-
-def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
- # create dummy dataset instance
-
- # run prefetching
- if idx_to_fn:
- res = func(data, worker_id=idx)
- else:
- res = func(data)
- Q.put([idx, res])
- Q.put("Done")
-
-
-def parallel_data_prefetch(
- func: callable,
- data,
- n_proc,
- target_data_type="ndarray",
- cpu_intensive=True,
- use_worker_id=False,
-):
- # if target_data_type not in ["ndarray", "list"]:
- # raise ValueError(
- # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
- # )
- if isinstance(data, np.ndarray) and target_data_type == "list":
- raise ValueError("list expected but function got ndarray.")
- elif isinstance(data, abc.Iterable):
- if isinstance(data, dict):
- print(
- f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
- )
- data = list(data.values())
- if target_data_type == "ndarray":
- data = np.asarray(data)
- else:
- data = list(data)
- else:
- raise TypeError(
- f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
- )
-
- if cpu_intensive:
- Q = mp.Queue(1000)
- proc = mp.Process
- else:
- Q = Queue(1000)
- proc = Thread
- # spawn processes
- if target_data_type == "ndarray":
- arguments = [
- [func, Q, part, i, use_worker_id]
- for i, part in enumerate(np.array_split(data, n_proc))
- ]
- else:
- step = (
- int(len(data) / n_proc + 1)
- if len(data) % n_proc != 0
- else int(len(data) / n_proc)
- )
- arguments = [
- [func, Q, part, i, use_worker_id]
- for i, part in enumerate(
- [data[i : i + step] for i in range(0, len(data), step)]
- )
- ]
- processes = []
- for i in range(n_proc):
- p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
- processes += [p]
-
- # start processes
- print(f"Start prefetching...")
- import time
-
- start = time.time()
- gather_res = [[] for _ in range(n_proc)]
- try:
- for p in processes:
- p.start()
-
- k = 0
- while k < n_proc:
- # get result
- res = Q.get()
- if res == "Done":
- k += 1
- else:
- gather_res[res[0]] = res[1]
-
- except Exception as e:
- print("Exception: ", e)
- for p in processes:
- p.terminate()
-
- raise e
- finally:
- for p in processes:
- p.join()
- print(f"Prefetching complete. [{time.time() - start} sec.]")
-
- if target_data_type == "ndarray":
- if not isinstance(gather_res[0], np.ndarray):
- return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
-
- # order outputs
- return np.concatenate(gather_res, axis=0)
- elif target_data_type == "list":
- out = []
- for r in gather_res:
- out.extend(r)
- return out
- else:
- return gather_res
-
-def set_seed(seed=None):
- random.seed(seed)
- np.random.seed(seed)
- if seed is not None:
- torch.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
-
-def add_random_background(image, bg_color=None):
- bg_color = np.random.rand() * 255 if bg_color is None else bg_color
- image = np.array(image)
- rgb, alpha = image[..., :3], image[..., 3:]
- alpha = alpha.astype(np.float32) / 255.0
- image_new = rgb * alpha + bg_color * (1 - alpha)
- return Image.fromarray(image_new.astype(np.uint8))
\ No newline at end of file
diff --git a/apps/third_party/CRM/imagedream/model_zoo.py b/apps/third_party/CRM/imagedream/model_zoo.py
deleted file mode 100644
index 45d6b678bdc554f5b2ad19903f8ed9976ece024e..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/imagedream/model_zoo.py
+++ /dev/null
@@ -1,64 +0,0 @@
-""" Utiliy functions to load pre-trained models more easily """
-import os
-import pkg_resources
-from omegaconf import OmegaConf
-
-import torch
-from huggingface_hub import hf_hub_download
-
-from imagedream.ldm.util import instantiate_from_config
-
-
-PRETRAINED_MODELS = {
- "sd-v2.1-base-4view-ipmv": {
- "config": "sd_v2_base_ipmv.yaml",
- "repo_id": "Peng-Wang/ImageDream",
- "filename": "sd-v2.1-base-4view-ipmv.pt",
- },
- "sd-v2.1-base-4view-ipmv-local": {
- "config": "sd_v2_base_ipmv_local.yaml",
- "repo_id": "Peng-Wang/ImageDream",
- "filename": "sd-v2.1-base-4view-ipmv-local.pt",
- },
-}
-
-
-def get_config_file(config_path):
- cfg_file = pkg_resources.resource_filename(
- "imagedream", os.path.join("configs", config_path)
- )
- if not os.path.exists(cfg_file):
- raise RuntimeError(f"Config {config_path} not available!")
- return cfg_file
-
-
-def build_model(model_name, config_path=None, ckpt_path=None, cache_dir=None):
- if (config_path is not None) and (ckpt_path is not None):
- config = OmegaConf.load(config_path)
- model = instantiate_from_config(config.model)
- model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
- return model
-
- if not model_name in PRETRAINED_MODELS:
- raise RuntimeError(
- f"Model name {model_name} is not a pre-trained model. Available models are:\n- "
- + "\n- ".join(PRETRAINED_MODELS.keys())
- )
- model_info = PRETRAINED_MODELS[model_name]
-
- # Instiantiate the model
- print(f"Loading model from config: {model_info['config']}")
- config_file = get_config_file(model_info["config"])
- config = OmegaConf.load(config_file)
- model = instantiate_from_config(config.model)
-
- # Load pre-trained checkpoint from huggingface
- if not ckpt_path:
- ckpt_path = hf_hub_download(
- repo_id=model_info["repo_id"],
- filename=model_info["filename"],
- cache_dir=cache_dir,
- )
- print(f"Loading model from cache file: {ckpt_path}")
- model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
- return model
diff --git a/apps/third_party/CRM/inference.py b/apps/third_party/CRM/inference.py
deleted file mode 100644
index a6fc2a9e49d606dc44d0cb8ae0cfbf223b44dcd4..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/inference.py
+++ /dev/null
@@ -1,91 +0,0 @@
-import numpy as np
-import torch
-import time
-import nvdiffrast.torch as dr
-from util.utils import get_tri
-import tempfile
-from mesh import Mesh
-import zipfile
-def generate3d(model, rgb, ccm, device):
-
- color_tri = torch.from_numpy(rgb)/255
- xyz_tri = torch.from_numpy(ccm[:,:,(2,1,0)])/255
- color = color_tri.permute(2,0,1)
- xyz = xyz_tri.permute(2,0,1)
-
-
- def get_imgs(color):
- # color : [C, H, W*6]
- color_list = []
- color_list.append(color[:,:,256*5:256*(1+5)])
- for i in range(0,5):
- color_list.append(color[:,:,256*i:256*(1+i)])
- return torch.stack(color_list, dim=0)# [6, C, H, W]
-
- triplane_color = get_imgs(color).permute(0,2,3,1).unsqueeze(0).to(device)# [1, 6, H, W, C]
-
- color = get_imgs(color)
- xyz = get_imgs(xyz)
-
- color = get_tri(color, dim=0, blender= True, scale = 1).unsqueeze(0)
- xyz = get_tri(xyz, dim=0, blender= True, scale = 1, fix= True).unsqueeze(0)
-
- triplane = torch.cat([color,xyz],dim=1).to(device)
- # 3D visualize
- model.eval()
- glctx = dr.RasterizeCudaContext()
-
- if model.denoising == True:
- tnew = 20
- tnew = torch.randint(tnew, tnew+1, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
- noise_new = torch.randn_like(triplane) *0.5+0.5
- triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
- start_time = time.time()
- with torch.no_grad():
- triplane_feature2 = model.unet2(triplane,tnew)
- end_time = time.time()
- elapsed_time = end_time - start_time
- print(f"unet takes {elapsed_time}s")
- else:
- triplane_feature2 = model.unet2(triplane)
-
-
- with torch.no_grad():
- data_config = {
- 'resolution': [1024, 1024],
- "triview_color": triplane_color.to(device),
- }
-
- verts, faces = model.decode(data_config, triplane_feature2)
-
- data_config['verts'] = verts[0]
- data_config['faces'] = faces
-
-
- from kiui.mesh_utils import clean_mesh
- verts, faces = clean_mesh(data_config['verts'].squeeze().cpu().numpy().astype(np.float32), data_config['faces'].squeeze().cpu().numpy().astype(np.int32), repair = False, remesh=False, remesh_size=0.005)
- data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
- data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()
-
- start_time = time.time()
- with torch.no_grad():
- mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
- model.export_mesh_wt_uv(glctx, data_config, mesh_path_obj, "", device, res=(1024,1024), tri_fea_2=triplane_feature2)
-
- mesh = Mesh.load(mesh_path_obj+".obj", bound=0.9, front_dir="+z")
- mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
- mesh.write(mesh_path_glb+".glb")
-
- # mesh_obj2 = trimesh.load(mesh_path_glb+".glb", file_type='glb')
- # mesh_path_obj2 = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
- # mesh_obj2.export(mesh_path_obj2+".obj")
-
- with zipfile.ZipFile(mesh_path_obj+'.zip', 'w') as myzip:
- myzip.write(mesh_path_obj+'.obj', mesh_path_obj.split("/")[-1]+'.obj')
- myzip.write(mesh_path_obj+'.png', mesh_path_obj.split("/")[-1]+'.png')
- myzip.write(mesh_path_obj+'.mtl', mesh_path_obj.split("/")[-1]+'.mtl')
-
- end_time = time.time()
- elapsed_time = end_time - start_time
- print(f"uv takes {elapsed_time}s")
- return mesh_path_glb+".glb", mesh_path_obj+'.zip'
diff --git a/apps/third_party/CRM/libs/__init__.py b/apps/third_party/CRM/libs/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/apps/third_party/CRM/libs/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/libs/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index 377a573aa67f7cc53af4c8b05123f278611521fc..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/libs/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/libs/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/libs/__pycache__/__init__.cpython-38.pyc
deleted file mode 100644
index 1b54287998f0608201936c65b8676e9a2f8d2c9f..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/libs/__pycache__/__init__.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/libs/__pycache__/base_utils.cpython-310.pyc b/apps/third_party/CRM/libs/__pycache__/base_utils.cpython-310.pyc
deleted file mode 100644
index b19825c00604c13eafa9ac4db0ce70a9bf04e1df..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/libs/__pycache__/base_utils.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/libs/__pycache__/base_utils.cpython-38.pyc b/apps/third_party/CRM/libs/__pycache__/base_utils.cpython-38.pyc
deleted file mode 100644
index fc8c696f4c72d09559bcbc1519c671680fd92a35..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/libs/__pycache__/base_utils.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/libs/__pycache__/sample.cpython-310.pyc b/apps/third_party/CRM/libs/__pycache__/sample.cpython-310.pyc
deleted file mode 100644
index 1b872ef355b5c9cac32066264ac6d3fe8e6f49d0..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/libs/__pycache__/sample.cpython-310.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/libs/__pycache__/sample.cpython-38.pyc b/apps/third_party/CRM/libs/__pycache__/sample.cpython-38.pyc
deleted file mode 100644
index 888166c9fae24f2fd70b9939ac5803ec24ab5895..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/libs/__pycache__/sample.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/libs/base_utils.py b/apps/third_party/CRM/libs/base_utils.py
deleted file mode 100644
index c90a548286e80ded1f317126d8f560f67a85e0d8..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/libs/base_utils.py
+++ /dev/null
@@ -1,84 +0,0 @@
-import numpy as np
-import cv2
-import torch
-import numpy as np
-from PIL import Image
-
-
-def instantiate_from_config(config):
- if not "target" in config:
- raise KeyError("Expected key `target` to instantiate.")
- return get_obj_from_str(config["target"])(**config.get("params", dict()))
-
-
-def get_obj_from_str(string, reload=False):
- import importlib
- module, cls = string.rsplit(".", 1)
- if reload:
- module_imp = importlib.import_module(module)
- importlib.reload(module_imp)
- return getattr(importlib.import_module(module, package=None), cls)
-
-
-def tensor_detail(t):
- assert type(t) == torch.Tensor
- print(f"shape: {t.shape} mean: {t.mean():.2f}, std: {t.std():.2f}, min: {t.min():.2f}, max: {t.max():.2f}")
-
-
-
-def drawRoundRec(draw, color, x, y, w, h, r):
- drawObject = draw
-
- '''Rounds'''
- drawObject.ellipse((x, y, x + r, y + r), fill=color)
- drawObject.ellipse((x + w - r, y, x + w, y + r), fill=color)
- drawObject.ellipse((x, y + h - r, x + r, y + h), fill=color)
- drawObject.ellipse((x + w - r, y + h - r, x + w, y + h), fill=color)
-
- '''rec.s'''
- drawObject.rectangle((x + r / 2, y, x + w - (r / 2), y + h), fill=color)
- drawObject.rectangle((x, y + r / 2, x + w, y + h - (r / 2)), fill=color)
-
-
-def do_resize_content(original_image: Image, scale_rate):
- # resize image content wile retain the original image size
- if scale_rate != 1:
- # Calculate the new size after rescaling
- new_size = tuple(int(dim * scale_rate) for dim in original_image.size)
- # Resize the image while maintaining the aspect ratio
- resized_image = original_image.resize(new_size)
- # Create a new image with the original size and black background
- padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0))
- paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2)
- padded_image.paste(resized_image, paste_position)
- return padded_image
- else:
- return original_image
-
-def add_stroke(img, color=(255, 255, 255), stroke_radius=3):
- # color in R, G, B format
- if isinstance(img, Image.Image):
- assert img.mode == "RGBA"
- img = cv2.cvtColor(np.array(img), cv2.COLOR_RGBA2BGRA)
- else:
- assert img.shape[2] == 4
- gray = img[:,:, 3]
- ret, binary = cv2.threshold(gray,127,255,cv2.THRESH_BINARY)
- contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
- res = cv2.drawContours(img, contours,-1, tuple(color)[::-1] + (255,), stroke_radius)
- return Image.fromarray(cv2.cvtColor(res,cv2.COLOR_BGRA2RGBA))
-
-def make_blob(image_size=(512, 512), sigma=0.2):
- """
- make 2D blob image with:
- I(x, y)=1-\exp \left(-\frac{(x-H / 2)^2+(y-W / 2)^2}{2 \sigma^2 HS}\right)
- """
- import numpy as np
- H, W = image_size
- x = np.arange(0, W, 1, float)
- y = np.arange(0, H, 1, float)
- x, y = np.meshgrid(x, y)
- x0 = W // 2
- y0 = H // 2
- img = 1 - np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2 * H * W))
- return (img * 255).astype(np.uint8)
\ No newline at end of file
diff --git a/apps/third_party/CRM/libs/sample.py b/apps/third_party/CRM/libs/sample.py
deleted file mode 100644
index 6a2e5cc453796c10165401bea47814113b6381ba..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/libs/sample.py
+++ /dev/null
@@ -1,384 +0,0 @@
-import numpy as np
-import torch
-from imagedream.camera_utils import get_camera_for_index
-from imagedream.ldm.util import set_seed, add_random_background
-# import os
-# import sys
-# proj_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
-# sys.path.append(proj_dir)
-from apps.third_party.CRM.libs.base_utils import do_resize_content
-from imagedream.ldm.models.diffusion.ddim import DDIMSampler
-from torchvision import transforms as T
-
-
-class ImageDreamDiffusion:
- def __init__(
- self,
- model,
- device,
- dtype,
- mode,
- num_frames,
- camera_views,
- ref_position,
- random_background=False,
- offset_noise=False,
- resize_rate=1,
- image_size=256,
- seed=1234,
- ) -> None:
- assert mode in ["pixel", "local"]
- size = image_size
- self.seed = seed
- batch_size = max(4, num_frames)
-
- neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear."
- uc = model.get_learned_conditioning([neg_texts]).to(device)
- sampler = DDIMSampler(model)
-
- # pre-compute camera matrices
- camera = [get_camera_for_index(i).squeeze() for i in camera_views]
- camera[ref_position] = torch.zeros_like(camera[ref_position]) # set ref camera to zero
- camera = torch.stack(camera)
- camera = camera.repeat(batch_size // num_frames, 1).to(device)
-
- self.image_transform = T.Compose(
- [
- T.Resize((size, size)),
- T.ToTensor(),
- T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
- ]
- )
- self.dtype = dtype
- self.ref_position = ref_position
- self.mode = mode
- self.random_background = random_background
- self.resize_rate = resize_rate
- self.num_frames = num_frames
- self.size = size
- self.device = device
- self.batch_size = batch_size
- self.model = model
- self.sampler = sampler
- self.uc = uc
- self.camera = camera
- self.offset_noise = offset_noise
-
- @staticmethod
- def i2i(
- model,
- image_size,
- prompt,
- uc,
- sampler,
- ip=None,
- step=20,
- scale=5.0,
- batch_size=8,
- ddim_eta=0.0,
- dtype=torch.float32,
- device="cuda",
- camera=None,
- num_frames=4,
- pixel_control=False,
- transform=None,
- offset_noise=False,
- ):
- """ The function supports additional image prompt.
- Args:
- model (_type_): the image dream model
- image_size (_type_): size of diffusion output (standard 256)
- prompt (_type_): text prompt for the image (prompt in type str)
- uc (_type_): unconditional vector (tensor in shape [1, 77, 1024])
- sampler (_type_): imagedream.ldm.models.diffusion.ddim.DDIMSampler
- ip (Image, optional): the image prompt. Defaults to None.
- step (int, optional): _description_. Defaults to 20.
- scale (float, optional): _description_. Defaults to 7.5.
- batch_size (int, optional): _description_. Defaults to 8.
- ddim_eta (float, optional): _description_. Defaults to 0.0.
- dtype (_type_, optional): _description_. Defaults to torch.float32.
- device (str, optional): _description_. Defaults to "cuda".
- camera (_type_, optional): camera info in tensor, shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00
- num_frames (int, optional): _num of frames (views) to generate
- pixel_control: whether to use pixel conditioning. Defaults to False, True when using pixel mode
- transform: Compose(
- Resize(size=(256, 256), interpolation=bilinear, max_size=None, antialias=warn)
- ToTensor()
- Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
- )
- """
- ip_raw = ip
- if type(prompt) != list:
- prompt = [prompt]
- with torch.no_grad(), torch.autocast(device_type=torch.device(device).type, dtype=dtype):
- c = model.get_learned_conditioning(prompt).to(
- device
- ) # shape: torch.Size([1, 77, 1024]) mean: -0.17, std: 1.02, min: -7.50, max: 13.05
- c_ = {"context": c.repeat(batch_size, 1, 1)} # batch_size
- uc_ = {"context": uc.repeat(batch_size, 1, 1)}
-
- if camera is not None:
- c_["camera"] = uc_["camera"] = (
- camera # shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00
- )
- c_["num_frames"] = uc_["num_frames"] = num_frames
-
- if ip is not None:
- ip_embed = model.get_learned_image_conditioning(ip).to(
- device
- ) # shape: torch.Size([1, 257, 1280]) mean: 0.06, std: 0.53, min: -6.83, max: 11.12
- ip_ = ip_embed.repeat(batch_size, 1, 1)
- c_["ip"] = ip_
- uc_["ip"] = torch.zeros_like(ip_)
-
- if pixel_control:
- assert camera is not None
- ip = transform(ip).to(
- device
- ) # shape: torch.Size([3, 256, 256]) mean: 0.33, std: 0.37, min: -1.00, max: 1.00
- ip_img = model.get_first_stage_encoding(
- model.encode_first_stage(ip[None, :, :, :])
- ) # shape: torch.Size([1, 4, 32, 32]) mean: 0.23, std: 0.77, min: -4.42, max: 3.55
- c_["ip_img"] = ip_img
- uc_["ip_img"] = torch.zeros_like(ip_img)
-
- shape = [4, image_size // 8, image_size // 8] # [4, 32, 32]
- if offset_noise:
- ref = transform(ip_raw).to(device)
- ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref[None, :, :, :]))
- ref_mean = ref_latent.mean(dim=(-1, -2), keepdim=True)
- time_steps = torch.randint(model.num_timesteps - 1, model.num_timesteps, (batch_size,), device=device)
- x_T = model.q_sample(torch.ones([batch_size] + shape, device=device) * ref_mean, time_steps)
-
- samples_ddim, _ = (
- sampler.sample( # shape: torch.Size([5, 4, 32, 32]) mean: 0.29, std: 0.85, min: -3.38, max: 4.43
- S=step,
- conditioning=c_,
- batch_size=batch_size,
- shape=shape,
- verbose=False,
- unconditional_guidance_scale=scale,
- unconditional_conditioning=uc_,
- eta=ddim_eta,
- x_T=x_T if offset_noise else None,
- )
- )
-
- x_sample = model.decode_first_stage(samples_ddim)
- x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
- x_sample = 255.0 * x_sample.permute(0, 2, 3, 1).cpu().numpy()
-
- return list(x_sample.astype(np.uint8))
-
- def diffuse(self, t, ip, n_test=2):
- set_seed(self.seed)
- ip = do_resize_content(ip, self.resize_rate)
- if self.random_background:
- ip = add_random_background(ip)
-
- images = []
- for _ in range(n_test):
- img = self.i2i(
- self.model,
- self.size,
- t,
- self.uc,
- self.sampler,
- ip=ip,
- step=50,
- scale=5,
- batch_size=self.batch_size,
- ddim_eta=0.0,
- dtype=self.dtype,
- device=self.device,
- camera=self.camera,
- num_frames=self.num_frames,
- pixel_control=(self.mode == "pixel"),
- transform=self.image_transform,
- offset_noise=self.offset_noise,
- )
- img = np.concatenate(img, 1)
- img = np.concatenate((img, ip.resize((self.size, self.size))), axis=1)
- images.append(img)
- set_seed() # unset random and numpy seed
- return images
-
-
-class ImageDreamDiffusionStage2:
- def __init__(
- self,
- model,
- device,
- dtype,
- num_frames,
- camera_views,
- ref_position,
- random_background=False,
- offset_noise=False,
- resize_rate=1,
- mode="pixel",
- image_size=256,
- seed=1234,
- ) -> None:
- assert mode in ["pixel", "local"]
-
- size = image_size
- self.seed = seed
- batch_size = max(4, num_frames)
-
- neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear."
- uc = model.get_learned_conditioning([neg_texts]).to(device)
- sampler = DDIMSampler(model)
-
- # pre-compute camera matrices
- camera = [get_camera_for_index(i).squeeze() for i in camera_views]
- if ref_position is not None:
- camera[ref_position] = torch.zeros_like(camera[ref_position]) # set ref camera to zero
- camera = torch.stack(camera)
- camera = camera.repeat(batch_size // num_frames, 1).to(device)
-
- self.image_transform = T.Compose(
- [
- T.Resize((size, size)),
- T.ToTensor(),
- T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
- ]
- )
-
- self.dtype = dtype
- self.mode = mode
- self.ref_position = ref_position
- self.random_background = random_background
- self.resize_rate = resize_rate
- self.num_frames = num_frames
- self.size = size
- self.device = device
- self.batch_size = batch_size
- self.model = model
- self.sampler = sampler
- self.uc = uc
- self.camera = camera
- self.offset_noise = offset_noise
-
- @staticmethod
- def i2iStage2(
- model,
- image_size,
- prompt,
- uc,
- sampler,
- pixel_images,
- ip=None,
- step=20,
- scale=5.0,
- batch_size=8,
- ddim_eta=0.0,
- dtype=torch.float32,
- device="cuda",
- camera=None,
- num_frames=4,
- pixel_control=False,
- transform=None,
- offset_noise=False,
- ):
- ip_raw = ip
- if type(prompt) != list:
- prompt = [prompt]
- with torch.no_grad(), torch.autocast(device_type=torch.device(device).type, dtype=dtype):
- c = model.get_learned_conditioning(prompt).to(
- device
- ) # shape: torch.Size([1, 77, 1024]) mean: -0.17, std: 1.02, min: -7.50, max: 13.05
- c_ = {"context": c.repeat(batch_size, 1, 1)} # batch_size
- uc_ = {"context": uc.repeat(batch_size, 1, 1)}
-
- if camera is not None:
- c_["camera"] = uc_["camera"] = (
- camera # shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00
- )
- c_["num_frames"] = uc_["num_frames"] = num_frames
-
- if ip is not None:
- ip_embed = model.get_learned_image_conditioning(ip).to(
- device
- ) # shape: torch.Size([1, 257, 1280]) mean: 0.06, std: 0.53, min: -6.83, max: 11.12
- ip_ = ip_embed.repeat(batch_size, 1, 1)
- c_["ip"] = ip_
- uc_["ip"] = torch.zeros_like(ip_)
-
- if pixel_control:
- assert camera is not None
-
- transed_pixel_images = torch.stack([transform(i).to(device) for i in pixel_images])
- latent_pixel_images = model.get_first_stage_encoding(model.encode_first_stage(transed_pixel_images))
-
- c_["pixel_images"] = latent_pixel_images
- uc_["pixel_images"] = torch.zeros_like(latent_pixel_images)
-
- shape = [4, image_size // 8, image_size // 8] # [4, 32, 32]
- if offset_noise:
- ref = transform(ip_raw).to(device)
- ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref[None, :, :, :]))
- ref_mean = ref_latent.mean(dim=(-1, -2), keepdim=True)
- time_steps = torch.randint(model.num_timesteps - 1, model.num_timesteps, (batch_size,), device=device)
- x_T = model.q_sample(torch.ones([batch_size] + shape, device=device) * ref_mean, time_steps)
-
- samples_ddim, _ = (
- sampler.sample( # shape: torch.Size([5, 4, 32, 32]) mean: 0.29, std: 0.85, min: -3.38, max: 4.43
- S=step,
- conditioning=c_,
- batch_size=batch_size,
- shape=shape,
- verbose=False,
- unconditional_guidance_scale=scale,
- unconditional_conditioning=uc_,
- eta=ddim_eta,
- x_T=x_T if offset_noise else None,
- )
- )
- x_sample = model.decode_first_stage(samples_ddim)
- x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
- x_sample = 255.0 * x_sample.permute(0, 2, 3, 1).cpu().numpy()
-
- return list(x_sample.astype(np.uint8))
-
- @torch.no_grad()
- def diffuse(self, t, ip, pixel_images, n_test=2):
- set_seed(self.seed)
- ip = do_resize_content(ip, self.resize_rate)
- pixel_images = [do_resize_content(i, self.resize_rate) for i in pixel_images]
-
- if self.random_background:
- bg_color = np.random.rand() * 255
- ip = add_random_background(ip, bg_color)
- pixel_images = [add_random_background(i, bg_color) for i in pixel_images]
-
- images = []
- for _ in range(n_test):
- img = self.i2iStage2(
- self.model,
- self.size,
- t,
- self.uc,
- self.sampler,
- pixel_images=pixel_images,
- ip=ip,
- step=50,
- scale=5,
- batch_size=self.batch_size,
- ddim_eta=0.0,
- dtype=self.dtype,
- device=self.device,
- camera=self.camera,
- num_frames=self.num_frames,
- pixel_control=(self.mode == "pixel"),
- transform=self.image_transform,
- offset_noise=self.offset_noise,
- )
- img = np.concatenate(img, 1)
- img = np.concatenate(
- (img, ip.resize((self.size, self.size)), *[i.resize((self.size, self.size)) for i in pixel_images]),
- axis=1,
- )
- images.append(img)
- set_seed() # unset random and numpy seed
- return images
diff --git a/apps/third_party/CRM/mesh.py b/apps/third_party/CRM/mesh.py
deleted file mode 100644
index b98dea041fb41d207b6e95ed927216344854d25c..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/mesh.py
+++ /dev/null
@@ -1,845 +0,0 @@
-import os
-import cv2
-import torch
-import trimesh
-import numpy as np
-
-from kiui.op import safe_normalize, dot
-from kiui.typing import *
-
-class Mesh:
- """
- A torch-native trimesh class, with support for ``ply/obj/glb`` formats.
-
- Note:
- This class only supports one mesh with a single texture image (an albedo texture and a metallic-roughness texture).
- """
- def __init__(
- self,
- v: Optional[Tensor] = None,
- f: Optional[Tensor] = None,
- vn: Optional[Tensor] = None,
- fn: Optional[Tensor] = None,
- vt: Optional[Tensor] = None,
- ft: Optional[Tensor] = None,
- vc: Optional[Tensor] = None, # vertex color
- albedo: Optional[Tensor] = None,
- metallicRoughness: Optional[Tensor] = None,
- device: Optional[torch.device] = None,
- ):
- """Init a mesh directly using all attributes.
-
- Args:
- v (Optional[Tensor]): vertices, float [N, 3]. Defaults to None.
- f (Optional[Tensor]): faces, int [M, 3]. Defaults to None.
- vn (Optional[Tensor]): vertex normals, float [N, 3]. Defaults to None.
- fn (Optional[Tensor]): faces for normals, int [M, 3]. Defaults to None.
- vt (Optional[Tensor]): vertex uv coordinates, float [N, 2]. Defaults to None.
- ft (Optional[Tensor]): faces for uvs, int [M, 3]. Defaults to None.
- vc (Optional[Tensor]): vertex colors, float [N, 3]. Defaults to None.
- albedo (Optional[Tensor]): albedo texture, float [H, W, 3], RGB format. Defaults to None.
- metallicRoughness (Optional[Tensor]): metallic-roughness texture, float [H, W, 3], metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]. Defaults to None.
- device (Optional[torch.device]): torch device. Defaults to None.
- """
- self.device = device
- self.v = v
- self.vn = vn
- self.vt = vt
- self.f = f
- self.fn = fn
- self.ft = ft
- # will first see if there is vertex color to use
- self.vc = vc
- # only support a single albedo image
- self.albedo = albedo
- # pbr extension, metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]
- # ref: https://registry.khronos.org/glTF/specs/2.0/glTF-2.0.html
- self.metallicRoughness = metallicRoughness
-
- self.ori_center = 0
- self.ori_scale = 1
-
- @classmethod
- def load(cls, path, resize=True, clean=False, renormal=True, retex=False, bound=0.9, front_dir='+z', **kwargs):
- """load mesh from path.
-
- Args:
- path (str): path to mesh file, supports ply, obj, glb.
- clean (bool, optional): perform mesh cleaning at load (e.g., merge close vertices). Defaults to False.
- resize (bool, optional): auto resize the mesh using ``bound`` into [-bound, bound]^3. Defaults to True.
- renormal (bool, optional): re-calc the vertex normals. Defaults to True.
- retex (bool, optional): re-calc the uv coordinates, will overwrite the existing uv coordinates. Defaults to False.
- bound (float, optional): bound to resize. Defaults to 0.9.
- front_dir (str, optional): front-view direction of the mesh, should be [+-][xyz][ 123]. Defaults to '+z'.
- device (torch.device, optional): torch device. Defaults to None.
-
- Note:
- a ``device`` keyword argument can be provided to specify the torch device.
- If it's not provided, we will try to use ``'cuda'`` as the device if it's available.
-
- Returns:
- Mesh: the loaded Mesh object.
- """
- # obj supports face uv
- if path.endswith(".obj"):
- mesh = cls.load_obj(path, **kwargs)
- # trimesh only supports vertex uv, but can load more formats
- else:
- mesh = cls.load_trimesh(path, **kwargs)
-
- # clean
- if clean:
- from kiui.mesh_utils import clean_mesh
- vertices = mesh.v.detach().cpu().numpy()
- triangles = mesh.f.detach().cpu().numpy()
- vertices, triangles = clean_mesh(vertices, triangles, remesh=False)
- mesh.v = torch.from_numpy(vertices).contiguous().float().to(mesh.device)
- mesh.f = torch.from_numpy(triangles).contiguous().int().to(mesh.device)
-
- print(f"[Mesh loading] v: {mesh.v.shape}, f: {mesh.f.shape}")
- # auto-normalize
- if resize:
- mesh.auto_size(bound=bound)
- # auto-fix normal
- if renormal or mesh.vn is None:
- mesh.auto_normal()
- print(f"[Mesh loading] vn: {mesh.vn.shape}, fn: {mesh.fn.shape}")
- # auto-fix texcoords
- if retex or (mesh.albedo is not None and mesh.vt is None):
- mesh.auto_uv(cache_path=path)
- print(f"[Mesh loading] vt: {mesh.vt.shape}, ft: {mesh.ft.shape}")
-
- # rotate front dir to +z
- if front_dir != "+z":
- # axis switch
- if "-z" in front_dir:
- T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -1]], device=mesh.device, dtype=torch.float32)
- elif "+x" in front_dir:
- T = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
- elif "-x" in front_dir:
- T = torch.tensor([[0, 0, -1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
- elif "+y" in front_dir:
- T = torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
- elif "-y" in front_dir:
- T = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
- else:
- T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
- # rotation (how many 90 degrees)
- if '1' in front_dir:
- T @= torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
- elif '2' in front_dir:
- T @= torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
- elif '3' in front_dir:
- T @= torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
- mesh.v @= T
- mesh.vn @= T
-
- return mesh
-
- # load from obj file
- @classmethod
- def load_obj(cls, path, albedo_path=None, device=None):
- """load an ``obj`` mesh.
-
- Args:
- path (str): path to mesh.
- albedo_path (str, optional): path to the albedo texture image, will overwrite the existing texture path if specified in mtl. Defaults to None.
- device (torch.device, optional): torch device. Defaults to None.
-
- Note:
- We will try to read `mtl` path from `obj`, else we assume the file name is the same as `obj` but with `mtl` extension.
- The `usemtl` statement is ignored, and we only use the last material path in `mtl` file.
-
- Returns:
- Mesh: the loaded Mesh object.
- """
- assert os.path.splitext(path)[-1] == ".obj"
-
- mesh = cls()
-
- # device
- if device is None:
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- mesh.device = device
-
- # load obj
- with open(path, "r") as f:
- lines = f.readlines()
-
- def parse_f_v(fv):
- # pass in a vertex term of a face, return {v, vt, vn} (-1 if not provided)
- # supported forms:
- # f v1 v2 v3
- # f v1/vt1 v2/vt2 v3/vt3
- # f v1/vt1/vn1 v2/vt2/vn2 v3/vt3/vn3
- # f v1//vn1 v2//vn2 v3//vn3
- xs = [int(x) - 1 if x != "" else -1 for x in fv.split("/")]
- xs.extend([-1] * (3 - len(xs)))
- return xs[0], xs[1], xs[2]
-
- vertices, texcoords, normals = [], [], []
- faces, tfaces, nfaces = [], [], []
- mtl_path = None
-
- for line in lines:
- split_line = line.split()
- # empty line
- if len(split_line) == 0:
- continue
- prefix = split_line[0].lower()
- # mtllib
- if prefix == "mtllib":
- mtl_path = split_line[1]
- # usemtl
- elif prefix == "usemtl":
- pass # ignored
- # v/vn/vt
- elif prefix == "v":
- vertices.append([float(v) for v in split_line[1:]])
- elif prefix == "vn":
- normals.append([float(v) for v in split_line[1:]])
- elif prefix == "vt":
- val = [float(v) for v in split_line[1:]]
- texcoords.append([val[0], 1.0 - val[1]])
- elif prefix == "f":
- vs = split_line[1:]
- nv = len(vs)
- v0, t0, n0 = parse_f_v(vs[0])
- for i in range(nv - 2): # triangulate (assume vertices are ordered)
- v1, t1, n1 = parse_f_v(vs[i + 1])
- v2, t2, n2 = parse_f_v(vs[i + 2])
- faces.append([v0, v1, v2])
- tfaces.append([t0, t1, t2])
- nfaces.append([n0, n1, n2])
-
- mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
- mesh.vt = (
- torch.tensor(texcoords, dtype=torch.float32, device=device)
- if len(texcoords) > 0
- else None
- )
- mesh.vn = (
- torch.tensor(normals, dtype=torch.float32, device=device)
- if len(normals) > 0
- else None
- )
-
- mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
- mesh.ft = (
- torch.tensor(tfaces, dtype=torch.int32, device=device)
- if len(texcoords) > 0
- else None
- )
- mesh.fn = (
- torch.tensor(nfaces, dtype=torch.int32, device=device)
- if len(normals) > 0
- else None
- )
-
- # see if there is vertex color
- use_vertex_color = False
- if mesh.v.shape[1] == 6:
- use_vertex_color = True
- mesh.vc = mesh.v[:, 3:]
- mesh.v = mesh.v[:, :3]
- print(f"[load_obj] use vertex color: {mesh.vc.shape}")
-
- # try to load texture image
- if not use_vertex_color:
- # try to retrieve mtl file
- mtl_path_candidates = []
- if mtl_path is not None:
- mtl_path_candidates.append(mtl_path)
- mtl_path_candidates.append(os.path.join(os.path.dirname(path), mtl_path))
- mtl_path_candidates.append(path.replace(".obj", ".mtl"))
-
- mtl_path = None
- for candidate in mtl_path_candidates:
- if os.path.exists(candidate):
- mtl_path = candidate
- break
-
- # if albedo_path is not provided, try retrieve it from mtl
- metallic_path = None
- roughness_path = None
- if mtl_path is not None and albedo_path is None:
- with open(mtl_path, "r") as f:
- lines = f.readlines()
-
- for line in lines:
- split_line = line.split()
- # empty line
- if len(split_line) == 0:
- continue
- prefix = split_line[0]
-
- if "map_Kd" in prefix:
- # assume relative path!
- albedo_path = os.path.join(os.path.dirname(path), split_line[1])
- print(f"[load_obj] use texture from: {albedo_path}")
- elif "map_Pm" in prefix:
- metallic_path = os.path.join(os.path.dirname(path), split_line[1])
- elif "map_Pr" in prefix:
- roughness_path = os.path.join(os.path.dirname(path), split_line[1])
-
- # still not found albedo_path, or the path doesn't exist
- if albedo_path is None or not os.path.exists(albedo_path):
- # init an empty texture
- print(f"[load_obj] init empty albedo!")
- # albedo = np.random.rand(1024, 1024, 3).astype(np.float32)
- albedo = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) # default color
- else:
- albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED)
- albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB)
- albedo = albedo.astype(np.float32) / 255
- print(f"[load_obj] load texture: {albedo.shape}")
-
- mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device)
-
- # try to load metallic and roughness
- if metallic_path is not None and roughness_path is not None:
- print(f"[load_obj] load metallicRoughness from: {metallic_path}, {roughness_path}")
- metallic = cv2.imread(metallic_path, cv2.IMREAD_UNCHANGED)
- metallic = metallic.astype(np.float32) / 255
- roughness = cv2.imread(roughness_path, cv2.IMREAD_UNCHANGED)
- roughness = roughness.astype(np.float32) / 255
- metallicRoughness = np.stack([np.zeros_like(metallic), roughness, metallic], axis=-1)
-
- mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous()
-
- return mesh
-
- @classmethod
- def load_trimesh(cls, path, device=None):
- """load a mesh using ``trimesh.load()``.
-
- Can load various formats like ``glb`` and serves as a fallback.
-
- Note:
- We will try to merge all meshes if the glb contains more than one,
- but **this may cause the texture to lose**, since we only support one texture image!
-
- Args:
- path (str): path to the mesh file.
- device (torch.device, optional): torch device. Defaults to None.
-
- Returns:
- Mesh: the loaded Mesh object.
- """
- mesh = cls()
-
- # device
- if device is None:
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- mesh.device = device
-
- # use trimesh to load ply/glb
- _data = trimesh.load(path)
- if isinstance(_data, trimesh.Scene):
- if len(_data.geometry) == 1:
- _mesh = list(_data.geometry.values())[0]
- else:
- print(f"[load_trimesh] concatenating {len(_data.geometry)} meshes.")
- _concat = []
- # loop the scene graph and apply transform to each mesh
- scene_graph = _data.graph.to_flattened() # dict {name: {transform: 4x4 mat, geometry: str}}
- for k, v in scene_graph.items():
- name = v['geometry']
- if name in _data.geometry and isinstance(_data.geometry[name], trimesh.Trimesh):
- transform = v['transform']
- _concat.append(_data.geometry[name].apply_transform(transform))
- _mesh = trimesh.util.concatenate(_concat)
- else:
- _mesh = _data
-
- if _mesh.visual.kind == 'vertex':
- vertex_colors = _mesh.visual.vertex_colors
- vertex_colors = np.array(vertex_colors[..., :3]).astype(np.float32) / 255
- mesh.vc = torch.tensor(vertex_colors, dtype=torch.float32, device=device)
- print(f"[load_trimesh] use vertex color: {mesh.vc.shape}")
- elif _mesh.visual.kind == 'texture':
- _material = _mesh.visual.material
- if isinstance(_material, trimesh.visual.material.PBRMaterial):
- texture = np.array(_material.baseColorTexture).astype(np.float32) / 255
- # load metallicRoughness if present
- if _material.metallicRoughnessTexture is not None:
- metallicRoughness = np.array(_material.metallicRoughnessTexture).astype(np.float32) / 255
- mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous()
- elif isinstance(_material, trimesh.visual.material.SimpleMaterial):
- texture = np.array(_material.to_pbr().baseColorTexture).astype(np.float32) / 255
- else:
- raise NotImplementedError(f"material type {type(_material)} not supported!")
- mesh.albedo = torch.tensor(texture[..., :3], dtype=torch.float32, device=device).contiguous()
- print(f"[load_trimesh] load texture: {texture.shape}")
- else:
- texture = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5])
- mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device)
- print(f"[load_trimesh] failed to load texture.")
-
- vertices = _mesh.vertices
-
- try:
- texcoords = _mesh.visual.uv
- texcoords[:, 1] = 1 - texcoords[:, 1]
- except Exception as e:
- texcoords = None
-
- try:
- normals = _mesh.vertex_normals
- except Exception as e:
- normals = None
-
- # trimesh only support vertex uv...
- faces = tfaces = nfaces = _mesh.faces
-
- mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
- mesh.vt = (
- torch.tensor(texcoords, dtype=torch.float32, device=device)
- if texcoords is not None
- else None
- )
- mesh.vn = (
- torch.tensor(normals, dtype=torch.float32, device=device)
- if normals is not None
- else None
- )
-
- mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
- mesh.ft = (
- torch.tensor(tfaces, dtype=torch.int32, device=device)
- if texcoords is not None
- else None
- )
- mesh.fn = (
- torch.tensor(nfaces, dtype=torch.int32, device=device)
- if normals is not None
- else None
- )
-
- return mesh
-
- # sample surface (using trimesh)
- def sample_surface(self, count: int):
- """sample points on the surface of the mesh.
-
- Args:
- count (int): number of points to sample.
-
- Returns:
- torch.Tensor: the sampled points, float [count, 3].
- """
- _mesh = trimesh.Trimesh(vertices=self.v.detach().cpu().numpy(), faces=self.f.detach().cpu().numpy())
- points, face_idx = trimesh.sample.sample_surface(_mesh, count)
- points = torch.from_numpy(points).float().to(self.device)
- return points
-
- # aabb
- def aabb(self):
- """get the axis-aligned bounding box of the mesh.
-
- Returns:
- Tuple[torch.Tensor]: the min xyz and max xyz of the mesh.
- """
- return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values
-
- # unit size
- @torch.no_grad()
- def auto_size(self, bound=0.9):
- """auto resize the mesh.
-
- Args:
- bound (float, optional): resizing into ``[-bound, bound]^3``. Defaults to 0.9.
- """
- vmin, vmax = self.aabb()
- self.ori_center = (vmax + vmin) / 2
- self.ori_scale = 2 * bound / torch.max(vmax - vmin).item()
- self.v = (self.v - self.ori_center) * self.ori_scale
-
- def auto_normal(self):
- """auto calculate the vertex normals.
- """
- i0, i1, i2 = self.f[:, 0].long(), self.f[:, 1].long(), self.f[:, 2].long()
- v0, v1, v2 = self.v[i0, :], self.v[i1, :], self.v[i2, :]
-
- face_normals = torch.cross(v1 - v0, v2 - v0)
-
- # Splat face normals to vertices
- vn = torch.zeros_like(self.v)
- vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
- vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
- vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
-
- # Normalize, replace zero (degenerated) normals with some default value
- vn = torch.where(
- dot(vn, vn) > 1e-20,
- vn,
- torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device),
- )
- vn = safe_normalize(vn)
-
- self.vn = vn
- self.fn = self.f
-
- def auto_uv(self, cache_path=None, vmap=True):
- """auto calculate the uv coordinates.
-
- Args:
- cache_path (str, optional): path to save/load the uv cache as a npz file, this can avoid calculating uv every time when loading the same mesh, which is time-consuming. Defaults to None.
- vmap (bool, optional): remap vertices based on uv coordinates, so each v correspond to a unique vt (necessary for formats like gltf).
- Usually this will duplicate the vertices on the edge of uv atlas. Defaults to True.
- """
- # try to load cache
- if cache_path is not None:
- cache_path = os.path.splitext(cache_path)[0] + "_uv.npz"
- if cache_path is not None and os.path.exists(cache_path):
- data = np.load(cache_path)
- vt_np, ft_np, vmapping = data["vt"], data["ft"], data["vmapping"]
- else:
- import xatlas
-
- v_np = self.v.detach().cpu().numpy()
- f_np = self.f.detach().int().cpu().numpy()
- atlas = xatlas.Atlas()
- atlas.add_mesh(v_np, f_np)
- chart_options = xatlas.ChartOptions()
- # chart_options.max_iterations = 4
- atlas.generate(chart_options=chart_options)
- vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
-
- # save to cache
- if cache_path is not None:
- np.savez(cache_path, vt=vt_np, ft=ft_np, vmapping=vmapping)
-
- vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device)
- ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device)
- self.vt = vt
- self.ft = ft
-
- if vmap:
- vmapping = torch.from_numpy(vmapping.astype(np.int64)).long().to(self.device)
- self.align_v_to_vt(vmapping)
-
- def align_v_to_vt(self, vmapping=None):
- """ remap v/f and vn/fn to vt/ft.
-
- Args:
- vmapping (np.ndarray, optional): the mapping relationship from f to ft. Defaults to None.
- """
- if vmapping is None:
- ft = self.ft.view(-1).long()
- f = self.f.view(-1).long()
- vmapping = torch.zeros(self.vt.shape[0], dtype=torch.long, device=self.device)
- vmapping[ft] = f # scatter, randomly choose one if index is not unique
-
- self.v = self.v[vmapping]
- self.f = self.ft
-
- if self.vn is not None:
- self.vn = self.vn[vmapping]
- self.fn = self.ft
-
- def to(self, device):
- """move all tensor attributes to device.
-
- Args:
- device (torch.device): target device.
-
- Returns:
- Mesh: self.
- """
- self.device = device
- for name in ["v", "f", "vn", "fn", "vt", "ft", "albedo", "vc", "metallicRoughness"]:
- tensor = getattr(self, name)
- if tensor is not None:
- setattr(self, name, tensor.to(device))
- return self
-
- def write(self, path):
- """write the mesh to a path.
-
- Args:
- path (str): path to write, supports ply, obj and glb.
- """
- if path.endswith(".ply"):
- self.write_ply(path)
- elif path.endswith(".obj"):
- self.write_obj(path)
- elif path.endswith(".glb") or path.endswith(".gltf"):
- self.write_glb(path)
- else:
- raise NotImplementedError(f"format {path} not supported!")
-
- def write_ply(self, path):
- """write the mesh in ply format. Only for geometry!
-
- Args:
- path (str): path to write.
- """
-
- if self.albedo is not None:
- print(f'[WARN] ply format does not support exporting texture, will ignore!')
-
- v_np = self.v.detach().cpu().numpy()
- f_np = self.f.detach().cpu().numpy()
-
- _mesh = trimesh.Trimesh(vertices=v_np, faces=f_np)
- _mesh.export(path)
-
-
- def write_glb(self, path):
- """write the mesh in glb/gltf format.
- This will create a scene with a single mesh.
-
- Args:
- path (str): path to write.
- """
-
- # assert self.v.shape[0] == self.vn.shape[0] and self.v.shape[0] == self.vt.shape[0]
- if self.vt is not None and self.v.shape[0] != self.vt.shape[0]:
- self.align_v_to_vt()
-
- import pygltflib
-
- f_np = self.f.detach().cpu().numpy().astype(np.uint32)
- f_np_blob = f_np.flatten().tobytes()
-
- v_np = self.v.detach().cpu().numpy().astype(np.float32)
- v_np_blob = v_np.tobytes()
-
- blob = f_np_blob + v_np_blob
- byteOffset = len(blob)
-
- # base mesh
- gltf = pygltflib.GLTF2(
- scene=0,
- scenes=[pygltflib.Scene(nodes=[0])],
- nodes=[pygltflib.Node(mesh=0)],
- meshes=[pygltflib.Mesh(primitives=[pygltflib.Primitive(
- # indices to accessors (0 is triangles)
- attributes=pygltflib.Attributes(
- POSITION=1,
- ),
- indices=0,
- )])],
- buffers=[
- pygltflib.Buffer(byteLength=len(f_np_blob) + len(v_np_blob))
- ],
- # buffer view (based on dtype)
- bufferViews=[
- # triangles; as flatten (element) array
- pygltflib.BufferView(
- buffer=0,
- byteLength=len(f_np_blob),
- target=pygltflib.ELEMENT_ARRAY_BUFFER, # GL_ELEMENT_ARRAY_BUFFER (34963)
- ),
- # positions; as vec3 array
- pygltflib.BufferView(
- buffer=0,
- byteOffset=len(f_np_blob),
- byteLength=len(v_np_blob),
- byteStride=12, # vec3
- target=pygltflib.ARRAY_BUFFER, # GL_ARRAY_BUFFER (34962)
- ),
- ],
- accessors=[
- # 0 = triangles
- pygltflib.Accessor(
- bufferView=0,
- componentType=pygltflib.UNSIGNED_INT, # GL_UNSIGNED_INT (5125)
- count=f_np.size,
- type=pygltflib.SCALAR,
- max=[int(f_np.max())],
- min=[int(f_np.min())],
- ),
- # 1 = positions
- pygltflib.Accessor(
- bufferView=1,
- componentType=pygltflib.FLOAT, # GL_FLOAT (5126)
- count=len(v_np),
- type=pygltflib.VEC3,
- max=v_np.max(axis=0).tolist(),
- min=v_np.min(axis=0).tolist(),
- ),
- ],
- )
-
- # append texture info
- if self.vt is not None:
-
- vt_np = self.vt.detach().cpu().numpy().astype(np.float32)
- vt_np_blob = vt_np.tobytes()
-
- albedo = self.albedo.detach().cpu().numpy()
- albedo = (albedo * 255).astype(np.uint8)
- albedo = cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR)
- albedo_blob = cv2.imencode('.png', albedo)[1].tobytes()
-
- # update primitive
- gltf.meshes[0].primitives[0].attributes.TEXCOORD_0 = 2
- gltf.meshes[0].primitives[0].material = 0
-
- # update materials
- gltf.materials.append(pygltflib.Material(
- pbrMetallicRoughness=pygltflib.PbrMetallicRoughness(
- baseColorTexture=pygltflib.TextureInfo(index=0, texCoord=0),
- metallicFactor=0.0,
- roughnessFactor=1.0,
- ),
- alphaMode=pygltflib.OPAQUE,
- alphaCutoff=None,
- doubleSided=True,
- ))
-
- gltf.textures.append(pygltflib.Texture(sampler=0, source=0))
- gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT))
- gltf.images.append(pygltflib.Image(bufferView=3, mimeType="image/png"))
-
- # update buffers
- gltf.bufferViews.append(
- # index = 2, texcoords; as vec2 array
- pygltflib.BufferView(
- buffer=0,
- byteOffset=byteOffset,
- byteLength=len(vt_np_blob),
- byteStride=8, # vec2
- target=pygltflib.ARRAY_BUFFER,
- )
- )
-
- gltf.accessors.append(
- # 2 = texcoords
- pygltflib.Accessor(
- bufferView=2,
- componentType=pygltflib.FLOAT,
- count=len(vt_np),
- type=pygltflib.VEC2,
- max=vt_np.max(axis=0).tolist(),
- min=vt_np.min(axis=0).tolist(),
- )
- )
-
- blob += vt_np_blob
- byteOffset += len(vt_np_blob)
-
- gltf.bufferViews.append(
- # index = 3, albedo texture; as none target
- pygltflib.BufferView(
- buffer=0,
- byteOffset=byteOffset,
- byteLength=len(albedo_blob),
- )
- )
-
- blob += albedo_blob
- byteOffset += len(albedo_blob)
-
- gltf.buffers[0].byteLength = byteOffset
-
- # append metllic roughness
- if self.metallicRoughness is not None:
- metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
- metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
- metallicRoughness = cv2.cvtColor(metallicRoughness, cv2.COLOR_RGB2BGR)
- metallicRoughness_blob = cv2.imencode('.png', metallicRoughness)[1].tobytes()
-
- # update texture definition
- gltf.materials[0].pbrMetallicRoughness.metallicFactor = 1.0
- gltf.materials[0].pbrMetallicRoughness.roughnessFactor = 1.0
- gltf.materials[0].pbrMetallicRoughness.metallicRoughnessTexture = pygltflib.TextureInfo(index=1, texCoord=0)
-
- gltf.textures.append(pygltflib.Texture(sampler=1, source=1))
- gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT))
- gltf.images.append(pygltflib.Image(bufferView=4, mimeType="image/png"))
-
- # update buffers
- gltf.bufferViews.append(
- # index = 4, metallicRoughness texture; as none target
- pygltflib.BufferView(
- buffer=0,
- byteOffset=byteOffset,
- byteLength=len(metallicRoughness_blob),
- )
- )
-
- blob += metallicRoughness_blob
- byteOffset += len(metallicRoughness_blob)
-
- gltf.buffers[0].byteLength = byteOffset
-
-
- # set actual data
- gltf.set_binary_blob(blob)
-
- # glb = b"".join(gltf.save_to_bytes())
- gltf.save(path)
-
-
- def write_obj(self, path):
- """write the mesh in obj format. Will also write the texture and mtl files.
-
- Args:
- path (str): path to write.
- """
-
- mtl_path = path.replace(".obj", ".mtl")
- albedo_path = path.replace(".obj", "_albedo.png")
- metallic_path = path.replace(".obj", "_metallic.png")
- roughness_path = path.replace(".obj", "_roughness.png")
-
- v_np = self.v.detach().cpu().numpy()
- vt_np = self.vt.detach().cpu().numpy() if self.vt is not None else None
- vn_np = self.vn.detach().cpu().numpy() if self.vn is not None else None
- f_np = self.f.detach().cpu().numpy()
- ft_np = self.ft.detach().cpu().numpy() if self.ft is not None else None
- fn_np = self.fn.detach().cpu().numpy() if self.fn is not None else None
-
- with open(path, "w") as fp:
- fp.write(f"mtllib {os.path.basename(mtl_path)} \n")
-
- for v in v_np:
- fp.write(f"v {v[0]} {v[1]} {v[2]} \n")
-
- if vt_np is not None:
- for v in vt_np:
- fp.write(f"vt {v[0]} {1 - v[1]} \n")
-
- if vn_np is not None:
- for v in vn_np:
- fp.write(f"vn {v[0]} {v[1]} {v[2]} \n")
-
- fp.write(f"usemtl defaultMat \n")
- for i in range(len(f_np)):
- fp.write(
- f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else ""}/{fn_np[i, 0] + 1 if fn_np is not None else ""} \
- {f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else ""}/{fn_np[i, 1] + 1 if fn_np is not None else ""} \
- {f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else ""}/{fn_np[i, 2] + 1 if fn_np is not None else ""} \n'
- )
-
- with open(mtl_path, "w") as fp:
- fp.write(f"newmtl defaultMat \n")
- fp.write(f"Ka 1 1 1 \n")
- fp.write(f"Kd 1 1 1 \n")
- fp.write(f"Ks 0 0 0 \n")
- fp.write(f"Tr 1 \n")
- fp.write(f"illum 1 \n")
- fp.write(f"Ns 0 \n")
- if self.albedo is not None:
- fp.write(f"map_Kd {os.path.basename(albedo_path)} \n")
- if self.metallicRoughness is not None:
- # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering
- fp.write(f"map_Pm {os.path.basename(metallic_path)} \n")
- fp.write(f"map_Pr {os.path.basename(roughness_path)} \n")
-
- if self.albedo is not None:
- albedo = self.albedo.detach().cpu().numpy()
- albedo = (albedo * 255).astype(np.uint8)
- cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR))
-
- if self.metallicRoughness is not None:
- metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
- metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
- cv2.imwrite(metallic_path, metallicRoughness[..., 2])
- cv2.imwrite(roughness_path, metallicRoughness[..., 1])
-
diff --git a/apps/third_party/CRM/model/.DS_Store b/apps/third_party/CRM/model/.DS_Store
deleted file mode 100644
index e2fd35f3e9054910a42dde149c88e430130d66d7..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/model/.DS_Store and /dev/null differ
diff --git a/apps/third_party/CRM/model/__init__.py b/apps/third_party/CRM/model/__init__.py
deleted file mode 100644
index b339e3ea9dac5482a6daed63b069e4d1eda000a8..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/model/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from model.crm.model import CRM
\ No newline at end of file
diff --git a/apps/third_party/CRM/model/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/model/__pycache__/__init__.cpython-38.pyc
deleted file mode 100644
index de5f1cf642bb8b776bf8da5e5d582f6001996da8..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/model/__pycache__/__init__.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/model/archs/__init__.py b/apps/third_party/CRM/model/archs/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/apps/third_party/CRM/model/archs/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/model/archs/__pycache__/__init__.cpython-38.pyc
deleted file mode 100644
index 1b1f2b99930a199372b91624098bf60f9baa8469..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/model/archs/__pycache__/__init__.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/model/archs/__pycache__/mlp_head.cpython-38.pyc b/apps/third_party/CRM/model/archs/__pycache__/mlp_head.cpython-38.pyc
deleted file mode 100644
index da7fbd72afdb4465902b7f80d4dda57c8f3b9ee5..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/model/archs/__pycache__/mlp_head.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/model/archs/__pycache__/unet.cpython-38.pyc b/apps/third_party/CRM/model/archs/__pycache__/unet.cpython-38.pyc
deleted file mode 100644
index 932b63f9d151ae809cb092048dd560ac436567f7..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/model/archs/__pycache__/unet.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/model/archs/decoders/__init__.py b/apps/third_party/CRM/model/archs/decoders/__init__.py
deleted file mode 100644
index d3f5a12faa99758192ecc4ed3fc22c9249232e86..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/model/archs/decoders/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/apps/third_party/CRM/model/archs/decoders/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/model/archs/decoders/__pycache__/__init__.cpython-38.pyc
deleted file mode 100644
index da3e1e7536c21056bb8268bd03799770646488ae..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/model/archs/decoders/__pycache__/__init__.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/model/archs/decoders/__pycache__/shape_texture_net.cpython-38.pyc b/apps/third_party/CRM/model/archs/decoders/__pycache__/shape_texture_net.cpython-38.pyc
deleted file mode 100644
index 2fab55d21095a1a7e4841140ca2efb1c71f9730c..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/model/archs/decoders/__pycache__/shape_texture_net.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/model/archs/decoders/shape_texture_net.py b/apps/third_party/CRM/model/archs/decoders/shape_texture_net.py
deleted file mode 100644
index 5e5ddd78215b9f48b281757a91b8de6f73e4742a..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/model/archs/decoders/shape_texture_net.py
+++ /dev/null
@@ -1,62 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class TetTexNet(nn.Module):
- def __init__(self, plane_reso=64, padding=0.1, fea_concat=True):
- super().__init__()
- # self.c_dim = c_dim
- self.plane_reso = plane_reso
- self.padding = padding
- self.fea_concat = fea_concat
-
- def forward(self, rolled_out_feature, query):
- # rolled_out_feature: rolled-out triplane feature
- # query: queried xyz coordinates (should be scaled consistently to ptr cloud)
-
- plane_reso = self.plane_reso
-
- triplane_feature = dict()
- triplane_feature['xy'] = rolled_out_feature[:, :, :, 0: plane_reso]
- triplane_feature['yz'] = rolled_out_feature[:, :, :, plane_reso: 2 * plane_reso]
- triplane_feature['zx'] = rolled_out_feature[:, :, :, 2 * plane_reso:]
-
- query_feature_xy = self.sample_plane_feature(query, triplane_feature['xy'], 'xy')
- query_feature_yz = self.sample_plane_feature(query, triplane_feature['yz'], 'yz')
- query_feature_zx = self.sample_plane_feature(query, triplane_feature['zx'], 'zx')
-
- if self.fea_concat:
- query_feature = torch.cat((query_feature_xy, query_feature_yz, query_feature_zx), dim=1)
- else:
- query_feature = query_feature_xy + query_feature_yz + query_feature_zx
-
- output = query_feature.permute(0, 2, 1)
-
- return output
-
- # uses values from plane_feature and pixel locations from vgrid to interpolate feature
- def sample_plane_feature(self, query, plane_feature, plane):
- # CYF note:
- # for pretraining, query are uniformly sampled positions w.i. [-scale, scale]
- # for training, query are essentially tetrahedra grid vertices, which are
- # also within [-scale, scale] in the current version!
- # xy range [-scale, scale]
- if plane == 'xy':
- xy = query[:, :, [0, 1]]
- elif plane == 'yz':
- xy = query[:, :, [1, 2]]
- elif plane == 'zx':
- xy = query[:, :, [2, 0]]
- else:
- raise ValueError("Error! Invalid plane type!")
-
- xy = xy[:, :, None].float()
- # not seem necessary to rescale the grid, because from
- # https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html,
- # it specifies sampling locations normalized by plane_feature's spatial dimension,
- # which is within [-scale, scale] as specified by encoder's calling of coordinate2index()
- vgrid = 1.0 * xy
- sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1)
-
- return sampled_feat
diff --git a/apps/third_party/CRM/model/archs/mlp_head.py b/apps/third_party/CRM/model/archs/mlp_head.py
deleted file mode 100644
index 33d7dcdbf58374dd036d9f3f5f0bfd3f248e845b..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/model/archs/mlp_head.py
+++ /dev/null
@@ -1,40 +0,0 @@
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class SdfMlp(nn.Module):
- def __init__(self, input_dim, hidden_dim=512, bias=True):
- super().__init__()
- self.input_dim = input_dim
- self.hidden_dim = hidden_dim
-
- self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias)
- self.fc2 = nn.Linear(hidden_dim, hidden_dim, bias=bias)
- self.fc3 = nn.Linear(hidden_dim, 4, bias=bias)
-
-
- def forward(self, input):
- x = F.relu(self.fc1(input))
- x = F.relu(self.fc2(x))
- out = self.fc3(x)
- return out
-
-
-class RgbMlp(nn.Module):
- def __init__(self, input_dim, hidden_dim=512, bias=True):
- super().__init__()
- self.input_dim = input_dim
- self.hidden_dim = hidden_dim
-
- self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias)
- self.fc2 = nn.Linear(hidden_dim, hidden_dim, bias=bias)
- self.fc3 = nn.Linear(hidden_dim, 3, bias=bias)
-
- def forward(self, input):
- x = F.relu(self.fc1(input))
- x = F.relu(self.fc2(x))
- out = self.fc3(x)
-
- return out
-
-
\ No newline at end of file
diff --git a/apps/third_party/CRM/model/archs/unet.py b/apps/third_party/CRM/model/archs/unet.py
deleted file mode 100644
index e427c18b1bed00089e87fe25b0e810042538d6b3..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/model/archs/unet.py
+++ /dev/null
@@ -1,53 +0,0 @@
-'''
-Codes are from:
-https://github.com/jaxony/unet-pytorch/blob/master/model.py
-'''
-
-import torch
-import torch.nn as nn
-from diffusers import UNet2DModel
-import einops
-class UNetPP(nn.Module):
- '''
- Wrapper for UNet in diffusers
- '''
- def __init__(self, in_channels):
- super(UNetPP, self).__init__()
- self.in_channels = in_channels
- self.unet = UNet2DModel(
- sample_size=[256, 256*3],
- in_channels=in_channels,
- out_channels=32,
- layers_per_block=2,
- block_out_channels=(64, 128, 128, 128*2, 128*2, 128*4, 128*4),
- down_block_types=(
- "DownBlock2D",
- "DownBlock2D",
- "DownBlock2D",
- "AttnDownBlock2D",
- "AttnDownBlock2D",
- "AttnDownBlock2D",
- "DownBlock2D",
- ),
- up_block_types=(
- "UpBlock2D",
- "AttnUpBlock2D",
- "AttnUpBlock2D",
- "AttnUpBlock2D",
- "UpBlock2D",
- "UpBlock2D",
- "UpBlock2D",
- ),
- )
-
- self.unet.enable_xformers_memory_efficient_attention()
- if in_channels > 12:
- self.learned_plane = torch.nn.parameter.Parameter(torch.zeros([1,in_channels-12,256,256*3]))
-
- def forward(self, x, t=256):
- learned_plane = self.learned_plane
- if x.shape[1] < self.in_channels:
- learned_plane = einops.repeat(learned_plane, '1 C H W -> B C H W', B=x.shape[0]).to(x.device)
- x = torch.cat([x, learned_plane], dim = 1)
- return self.unet(x, t).sample
-
diff --git a/apps/third_party/CRM/model/crm/__pycache__/model.cpython-38.pyc b/apps/third_party/CRM/model/crm/__pycache__/model.cpython-38.pyc
deleted file mode 100644
index 93b38087a8e1a71bab18af950898769b765939ab..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/model/crm/__pycache__/model.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/model/crm/model.py b/apps/third_party/CRM/model/crm/model.py
deleted file mode 100644
index 9eb164266a0c23295ab3451d99d720ec12afa468..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/model/crm/model.py
+++ /dev/null
@@ -1,213 +0,0 @@
-import torch.nn as nn
-import torch
-import torch.nn.functional as F
-
-import numpy as np
-
-
-from pathlib import Path
-import cv2
-import trimesh
-import nvdiffrast.torch as dr
-
-from model.archs.decoders.shape_texture_net import TetTexNet
-from model.archs.unet import UNetPP
-from util.renderer import Renderer
-from model.archs.mlp_head import SdfMlp, RgbMlp
-import xatlas
-
-
-class Dummy:
- pass
-
-class CRM(nn.Module):
- def __init__(self, specs):
- super(CRM, self).__init__()
-
- self.specs = specs
- # configs
- input_specs = specs["Input"]
- self.input = Dummy()
- self.input.scale = input_specs['scale']
- self.input.resolution = input_specs['resolution']
- self.tet_grid_size = input_specs['tet_grid_size']
- self.camera_angle_num = input_specs['camera_angle_num']
-
- self.arch = Dummy()
- self.arch.fea_concat = specs["ArchSpecs"]["fea_concat"]
- self.arch.mlp_bias = specs["ArchSpecs"]["mlp_bias"]
-
- self.dec = Dummy()
- self.dec.c_dim = specs["DecoderSpecs"]["c_dim"]
- self.dec.plane_resolution = specs["DecoderSpecs"]["plane_resolution"]
-
- self.geo_type = specs["Train"].get("geo_type", "flex") # "dmtet" or "flex"
-
- self.unet2 = UNetPP(in_channels=self.dec.c_dim)
-
- mlp_chnl_s = 3 if self.arch.fea_concat else 1 # 3 for queried triplane feature concatenation
- self.decoder = TetTexNet(plane_reso=self.dec.plane_resolution, fea_concat=self.arch.fea_concat)
-
- if self.geo_type == "flex":
- self.weightMlp = nn.Sequential(
- nn.Linear(mlp_chnl_s * 32 * 8, 512),
- nn.SiLU(),
- nn.Linear(512, 21))
-
- self.sdfMlp = SdfMlp(mlp_chnl_s * 32, 512, bias=self.arch.mlp_bias)
- self.rgbMlp = RgbMlp(mlp_chnl_s * 32, 512, bias=self.arch.mlp_bias)
- self.renderer = Renderer(tet_grid_size=self.tet_grid_size, camera_angle_num=self.camera_angle_num,
- scale=self.input.scale, geo_type = self.geo_type)
-
-
- self.spob = True if specs['Pretrain']['mode'] is None else False # whether to add sphere
- self.radius = specs['Pretrain']['radius'] # used when spob
-
- self.denoising = True
- from diffusers import DDIMScheduler
- self.scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler")
-
- def decode(self, data, triplane_feature2):
- if self.geo_type == "flex":
- tet_verts = self.renderer.flexicubes.verts.unsqueeze(0)
- tet_indices = self.renderer.flexicubes.indices
-
- dec_verts = self.decoder(triplane_feature2, tet_verts)
- out = self.sdfMlp(dec_verts)
-
- weight = None
- if self.geo_type == "flex":
- grid_feat = torch.index_select(input=dec_verts, index=self.renderer.flexicubes.indices.reshape(-1),dim=1)
- grid_feat = grid_feat.reshape(dec_verts.shape[0], self.renderer.flexicubes.indices.shape[0], self.renderer.flexicubes.indices.shape[1] * dec_verts.shape[-1])
- weight = self.weightMlp(grid_feat)
- weight = weight * 0.1
-
- pred_sdf, deformation = out[..., 0], out[..., 1:]
- if self.spob:
- pred_sdf = pred_sdf + self.radius - torch.sqrt((tet_verts**2).sum(-1))
-
- _, verts, faces = self.renderer(data, pred_sdf, deformation, tet_verts, tet_indices, weight= weight)
- return verts[0].unsqueeze(0), faces[0].int()
-
- def export_mesh(self, data, out_dir, ind, device=None, tri_fea_2 = None):
- verts = data['verts']
- faces = data['faces']
-
- dec_verts = self.decoder(tri_fea_2, verts.unsqueeze(0))
- colors = self.rgbMlp(dec_verts).squeeze().detach().cpu().numpy()
- # Expect predicted colors value range from [-1, 1]
- colors = (colors * 0.5 + 0.5).clip(0, 1)
-
- verts = verts.squeeze().cpu().numpy()
- faces = faces[..., [2, 1, 0]].squeeze().cpu().numpy()
-
- # export the final mesh
- with torch.no_grad():
- mesh = trimesh.Trimesh(verts, faces, vertex_colors=colors, process=False) # important, process=True leads to seg fault...
- mesh.export(out_dir / f'{ind}.obj')
-
- def export_mesh_wt_uv(self, ctx, data, out_dir, ind, device, res, tri_fea_2=None):
-
- mesh_v = data['verts'].squeeze().cpu().numpy()
- mesh_pos_idx = data['faces'].squeeze().cpu().numpy()
-
- def interpolate(attr, rast, attr_idx, rast_db=None):
- return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db,
- diff_attrs=None if rast_db is None else 'all')
-
- vmapping, indices, uvs = xatlas.parametrize(mesh_v, mesh_pos_idx)
-
- mesh_v = torch.tensor(mesh_v, dtype=torch.float32, device=device)
- mesh_pos_idx = torch.tensor(mesh_pos_idx, dtype=torch.int64, device=device)
-
- # Convert to tensors
- indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
-
- uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device)
- mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device)
- # mesh_v_tex. ture
- uv_clip = uvs[None, ...] * 2.0 - 1.0
-
- # pad to four component coordinate
- uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1)
-
- # rasterize
- rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), res)
-
- # Interpolate world space position
- gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int())
- mask = rast[..., 3:4] > 0
-
- # return uvs, mesh_tex_idx, gb_pos, mask
- gb_pos_unsqz = gb_pos.view(-1, 3)
- mask_unsqz = mask.view(-1)
- tex_unsqz = torch.zeros_like(gb_pos_unsqz) + 1
-
- gb_mask_pos = gb_pos_unsqz[mask_unsqz]
-
- gb_mask_pos = gb_mask_pos[None, ]
-
- with torch.no_grad():
-
- dec_verts = self.decoder(tri_fea_2, gb_mask_pos)
- colors = self.rgbMlp(dec_verts).squeeze()
-
- # Expect predicted colors value range from [-1, 1]
- lo, hi = (-1, 1)
- colors = (colors - lo) * (255 / (hi - lo))
- colors = colors.clip(0, 255)
-
- tex_unsqz[mask_unsqz] = colors
-
- tex = tex_unsqz.view(res + (3,))
-
- verts = mesh_v.squeeze().cpu().numpy()
- faces = mesh_pos_idx[..., [2, 1, 0]].squeeze().cpu().numpy()
- # faces = mesh_pos_idx
- # faces = faces.detach().cpu().numpy()
- # faces = faces[..., [2, 1, 0]]
- indices = indices[..., [2, 1, 0]]
-
- # xatlas.export(f"{out_dir}/{ind}.obj", verts[vmapping], indices, uvs)
- matname = f'{out_dir}.mtl'
- # matname = f'{out_dir}/{ind}.mtl'
- fid = open(matname, 'w')
- fid.write('newmtl material_0\n')
- fid.write('Kd 1 1 1\n')
- fid.write('Ka 1 1 1\n')
- # fid.write('Ks 0 0 0\n')
- fid.write('Ks 0.4 0.4 0.4\n')
- fid.write('Ns 10\n')
- fid.write('illum 2\n')
- fid.write(f'map_Kd {out_dir.split("/")[-1]}.png\n')
- fid.close()
-
- fid = open(f'{out_dir}.obj', 'w')
- # fid = open(f'{out_dir}/{ind}.obj', 'w')
- fid.write('mtllib %s.mtl\n' % out_dir.split("/")[-1])
-
- for pidx, p in enumerate(verts):
- pp = p
- fid.write('v %f %f %f\n' % (pp[0], pp[2], - pp[1]))
-
- for pidx, p in enumerate(uvs):
- pp = p
- fid.write('vt %f %f\n' % (pp[0], 1 - pp[1]))
-
- fid.write('usemtl material_0\n')
- for i, f in enumerate(faces):
- f1 = f + 1
- f2 = indices[i] + 1
- fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
- fid.close()
-
- img = np.asarray(tex.data.cpu().numpy(), dtype=np.float32)
- mask = np.sum(img.astype(float), axis=-1, keepdims=True)
- mask = (mask <= 3.0).astype(float)
- kernel = np.ones((3, 3), 'uint8')
- dilate_img = cv2.dilate(img, kernel, iterations=1)
- img = img * (1 - mask) + dilate_img * mask
- img = img.clip(0, 255).astype(np.uint8)
-
- cv2.imwrite(f'{out_dir}.png', img[..., [2, 1, 0]])
- # cv2.imwrite(f'{out_dir}/{ind}.png', img[..., [2, 1, 0]])
diff --git a/apps/third_party/CRM/pipelines.py b/apps/third_party/CRM/pipelines.py
deleted file mode 100644
index 0ef19dc84dd7789197d02a29239d99b0a82558b1..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/pipelines.py
+++ /dev/null
@@ -1,205 +0,0 @@
-import torch
-import os
-import sys
-proj_dir = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(proj_dir)
-from .libs.base_utils import do_resize_content
-from .imagedream.ldm.util import (
- instantiate_from_config,
- get_obj_from_str,
-)
-from omegaconf import OmegaConf
-from PIL import Image
-import PIL
-import rembg
-class TwoStagePipeline(object):
- def __init__(
- self,
- stage1_model_config,
- stage1_sampler_config,
- device="cuda",
- dtype=torch.float16,
- resize_rate=1,
- ) -> None:
- """
- only for two stage generate process.
- - the first stage was condition on single pixel image, gererate multi-view pixel image, based on the v2pp config
- - the second stage was condition on multiview pixel image generated by the first stage, generate the final image, based on the stage2-test config
- """
- self.resize_rate = resize_rate
-
- self.stage1_model = instantiate_from_config(OmegaConf.load(stage1_model_config.config).model)
- self.stage1_model.load_state_dict(torch.load(stage1_model_config.resume, map_location="cpu"), strict=False)
- self.stage1_model = self.stage1_model.to(device).to(dtype)
-
- self.stage1_model.device = device
- self.device = device
- self.dtype = dtype
- self.stage1_sampler = get_obj_from_str(stage1_sampler_config.target)(
- self.stage1_model, device=device, dtype=dtype, **stage1_sampler_config.params
- )
-
- def stage1_sample(
- self,
- pixel_img,
- prompt="3D assets",
- neg_texts="uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear.",
- step=50,
- scale=5,
- ddim_eta=0.0,
- ):
- if type(pixel_img) == str:
- pixel_img = Image.open(pixel_img)
-
- if isinstance(pixel_img, Image.Image):
- if pixel_img.mode == "RGBA":
- background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0))
- pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB")
- else:
- pixel_img = pixel_img.convert("RGB")
- else:
- raise
- uc = self.stage1_sampler.model.get_learned_conditioning([neg_texts]).to(self.device)
- stage1_images = self.stage1_sampler.i2i(
- self.stage1_sampler.model,
- self.stage1_sampler.size,
- prompt,
- uc=uc,
- sampler=self.stage1_sampler.sampler,
- ip=pixel_img,
- step=step,
- scale=scale,
- batch_size=self.stage1_sampler.batch_size,
- ddim_eta=ddim_eta,
- dtype=self.stage1_sampler.dtype,
- device=self.stage1_sampler.device,
- camera=self.stage1_sampler.camera,
- num_frames=self.stage1_sampler.num_frames,
- pixel_control=(self.stage1_sampler.mode == "pixel"),
- transform=self.stage1_sampler.image_transform,
- offset_noise=self.stage1_sampler.offset_noise,
- )
-
- stage1_images = [Image.fromarray(img) for img in stage1_images]
- stage1_images.pop(self.stage1_sampler.ref_position)
- return stage1_images
-
- def stage2_sample(self, pixel_img, stage1_images, scale=5, step=50):
- if type(pixel_img) == str:
- pixel_img = Image.open(pixel_img)
-
- if isinstance(pixel_img, Image.Image):
- if pixel_img.mode == "RGBA":
- background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0))
- pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB")
- else:
- pixel_img = pixel_img.convert("RGB")
- else:
- raise
- stage2_images = self.stage2_sampler.i2iStage2(
- self.stage2_sampler.model,
- self.stage2_sampler.size,
- "3D assets",
- self.stage2_sampler.uc,
- self.stage2_sampler.sampler,
- pixel_images=stage1_images,
- ip=pixel_img,
- step=step,
- scale=scale,
- batch_size=self.stage2_sampler.batch_size,
- ddim_eta=0.0,
- dtype=self.stage2_sampler.dtype,
- device=self.stage2_sampler.device,
- camera=self.stage2_sampler.camera,
- num_frames=self.stage2_sampler.num_frames,
- pixel_control=(self.stage2_sampler.mode == "pixel"),
- transform=self.stage2_sampler.image_transform,
- offset_noise=self.stage2_sampler.offset_noise,
- )
- stage2_images = [Image.fromarray(img) for img in stage2_images]
- return stage2_images
-
- def set_seed(self, seed):
- self.stage1_sampler.seed = seed
- # self.stage2_sampler.seed = seed
-
- def __call__(self, pixel_img, prompt="3D assets", scale=5, step=50):
- pixel_img = do_resize_content(pixel_img, self.resize_rate)
- stage1_images = self.stage1_sample(pixel_img, prompt, scale=scale, step=step)
- # stage2_images = self.stage2_sample(pixel_img, stage1_images, scale=scale, step=step)
-
- return {
- "ref_img": pixel_img,
- "stage1_images": stage1_images,
- # "stage2_images": stage2_images,
- }
-
-rembg_session = rembg.new_session()
-
-def expand_to_square(image, bg_color=(0, 0, 0, 0)):
- # expand image to 1:1
- width, height = image.size
- if width == height:
- return image
- new_size = (max(width, height), max(width, height))
- new_image = Image.new("RGBA", new_size, bg_color)
- paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2)
- new_image.paste(image, paste_position)
- return new_image
-
-def remove_background(
- image: PIL.Image.Image,
- rembg_session = None,
- force: bool = False,
- **rembg_kwargs,
-) -> PIL.Image.Image:
- do_remove = True
- if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
- # explain why current do not rm bg
- print("alhpa channl not enpty, skip remove background, using alpha channel as mask")
- background = Image.new("RGBA", image.size, (0, 0, 0, 0))
- image = Image.alpha_composite(background, image)
- do_remove = False
- do_remove = do_remove or force
- if do_remove:
- image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
- return image
-
-def do_resize_content(original_image: Image, scale_rate):
- # resize image content wile retain the original image size
- if scale_rate != 1:
- # Calculate the new size after rescaling
- new_size = tuple(int(dim * scale_rate) for dim in original_image.size)
- # Resize the image while maintaining the aspect ratio
- resized_image = original_image.resize(new_size)
- # Create a new image with the original size and black background
- padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0))
- paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2)
- padded_image.paste(resized_image, paste_position)
- return padded_image
- else:
- return original_image
-
-def add_background(image, bg_color=(255, 255, 255)):
- # given an RGBA image, alpha channel is used as mask to add background color
- background = Image.new("RGBA", image.size, bg_color)
- return Image.alpha_composite(background, image)
-
-
-def preprocess_image(image, background_choice, foreground_ratio, backgroud_color):
- """
- input image is a pil image in RGBA, return RGB image
- """
- print(background_choice)
- if background_choice == "Alpha as mask":
- background = Image.new("RGBA", image.size, (0, 0, 0, 0))
- image = Image.alpha_composite(background, image)
- else:
- image = remove_background(image, rembg_session, force_remove=True)
- image = do_resize_content(image, foreground_ratio)
- image = expand_to_square(image)
- image = add_background(image, backgroud_color)
- return image.convert("RGB")
-
-
-
diff --git a/apps/third_party/CRM/requirements.txt b/apps/third_party/CRM/requirements.txt
deleted file mode 100644
index 8501f40d7ec31f64b3b6b77549fb2b7623f2d382..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/requirements.txt
+++ /dev/null
@@ -1,16 +0,0 @@
-gradio
-huggingface-hub
-diffusers==0.24.0
-einops==0.7.0
-Pillow==10.1.0
-transformers==4.27.1
-open-clip-torch==2.7.0
-opencv-contrib-python-headless==4.9.0.80
-opencv-python-headless==4.9.0.80
-omegaconf
-rembg
-pygltflib
-kiui
-trimesh
-xatlas
-pymeshlab
diff --git a/apps/third_party/CRM/run.py b/apps/third_party/CRM/run.py
deleted file mode 100644
index 8e14be6e7c7cb1d314d1e82a23a6d250e79ce3b7..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/run.py
+++ /dev/null
@@ -1,160 +0,0 @@
-import torch
-from libs.base_utils import do_resize_content
-from imagedream.ldm.util import (
- instantiate_from_config,
- get_obj_from_str,
-)
-from omegaconf import OmegaConf
-from PIL import Image
-import numpy as np
-from inference import generate3d
-from huggingface_hub import hf_hub_download
-import json
-import argparse
-import shutil
-from model import CRM
-import PIL
-import rembg
-import os
-from pipelines import TwoStagePipeline
-
-rembg_session = rembg.new_session()
-
-def expand_to_square(image, bg_color=(0, 0, 0, 0)):
- # expand image to 1:1
- width, height = image.size
- if width == height:
- return image
- new_size = (max(width, height), max(width, height))
- new_image = Image.new("RGBA", new_size, bg_color)
- paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2)
- new_image.paste(image, paste_position)
- return new_image
-
-def remove_background(
- image: PIL.Image.Image,
- rembg_session = None,
- force: bool = False,
- **rembg_kwargs,
-) -> PIL.Image.Image:
- do_remove = True
- if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
- # explain why current do not rm bg
- print("alhpa channl not enpty, skip remove background, using alpha channel as mask")
- background = Image.new("RGBA", image.size, (0, 0, 0, 0))
- image = Image.alpha_composite(background, image)
- do_remove = False
- do_remove = do_remove or force
- if do_remove:
- image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
- return image
-
-def do_resize_content(original_image: Image, scale_rate):
- # resize image content wile retain the original image size
- if scale_rate != 1:
- # Calculate the new size after rescaling
- new_size = tuple(int(dim * scale_rate) for dim in original_image.size)
- # Resize the image while maintaining the aspect ratio
- resized_image = original_image.resize(new_size)
- # Create a new image with the original size and black background
- padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0))
- paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2)
- padded_image.paste(resized_image, paste_position)
- return padded_image
- else:
- return original_image
-
-def add_background(image, bg_color=(255, 255, 255)):
- # given an RGBA image, alpha channel is used as mask to add background color
- background = Image.new("RGBA", image.size, bg_color)
- return Image.alpha_composite(background, image)
-
-
-def preprocess_image(image, background_choice, foreground_ratio, backgroud_color):
- """
- input image is a pil image in RGBA, return RGB image
- """
- print(background_choice)
- if background_choice == "Alpha as mask":
- background = Image.new("RGBA", image.size, (0, 0, 0, 0))
- image = Image.alpha_composite(background, image)
- else:
- image = remove_background(image, rembg_session, force_remove=True)
- image = do_resize_content(image, foreground_ratio)
- image = expand_to_square(image)
- image = add_background(image, backgroud_color)
- return image.convert("RGB")
-
-if __name__ == "__main__":
-
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--inputdir",
- type=str,
- default="examples/kunkun.webp",
- help="dir for input image",
- )
- parser.add_argument(
- "--scale",
- type=float,
- default=5.0,
- )
- parser.add_argument(
- "--step",
- type=int,
- default=50,
- )
- parser.add_argument(
- "--bg_choice",
- type=str,
- default="Auto Remove background",
- help="[Auto Remove background] or [Alpha as mask]",
- )
- parser.add_argument(
- "--outdir",
- type=str,
- default="out/",
- )
- args = parser.parse_args()
-
-
- img = Image.open(args.inputdir)
- img = preprocess_image(img, args.bg_choice, 1.0, (127, 127, 127))
- os.makedirs(args.outdir, exist_ok=True)
- img.save(args.outdir+"preprocessed_image.png")
-
- crm_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="CRM.pth")
- specs = json.load(open("configs/specs_objaverse_total.json"))
- model = CRM(specs).to("cuda")
- model.load_state_dict(torch.load(crm_path, map_location = "cuda"), strict=False)
-
- stage1_config = OmegaConf.load("configs/nf7_v3_SNR_rd_size_stroke.yaml").config
- stage2_config = OmegaConf.load("configs/stage2-v2-snr.yaml").config
- stage2_sampler_config = stage2_config.sampler
- stage1_sampler_config = stage1_config.sampler
-
- stage1_model_config = stage1_config.models
- stage2_model_config = stage2_config.models
-
- xyz_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="ccm-diffusion.pth")
- pixel_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="pixel-diffusion.pth")
- stage1_model_config.resume = pixel_path
- stage2_model_config.resume = xyz_path
-
- pipeline = TwoStagePipeline(
- stage1_model_config,
- stage2_model_config,
- stage1_sampler_config,
- stage2_sampler_config,
- )
-
- rt_dict = pipeline(img, scale=args.scale, step=args.step)
- stage1_images = rt_dict["stage1_images"]
- stage2_images = rt_dict["stage2_images"]
- np_imgs = np.concatenate(stage1_images, 1)
- np_xyzs = np.concatenate(stage2_images, 1)
- Image.fromarray(np_imgs).save(args.outdir+"pixel_images.png")
- Image.fromarray(np_xyzs).save(args.outdir+"xyz_images.png")
-
- glb_path, obj_path = generate3d(model, np_imgs, np_xyzs, "cuda")
- shutil.copy(obj_path, args.outdir+"output3d.zip")
\ No newline at end of file
diff --git a/apps/third_party/CRM/util/__init__.py b/apps/third_party/CRM/util/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/apps/third_party/CRM/util/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/util/__pycache__/__init__.cpython-38.pyc
deleted file mode 100644
index c2271f9a2aee0aef5a1ec6b84fa433a3e04bfcd6..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/util/__pycache__/__init__.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/util/__pycache__/flexicubes.cpython-38.pyc b/apps/third_party/CRM/util/__pycache__/flexicubes.cpython-38.pyc
deleted file mode 100644
index 37685af5cd939c94104af01552419fc96e7e4688..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/util/__pycache__/flexicubes.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/util/__pycache__/flexicubes_geometry.cpython-38.pyc b/apps/third_party/CRM/util/__pycache__/flexicubes_geometry.cpython-38.pyc
deleted file mode 100644
index 7b063e820c975d5640b6f1c2684f633f461a9fac..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/util/__pycache__/flexicubes_geometry.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/util/__pycache__/renderer.cpython-38.pyc b/apps/third_party/CRM/util/__pycache__/renderer.cpython-38.pyc
deleted file mode 100644
index 5fc688d3238a84ca4ffd4c61656f7a63679d72e8..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/util/__pycache__/renderer.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/util/__pycache__/tables.cpython-38.pyc b/apps/third_party/CRM/util/__pycache__/tables.cpython-38.pyc
deleted file mode 100644
index e9acbd82a2a5d5a7ef03a34c9b89035b2bcd6b39..0000000000000000000000000000000000000000
Binary files a/apps/third_party/CRM/util/__pycache__/tables.cpython-38.pyc and /dev/null differ
diff --git a/apps/third_party/CRM/util/flexicubes.py b/apps/third_party/CRM/util/flexicubes.py
deleted file mode 100644
index 0e12d371362ea7f7f315bd33d866f4bb3510eadb..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/util/flexicubes.py
+++ /dev/null
@@ -1,579 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-#
-# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
-# and proprietary rights in and to this software, related documentation
-# and any modifications thereto. Any use, reproduction, disclosure or
-# distribution of this software and related documentation without an express
-# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
-import torch
-from util.tables import *
-
-__all__ = [
- 'FlexiCubes'
-]
-
-
-class FlexiCubes:
- """
- This class implements the FlexiCubes method for extracting meshes from scalar fields.
- It maintains a series of lookup tables and indices to support the mesh extraction process.
- FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances
- the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting
- the surface representation through gradient-based optimization.
-
- During instantiation, the class loads DMC tables from a file and transforms them into
- PyTorch tensors on the specified device.
-
- Attributes:
- device (str): Specifies the computational device (default is "cuda").
- dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges
- associated with each dual vertex in 256 Marching Cubes (MC) configurations.
- num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of
- the 256 MC configurations.
- check_table (torch.Tensor): Table resolving ambiguity in cases C16 and C19
- of the DMC configurations.
- tet_table (torch.Tensor): Lookup table used in tetrahedralizing the isosurface.
- quad_split_1 (torch.Tensor): Indices for splitting a quad into two triangles
- along one diagonal.
- quad_split_2 (torch.Tensor): Alternative indices for splitting a quad into
- two triangles along the other diagonal.
- quad_split_train (torch.Tensor): Indices for splitting a quad into four triangles
- during training by connecting all edges to their midpoints.
- cube_corners (torch.Tensor): Defines the positions of a standard unit cube's
- eight corners in 3D space, ordered starting from the origin (0,0,0),
- moving along the x-axis, then y-axis, and finally z-axis.
- Used as a blueprint for generating a voxel grid.
- cube_corners_idx (torch.Tensor): Cube corners indexed as powers of 2, used
- to retrieve the case id.
- cube_edges (torch.Tensor): Edge connections in a cube, listed in pairs.
- Used to retrieve edge vertices in DMC.
- edge_dir_table (torch.Tensor): A mapping tensor that associates edge indices with
- their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the
- first edge is oriented along the x-axis.
- dir_faces_table (torch.Tensor): A tensor that maps the corresponding axis of shared edges
- across four adjacent cubes to the shared faces of these cubes. For instance,
- dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along
- the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively.
- This tensor is only utilized during isosurface tetrahedralization.
- adj_pairs (torch.Tensor):
- A tensor containing index pairs that correspond to neighboring cubes that share the same edge.
- qef_reg_scale (float):
- The scaling factor applied to the regularization loss to prevent issues with singularity
- when solving the QEF. This parameter is only used when a 'grad_func' is specified.
- weight_scale (float):
- The scale of weights in FlexiCubes. Should be between 0 and 1.
- """
-
- def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99):
-
- self.device = device
- self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False)
- self.num_vd_table = torch.tensor(num_vd_table,
- dtype=torch.long, device=device, requires_grad=False)
- self.check_table = torch.tensor(
- check_table,
- dtype=torch.long, device=device, requires_grad=False)
-
- self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False)
- self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
- self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
- self.quad_split_train = torch.tensor(
- [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False)
-
- self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
- 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device)
- self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False))
- self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
- 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False)
-
- self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1],
- dtype=torch.long, device=device)
- self.dir_faces_table = torch.tensor([
- [[5, 4], [3, 2], [4, 5], [2, 3]],
- [[5, 4], [1, 0], [4, 5], [0, 1]],
- [[3, 2], [1, 0], [2, 3], [0, 1]]
- ], dtype=torch.long, device=device)
- self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device)
- self.qef_reg_scale = qef_reg_scale
- self.weight_scale = weight_scale
-
- def construct_voxel_grid(self, res):
- """
- Generates a voxel grid based on the specified resolution.
-
- Args:
- res (int or list[int]): The resolution of the voxel grid. If an integer
- is provided, it is used for all three dimensions. If a list or tuple
- of 3 integers is provided, they define the resolution for the x,
- y, and z dimensions respectively.
-
- Returns:
- (torch.Tensor, torch.Tensor): Returns the vertices and the indices of the
- cube corners (index into vertices) of the constructed voxel grid.
- The vertices are centered at the origin, with the length of each
- dimension in the grid being one.
- """
- base_cube_f = torch.arange(8).to(self.device)
- if isinstance(res, int):
- res = (res, res, res)
- voxel_grid_template = torch.ones(res, device=self.device)
-
- res = torch.tensor([res], dtype=torch.float, device=self.device)
- coords = torch.nonzero(voxel_grid_template).float() / res # N, 3
- verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(-1, 3)
- cubes = (base_cube_f.unsqueeze(0) +
- torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8).reshape(-1)
-
- verts_rounded = torch.round(verts * 10**5) / (10**5)
- verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True)
- cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8)
-
- return verts_unique - 0.5, cubes
-
- def __call__(self, x_nx3, s_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None,
- gamma_f=None, training=False, output_tetmesh=False, grad_func=None):
- r"""
- Main function for mesh extraction from scalar field using FlexiCubes. This function converts
- discrete signed distance fields, encoded on voxel grids and additional per-cube parameters,
- to triangle or tetrahedral meshes using a differentiable operation as described in
- `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances
- mesh quality and geometric fidelity by adjusting the surface representation based on gradient
- optimization. The output surface is differentiable with respect to the input vertex positions,
- scalar field values, and weight parameters.
-
- If you intend to extract a surface mesh from a fixed Signed Distance Field without the
- optimization of parameters, it is suggested to provide the "grad_func" which should
- return the surface gradient at any given 3D position. When grad_func is provided, the process
- to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as
- described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy.
- Please note, this approach is non-differentiable.
-
- For more details and example usage in optimization, refer to the
- `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper.
-
- Args:
- x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed.
- s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values
- denote that the corresponding vertex resides inside the isosurface. This affects
- the directions of the extracted triangle faces and volume to be tetrahedralized.
- cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid.
- res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it
- is used for all three dimensions. If a list or tuple of 3 integers is provided, they
- specify the resolution for the x, y, and z dimensions respectively.
- beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual
- vertices positioning. Defaults to uniform value for all edges.
- alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual
- vertices positioning. Defaults to uniform value for all vertices.
- gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of
- quadrilaterals into triangles. Defaults to uniform value for all cubes.
- training (bool, optional): If set to True, applies differentiable quad splitting for
- training. Defaults to False.
- output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise,
- outputs a triangular mesh. Defaults to False.
- grad_func (callable, optional): A function to compute the surface gradient at specified
- 3D positions (input: Nx3 positions). The function should return gradients as an Nx3
- tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None.
-
- Returns:
- (torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing:
- - Vertices for the extracted triangular/tetrahedral mesh.
- - Faces for the extracted triangular/tetrahedral mesh.
- - Regularizer L_dev, computed per dual vertex.
-
- .. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization:
- https://research.nvidia.com/labs/toronto-ai/flexicubes/
- .. _Manifold Dual Contouring:
- https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf
- """
-
- surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8)
- if surf_cubes.sum() == 0:
- return torch.zeros(
- (0, 3),
- device=self.device), torch.zeros(
- (0, 4),
- dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros(
- (0, 3),
- dtype=torch.long, device=self.device), torch.zeros(
- (0),
- device=self.device)
- beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes)
-
- case_ids = self._get_case_id(occ_fx8, surf_cubes, res)
-
- surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes)
-
- vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd(
- x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func)
- vertices, faces, s_edges, edge_indices = self._triangulate(
- s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func)
- if not output_tetmesh:
- return vertices, faces, L_dev
- else:
- vertices, tets = self._tetrahedralize(
- x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
- surf_cubes, training)
- return vertices, tets, L_dev
-
- def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):
- """
- Regularizer L_dev as in Equation 8
- """
- dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1)
- mean_l2 = torch.zeros_like(vd[:, 0])
- mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float()
- mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs()
- return mad
-
- def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes):
- """
- Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones.
- """
- n_cubes = surf_cubes.shape[0]
-
- if beta_fx12 is not None:
- beta_fx12 = (torch.tanh(beta_fx12) * self.weight_scale + 1)
- else:
- beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device)
-
- if alpha_fx8 is not None:
- alpha_fx8 = (torch.tanh(alpha_fx8) * self.weight_scale + 1)
- else:
- alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device)
-
- if gamma_f is not None:
- gamma_f = torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale)/2
- else:
- gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device)
-
- return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes]
-
- @torch.no_grad()
- def _get_case_id(self, occ_fx8, surf_cubes, res):
- """
- Obtains the ID of topology cases based on cell corner occupancy. This function resolves the
- ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the
- supplementary material. It should be noted that this function assumes a regular grid.
- """
- case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1)
-
- problem_config = self.check_table.to(self.device)[case_ids]
- to_check = problem_config[..., 0] == 1
- problem_config = problem_config[to_check]
- if not isinstance(res, (list, tuple)):
- res = [res, res, res]
-
- # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,
- # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).
- # This allows efficient checking on adjacent cubes.
- problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long)
- vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3
- vol_idx_problem = vol_idx[surf_cubes][to_check]
- problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config
- vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4]
-
- within_range = (
- vol_idx_problem_adj[..., 0] >= 0) & (
- vol_idx_problem_adj[..., 0] < res[0]) & (
- vol_idx_problem_adj[..., 1] >= 0) & (
- vol_idx_problem_adj[..., 1] < res[1]) & (
- vol_idx_problem_adj[..., 2] >= 0) & (
- vol_idx_problem_adj[..., 2] < res[2])
-
- vol_idx_problem = vol_idx_problem[within_range]
- vol_idx_problem_adj = vol_idx_problem_adj[within_range]
- problem_config = problem_config[within_range]
- problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0],
- vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]]
- # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.
- to_invert = (problem_config_adj[..., 0] == 1)
- idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert]
- case_ids.index_put_((idx,), problem_config[to_invert][..., -1])
- return case_ids
-
- @torch.no_grad()
- def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes):
- """
- Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge
- can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge
- and marks the cube edges with this index.
- """
- occ_n = s_n < 0
- all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2)
- unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
-
- unique_edges = unique_edges.long()
- mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
-
- surf_edges_mask = mask_edges[_idx_map]
- counts = counts[_idx_map]
-
- mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device) * -1
- mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device)
- # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index
- # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1.
- idx_map = mapping[_idx_map]
- surf_edges = unique_edges[mask_edges]
- return surf_edges, idx_map, counts, surf_edges_mask
-
- @torch.no_grad()
- def _identify_surf_cubes(self, s_n, cube_fx8):
- """
- Identifies grid cubes that intersect with the underlying surface by checking if the signs at
- all corners are not identical.
- """
- occ_n = s_n < 0
- occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
- _occ_sum = torch.sum(occ_fx8, -1)
- surf_cubes = (_occ_sum > 0) & (_occ_sum < 8)
- return surf_cubes, occ_fx8
-
- def _linear_interp(self, edges_weight, edges_x):
- """
- Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'.
- """
- edge_dim = edges_weight.dim() - 2
- assert edges_weight.shape[edge_dim] == 2
- edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), -
- torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim)
- denominator = edges_weight.sum(edge_dim)
- ue = (edges_x * edges_weight).sum(edge_dim) / denominator
- return ue
-
- def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None):
- p_bxnx3 = p_bxnx3.reshape(-1, 7, 3)
- norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3)
- c_bx3 = c_bx3.reshape(-1, 3)
- A = norm_bxnx3
- B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True)
-
- A_reg = (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1)
- B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1)
- A = torch.cat([A, A_reg], 1)
- B = torch.cat([B, B_reg], 1)
- dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1)
- return dual_verts
-
- def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func):
- """
- Computes the location of dual vertices as described in Section 4.2
- """
- alpha_nx12x2 = torch.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2)
- surf_edges_x = torch.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3)
- surf_edges_s = torch.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1)
- zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x)
-
- idx_map = idx_map.reshape(-1, 12)
- num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0)
- edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], []
-
- total_num_vd = 0
- vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False)
- if grad_func is not None:
- normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1)
- vd = []
- for num in torch.unique(num_vd):
- cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching)
- curr_num_vd = cur_cubes.sum() * num
- curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7)
- curr_edge_group_to_vd = torch.arange(
- curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd
- total_num_vd += curr_num_vd
- curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[
- cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group)
-
- curr_mask = (curr_edge_group != -1)
- edge_group.append(torch.masked_select(curr_edge_group, curr_mask))
- edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask))
- edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask))
- vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True))
- vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1))
-
- if grad_func is not None:
- with torch.no_grad():
- cube_e_verts_idx = idx_map[cur_cubes]
- curr_edge_group[~curr_mask] = 0
-
- verts_group_idx = torch.gather(input=cube_e_verts_idx, dim=1, index=curr_edge_group)
- verts_group_idx[verts_group_idx == -1] = 0
- verts_group_pos = torch.index_select(
- input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0).reshape(-1, num.item(), 7, 3)
- v0 = x_nx3[surf_cubes_fx8[cur_cubes][:, 0]].reshape(-1, 1, 1, 3).repeat(1, num.item(), 1, 1)
- curr_mask = curr_mask.reshape(-1, num.item(), 7, 1)
- verts_centroid = (verts_group_pos * curr_mask).sum(2) / (curr_mask.sum(2))
-
- normals_bx7x3 = torch.index_select(input=normals, index=verts_group_idx.reshape(-1), dim=0).reshape(
- -1, num.item(), 7,
- 3)
- curr_mask = curr_mask.squeeze(2)
- vd.append(self._solve_vd_QEF((verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask,
- verts_centroid - v0.squeeze(2)) + v0.reshape(-1, 3))
- edge_group = torch.cat(edge_group)
- edge_group_to_vd = torch.cat(edge_group_to_vd)
- edge_group_to_cube = torch.cat(edge_group_to_cube)
- vd_num_edges = torch.cat(vd_num_edges)
- vd_gamma = torch.cat(vd_gamma)
-
- if grad_func is not None:
- vd = torch.cat(vd)
- L_dev = torch.zeros([1], device=self.device)
- else:
- vd = torch.zeros((total_num_vd, 3), device=self.device)
- beta_sum = torch.zeros((total_num_vd, 1), device=self.device)
-
- idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group)
-
- x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3)
- s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1)
-
- zero_crossing_group = torch.index_select(
- input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3)
-
- alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0,
- index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1)
- ue_group = self._linear_interp(s_group * alpha_group, x_group)
-
- beta_group = torch.gather(input=beta_fx12.reshape(-1), dim=0,
- index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1)
- beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group)
- vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum
- L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges)
-
- v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd
-
- vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube *
- 12 + edge_group, src=v_idx[edge_group_to_vd])
-
- return vd, L_dev, vd_gamma, vd_idx_map
-
- def _triangulate(self, s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func):
- """
- Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into
- triangles based on the gamma parameter, as described in Section 4.3.
- """
- with torch.no_grad():
- group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes.
- group = idx_map.reshape(-1)[group_mask]
- vd_idx = vd_idx_map[group_mask]
- edge_indices, indices = torch.sort(group, stable=True)
- quad_vd_idx = vd_idx[indices].reshape(-1, 4)
-
- # Ensure all face directions point towards the positive SDF to maintain consistent winding.
- s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2)
- flip_mask = s_edges[:, 0] > 0
- quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]],
- quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]]))
- if grad_func is not None:
- # when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients.
- with torch.no_grad():
- vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1)
- quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
- gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True)
- gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True)
- else:
- quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4)
- gamma_02 = torch.index_select(input=quad_gamma, index=torch.tensor(
- 0, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1)
- gamma_13 = torch.index_select(input=quad_gamma, index=torch.tensor(
- 1, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1)
- if not training:
- mask = (gamma_02 > gamma_13).squeeze(1)
- faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device)
- faces[mask] = quad_vd_idx[mask][:, self.quad_split_1]
- faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2]
- faces = faces.reshape(-1, 3)
- else:
- vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
- vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) +
- torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2
- vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) +
- torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2
- weight_sum = (gamma_02 + gamma_13) + 1e-8
- vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) /
- weight_sum.unsqueeze(-1)).squeeze(1)
- vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0]
- vd = torch.cat([vd, vd_center])
- faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2)
- faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3)
- return vd, faces, s_edges, edge_indices
-
- def _tetrahedralize(
- self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
- surf_cubes, training):
- """
- Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5.
- """
- occ_n = s_n < 0
- occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
- occ_sum = torch.sum(occ_fx8, -1)
-
- inside_verts = x_nx3[occ_n]
- mapping_inside_verts = torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1
- mapping_inside_verts[occ_n] = torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0]
- """
- For each grid edge connecting two grid vertices with different
- signs, we first form a four-sided pyramid by connecting one
- of the grid vertices with four mesh vertices that correspond
- to the grid edge and then subdivide the pyramid into two tetrahedra
- """
- inside_verts_idx = mapping_inside_verts[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[
- s_edges < 0]]
- if not training:
- inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1)
- else:
- inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1)
-
- tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1)
- """
- For each grid edge connecting two grid vertices with the
- same sign, the tetrahedron is formed by the two grid vertices
- and two vertices in consecutive adjacent cells
- """
- inside_cubes = (occ_sum == 8)
- inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1)
- inside_cubes_center_idx = torch.arange(
- inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0]
-
- surface_n_inside_cubes = surf_cubes | inside_cubes
- edge_center_vertex_idx = torch.ones(((surface_n_inside_cubes).sum(), 13),
- dtype=torch.long, device=x_nx3.device) * -1
- surf_cubes = surf_cubes[surface_n_inside_cubes]
- inside_cubes = inside_cubes[surface_n_inside_cubes]
- edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12)
- edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx
-
- all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2)
- unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
- unique_edges = unique_edges.long()
- mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2
- mask = mask_edges[_idx_map]
- counts = counts[_idx_map]
- mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1
- mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device)
- idx_map = mapping[_idx_map]
-
- group_mask = (counts == 4) & mask
- group = idx_map.reshape(-1)[group_mask]
- edge_indices, indices = torch.sort(group)
- cube_idx = torch.arange((_idx_map.shape[0] // 12), dtype=torch.long,
- device=self.device).unsqueeze(1).expand(-1, 12).reshape(-1)[group_mask]
- edge_idx = torch.arange((12), dtype=torch.long, device=self.device).unsqueeze(
- 0).expand(_idx_map.shape[0] // 12, -1).reshape(-1)[group_mask]
- # Identify the face shared by the adjacent cells.
- cube_idx_4 = cube_idx[indices].reshape(-1, 4)
- edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0]
- shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1)
- cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1)
- # Identify an edge of the face with different signs and
- # select the mesh vertex corresponding to the identified edge.
- case_ids_expand = torch.ones((surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device) * 255
- case_ids_expand[surf_cubes] = case_ids
- cases = case_ids_expand[cube_idx_4x2]
- quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2)
- mask = (quad_edge == -1).sum(-1) == 0
- inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2)
- tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask]
-
- tets = torch.cat([tets_surface, tets_inside])
- vertices = torch.cat([vertices, inside_verts, inside_cubes_center])
- return vertices, tets
diff --git a/apps/third_party/CRM/util/flexicubes_geometry.py b/apps/third_party/CRM/util/flexicubes_geometry.py
deleted file mode 100644
index 5dec7635e7f275f6ef3867223ea2700d3af6f4ea..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/util/flexicubes_geometry.py
+++ /dev/null
@@ -1,116 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-#
-# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
-# and proprietary rights in and to this software, related documentation
-# and any modifications thereto. Any use, reproduction, disclosure or
-# distribution of this software and related documentation without an express
-# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
-
-import torch
-from util.flexicubes import FlexiCubes # replace later
-# from dmtet import sdf_reg_loss_batch
-import torch.nn.functional as F
-
-def get_center_boundary_index(grid_res, device):
- v = torch.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device)
- v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True
- center_indices = torch.nonzero(v.reshape(-1))
-
- v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False
- v[:2, ...] = True
- v[-2:, ...] = True
- v[:, :2, ...] = True
- v[:, -2:, ...] = True
- v[:, :, :2] = True
- v[:, :, -2:] = True
- boundary_indices = torch.nonzero(v.reshape(-1))
- return center_indices, boundary_indices
-
-###############################################################################
-# Geometry interface
-###############################################################################
-class FlexiCubesGeometry(object):
- def __init__(
- self, grid_res=64, scale=2.0, device='cuda', renderer=None,
- render_type='neural_render', args=None):
- super(FlexiCubesGeometry, self).__init__()
- self.grid_res = grid_res
- self.device = device
- self.args = args
- self.fc = FlexiCubes(device, weight_scale=0.5)
- self.verts, self.indices = self.fc.construct_voxel_grid(grid_res)
- if isinstance(scale, list):
- self.verts[:, 0] = self.verts[:, 0] * scale[0]
- self.verts[:, 1] = self.verts[:, 1] * scale[1]
- self.verts[:, 2] = self.verts[:, 2] * scale[1]
- else:
- self.verts = self.verts * scale
-
- all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2)
- self.all_edges = torch.unique(all_edges, dim=0)
-
- # Parameters used for fix boundary sdf
- self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res, device)
- self.renderer = renderer
- self.render_type = render_type
-
- def getAABB(self):
- return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
-
- def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False):
- if indices is None:
- indices = self.indices
-
- verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, self.grid_res,
- beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20],
- gamma_f=weight_n[:, 20], training=is_training
- )
- return verts, faces, v_reg_loss
-
-
- def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False):
- return_value = dict()
- if self.render_type == 'neural_render':
- tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh(
- mesh_v_nx3.unsqueeze(dim=0),
- mesh_f_fx3.int(),
- camera_mv_bx4x4,
- mesh_v_nx3.unsqueeze(dim=0),
- resolution=resolution,
- device=self.device,
- hierarchical_mask=hierarchical_mask
- )
-
- return_value['tex_pos'] = tex_pos
- return_value['mask'] = mask
- return_value['hard_mask'] = hard_mask
- return_value['rast'] = rast
- return_value['v_pos_clip'] = v_pos_clip
- return_value['mask_pyramid'] = mask_pyramid
- return_value['depth'] = depth
- else:
- raise NotImplementedError
-
- return return_value
-
- def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256):
- # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1
- v_list = []
- f_list = []
- n_batch = v_deformed_bxnx3.shape[0]
- all_render_output = []
- for i_batch in range(n_batch):
- verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch])
- v_list.append(verts_nx3)
- f_list.append(faces_fx3)
- render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution)
- all_render_output.append(render_output)
-
- # Concatenate all render output
- return_keys = all_render_output[0].keys()
- return_value = dict()
- for k in return_keys:
- value = [v[k] for v in all_render_output]
- return_value[k] = value
- # We can do concatenation outside of the render
- return return_value
diff --git a/apps/third_party/CRM/util/renderer.py b/apps/third_party/CRM/util/renderer.py
deleted file mode 100644
index 2d1dc69dcaaf902cbb70af06834015ed1730dc0e..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/util/renderer.py
+++ /dev/null
@@ -1,49 +0,0 @@
-
-import torch
-import torch.nn as nn
-import nvdiffrast.torch as dr
-from util.flexicubes_geometry import FlexiCubesGeometry
-
-class Renderer(nn.Module):
- def __init__(self, tet_grid_size, camera_angle_num, scale, geo_type):
- super().__init__()
-
- self.tet_grid_size = tet_grid_size
- self.camera_angle_num = camera_angle_num
- self.scale = scale
- self.geo_type = geo_type
- self.glctx = dr.RasterizeCudaContext()
-
- if self.geo_type == "flex":
- self.flexicubes = FlexiCubesGeometry(grid_res = self.tet_grid_size)
-
- def forward(self, data, sdf, deform, verts, tets, training=False, weight = None):
-
- results = {}
-
- deform = torch.tanh(deform) / self.tet_grid_size * self.scale / 0.95
- if self.geo_type == "flex":
- deform = deform *0.5
-
- v_deformed = verts + deform
-
- verts_list = []
- faces_list = []
- reg_list = []
- n_shape = verts.shape[0]
- for i in range(n_shape):
- verts_i, faces_i, reg_i = self.flexicubes.get_mesh(v_deformed[i], sdf[i].squeeze(dim=-1),
- with_uv=False, indices=tets, weight_n=weight[i], is_training=training)
-
- verts_list.append(verts_i)
- faces_list.append(faces_i)
- reg_list.append(reg_i)
- verts = verts_list
- faces = faces_list
-
- flexicubes_surface_reg = torch.cat(reg_list).mean()
- flexicubes_weight_reg = (weight ** 2).mean()
- results["flex_surf_loss"] = flexicubes_surface_reg
- results["flex_weight_loss"] = flexicubes_weight_reg
-
- return results, verts, faces
\ No newline at end of file
diff --git a/apps/third_party/CRM/util/tables.py b/apps/third_party/CRM/util/tables.py
deleted file mode 100644
index 936a4bc5e2f95891f72651f2c42272e01a3a2bc3..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/util/tables.py
+++ /dev/null
@@ -1,791 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-#
-# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
-# and proprietary rights in and to this software, related documentation
-# and any modifications thereto. Any use, reproduction, disclosure or
-# distribution of this software and related documentation without an express
-# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
-dmc_table = [
-[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]],
-[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]],
-[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
-[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]]
-]
-num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2,
-2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2,
-1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1,
-1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2,
-2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2,
-3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1,
-2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1,
-1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2,
-1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1,
-1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
-check_table = [
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[1, 1, 0, 0, 194],
-[1, -1, 0, 0, 193],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[1, 0, 1, 0, 164],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[1, 0, -1, 0, 161],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[1, 0, 0, 1, 152],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[1, 0, 0, 1, 145],
-[1, 0, 0, 1, 144],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[1, 0, 0, -1, 137],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[1, 0, 1, 0, 133],
-[1, 0, 1, 0, 132],
-[1, 1, 0, 0, 131],
-[1, 1, 0, 0, 130],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[1, 0, 0, 1, 100],
-[0, 0, 0, 0, 0],
-[1, 0, 0, 1, 98],
-[0, 0, 0, 0, 0],
-[1, 0, 0, 1, 96],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[1, 0, 1, 0, 88],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[1, 0, -1, 0, 82],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[1, 0, 1, 0, 74],
-[0, 0, 0, 0, 0],
-[1, 0, 1, 0, 72],
-[0, 0, 0, 0, 0],
-[1, 0, 0, -1, 70],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[1, -1, 0, 0, 67],
-[0, 0, 0, 0, 0],
-[1, -1, 0, 0, 65],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[1, 1, 0, 0, 56],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[1, -1, 0, 0, 52],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[1, 1, 0, 0, 44],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[1, 1, 0, 0, 40],
-[0, 0, 0, 0, 0],
-[1, 0, 0, -1, 38],
-[1, 0, -1, 0, 37],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[1, 0, -1, 0, 33],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[1, -1, 0, 0, 28],
-[0, 0, 0, 0, 0],
-[1, 0, -1, 0, 26],
-[1, 0, 0, -1, 25],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[1, -1, 0, 0, 20],
-[0, 0, 0, 0, 0],
-[1, 0, -1, 0, 18],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[1, 0, 0, -1, 9],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[1, 0, 0, -1, 6],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0]
-]
-tet_table = [
-[-1, -1, -1, -1, -1, -1],
-[0, 0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0, 0],
-[1, 1, 1, 1, 1, 1],
-[4, 4, 4, 4, 4, 4],
-[0, 0, 0, 0, 0, 0],
-[4, 0, 0, 4, 4, -1],
-[1, 1, 1, 1, 1, 1],
-[4, 4, 4, 4, 4, 4],
-[0, 4, 0, 4, 4, -1],
-[0, 0, 0, 0, 0, 0],
-[1, 1, 1, 1, 1, 1],
-[5, 5, 5, 5, 5, 5],
-[0, 0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0, 0],
-[1, 1, 1, 1, 1, 1],
-[2, 2, 2, 2, 2, 2],
-[0, 0, 0, 0, 0, 0],
-[2, 0, 2, -1, 0, 2],
-[1, 1, 1, 1, 1, 1],
-[2, -1, 2, 4, 4, 2],
-[0, 0, 0, 0, 0, 0],
-[2, 0, 2, 4, 4, 2],
-[1, 1, 1, 1, 1, 1],
-[2, 4, 2, 4, 4, 2],
-[0, 4, 0, 4, 4, 0],
-[2, 0, 2, 0, 0, 2],
-[1, 1, 1, 1, 1, 1],
-[2, 5, 2, 5, 5, 2],
-[0, 0, 0, 0, 0, 0],
-[2, 0, 2, 0, 0, 2],
-[1, 1, 1, 1, 1, 1],
-[1, 1, 1, 1, 1, 1],
-[0, 1, 1, -1, 0, 1],
-[0, 0, 0, 0, 0, 0],
-[2, 2, 2, 2, 2, 2],
-[4, 1, 1, 4, 4, 1],
-[0, 1, 1, 0, 0, 1],
-[4, 0, 0, 4, 4, 0],
-[2, 2, 2, 2, 2, 2],
-[-1, 1, 1, 4, 4, 1],
-[0, 1, 1, 4, 4, 1],
-[0, 0, 0, 0, 0, 0],
-[2, 2, 2, 2, 2, 2],
-[5, 1, 1, 5, 5, 1],
-[0, 1, 1, 0, 0, 1],
-[0, 0, 0, 0, 0, 0],
-[2, 2, 2, 2, 2, 2],
-[1, 1, 1, 1, 1, 1],
-[0, 0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0, 0],
-[8, 8, 8, 8, 8, 8],
-[1, 1, 1, 4, 4, 1],
-[0, 0, 0, 0, 0, 0],
-[4, 0, 0, 4, 4, 0],
-[4, 4, 4, 4, 4, 4],
-[1, 1, 1, 4, 4, 1],
-[0, 4, 0, 4, 4, 0],
-[0, 0, 0, 0, 0, 0],
-[4, 4, 4, 4, 4, 4],
-[1, 1, 1, 5, 5, 1],
-[0, 0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0, 0],
-[5, 5, 5, 5, 5, 5],
-[6, 6, 6, 6, 6, 6],
-[6, -1, 0, 6, 0, 6],
-[6, 0, 0, 6, 0, 6],
-[6, 1, 1, 6, 1, 6],
-[4, 4, 4, 4, 4, 4],
-[0, 0, 0, 0, 0, 0],
-[4, 0, 0, 4, 4, 4],
-[1, 1, 1, 1, 1, 1],
-[6, 4, -1, 6, 4, 6],
-[6, 4, 0, 6, 4, 6],
-[6, 0, 0, 6, 0, 6],
-[6, 1, 1, 6, 1, 6],
-[5, 5, 5, 5, 5, 5],
-[0, 0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0, 0],
-[1, 1, 1, 1, 1, 1],
-[2, 2, 2, 2, 2, 2],
-[0, 0, 0, 0, 0, 0],
-[2, 0, 2, 2, 0, 2],
-[1, 1, 1, 1, 1, 1],
-[2, 2, 2, 2, 2, 2],
-[0, 0, 0, 0, 0, 0],
-[2, 0, 2, 2, 2, 2],
-[1, 1, 1, 1, 1, 1],
-[2, 4, 2, 2, 4, 2],
-[0, 4, 0, 4, 4, 0],
-[2, 0, 2, 2, 0, 2],
-[1, 1, 1, 1, 1, 1],
-[2, 2, 2, 2, 2, 2],
-[0, 0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0, 0],
-[1, 1, 1, 1, 1, 1],
-[6, 1, 1, 6, -1, 6],
-[6, 1, 1, 6, 0, 6],
-[6, 0, 0, 6, 0, 6],
-[6, 2, 2, 6, 2, 6],
-[4, 1, 1, 4, 4, 1],
-[0, 1, 1, 0, 0, 1],
-[4, 0, 0, 4, 4, 4],
-[2, 2, 2, 2, 2, 2],
-[6, 1, 1, 6, 4, 6],
-[6, 1, 1, 6, 4, 6],
-[6, 0, 0, 6, 0, 6],
-[6, 2, 2, 6, 2, 6],
-[5, 1, 1, 5, 5, 1],
-[0, 1, 1, 0, 0, 1],
-[0, 0, 0, 0, 0, 0],
-[2, 2, 2, 2, 2, 2],
-[1, 1, 1, 1, 1, 1],
-[0, 0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0, 0],
-[6, 6, 6, 6, 6, 6],
-[1, 1, 1, 1, 1, 1],
-[0, 0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0, 0],
-[4, 4, 4, 4, 4, 4],
-[1, 1, 1, 1, 4, 1],
-[0, 4, 0, 4, 4, 0],
-[0, 0, 0, 0, 0, 0],
-[4, 4, 4, 4, 4, 4],
-[1, 1, 1, 1, 1, 1],
-[0, 0, 0, 0, 0, 0],
-[0, 5, 0, 5, 0, 5],
-[5, 5, 5, 5, 5, 5],
-[5, 5, 5, 5, 5, 5],
-[0, 5, 0, 5, 0, 5],
-[-1, 5, 0, 5, 0, 5],
-[1, 5, 1, 5, 1, 5],
-[4, 5, -1, 5, 4, 5],
-[0, 5, 0, 5, 0, 5],
-[4, 5, 0, 5, 4, 5],
-[1, 5, 1, 5, 1, 5],
-[4, 4, 4, 4, 4, 4],
-[0, 4, 0, 4, 4, 4],
-[0, 0, 0, 0, 0, 0],
-[1, 1, 1, 1, 1, 1],
-[6, 6, 6, 6, 6, 6],
-[0, 0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0, 0],
-[1, 1, 1, 1, 1, 1],
-[2, 5, 2, 5, -1, 5],
-[0, 5, 0, 5, 0, 5],
-[2, 5, 2, 5, 0, 5],
-[1, 5, 1, 5, 1, 5],
-[2, 5, 2, 5, 4, 5],
-[0, 5, 0, 5, 0, 5],
-[2, 5, 2, 5, 4, 5],
-[1, 5, 1, 5, 1, 5],
-[2, 4, 2, 4, 4, 2],
-[0, 4, 0, 4, 4, 4],
-[2, 0, 2, 0, 0, 2],
-[1, 1, 1, 1, 1, 1],
-[2, 6, 2, 6, 6, 2],
-[0, 0, 0, 0, 0, 0],
-[2, 0, 2, 0, 0, 2],
-[1, 1, 1, 1, 1, 1],
-[1, 1, 1, 1, 1, 1],
-[0, 1, 1, 1, 0, 1],
-[0, 0, 0, 0, 0, 0],
-[2, 2, 2, 2, 2, 2],
-[4, 1, 1, 1, 4, 1],
-[0, 1, 1, 1, 0, 1],
-[4, 0, 0, 4, 4, 0],
-[2, 2, 2, 2, 2, 2],
-[1, 1, 1, 1, 1, 1],
-[0, 1, 1, 1, 1, 1],
-[0, 0, 0, 0, 0, 0],
-[2, 2, 2, 2, 2, 2],
-[1, 1, 1, 1, 1, 1],
-[0, 0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0, 0],
-[2, 2, 2, 2, 2, 2],
-[1, 1, 1, 1, 1, 1],
-[0, 0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0, 0],
-[5, 5, 5, 5, 5, 5],
-[1, 1, 1, 1, 4, 1],
-[0, 0, 0, 0, 0, 0],
-[4, 0, 0, 4, 4, 0],
-[4, 4, 4, 4, 4, 4],
-[1, 1, 1, 1, 1, 1],
-[0, 0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0, 0],
-[4, 4, 4, 4, 4, 4],
-[1, 1, 1, 1, 1, 1],
-[6, 0, 0, 6, 0, 6],
-[0, 0, 0, 0, 0, 0],
-[6, 6, 6, 6, 6, 6],
-[5, 5, 5, 5, 5, 5],
-[5, 5, 0, 5, 0, 5],
-[5, 5, 0, 5, 0, 5],
-[5, 5, 1, 5, 1, 5],
-[4, 4, 4, 4, 4, 4],
-[0, 0, 0, 0, 0, 0],
-[4, 4, 0, 4, 4, 4],
-[1, 1, 1, 1, 1, 1],
-[4, 4, 4, 4, 4, 4],
-[4, 4, 0, 4, 4, 4],
-[0, 0, 0, 0, 0, 0],
-[1, 1, 1, 1, 1, 1],
-[8, 8, 8, 8, 8, 8],
-[0, 0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0, 0],
-[1, 1, 1, 1, 1, 1],
-[2, 2, 2, 2, 2, 2],
-[0, 0, 0, 0, 0, 0],
-[2, 2, 2, 2, 0, 2],
-[1, 1, 1, 1, 1, 1],
-[2, 2, 2, 2, 2, 2],
-[0, 0, 0, 0, 0, 0],
-[2, 2, 2, 2, 2, 2],
-[1, 1, 1, 1, 1, 1],
-[2, 2, 2, 2, 2, 2],
-[0, 0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0, 0],
-[4, 1, 1, 4, 4, 1],
-[2, 2, 2, 2, 2, 2],
-[0, 0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0, 0],
-[1, 1, 1, 1, 1, 1],
-[1, 1, 1, 1, 1, 1],
-[1, 1, 1, 1, 0, 1],
-[0, 0, 0, 0, 0, 0],
-[2, 2, 2, 2, 2, 2],
-[1, 1, 1, 1, 1, 1],
-[0, 0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0, 0],
-[2, 4, 2, 4, 4, 2],
-[1, 1, 1, 1, 1, 1],
-[1, 1, 1, 1, 1, 1],
-[0, 0, 0, 0, 0, 0],
-[2, 2, 2, 2, 2, 2],
-[1, 1, 1, 1, 1, 1],
-[0, 0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0, 0],
-[2, 2, 2, 2, 2, 2],
-[1, 1, 1, 1, 1, 1],
-[0, 0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0, 0],
-[5, 5, 5, 5, 5, 5],
-[1, 1, 1, 1, 1, 1],
-[0, 0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0, 0],
-[4, 4, 4, 4, 4, 4],
-[1, 1, 1, 1, 1, 1],
-[0, 0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0, 0],
-[4, 4, 4, 4, 4, 4],
-[1, 1, 1, 1, 1, 1],
-[0, 0, 0, 0, 0, 0],
-[0, 0, 0, 0, 0, 0],
-[12, 12, 12, 12, 12, 12]
-]
diff --git a/apps/third_party/CRM/util/utils.py b/apps/third_party/CRM/util/utils.py
deleted file mode 100644
index ab3cad9b703b1cdd9a3d6a66e82f898185ebf6c8..0000000000000000000000000000000000000000
--- a/apps/third_party/CRM/util/utils.py
+++ /dev/null
@@ -1,194 +0,0 @@
-import numpy as np
-import torch
-import random
-
-
-# Reworked so this matches gluPerspective / glm::perspective, using fovy
-def perspective(fovx=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None):
- # y = np.tan(fovy / 2)
- x = np.tan(fovx / 2)
- return torch.tensor([[1/x, 0, 0, 0],
- [ 0, -aspect/x, 0, 0],
- [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
- [ 0, 0, -1, 0]], dtype=torch.float32, device=device)
-
-
-def translate(x, y, z, device=None):
- return torch.tensor([[1, 0, 0, x],
- [0, 1, 0, y],
- [0, 0, 1, z],
- [0, 0, 0, 1]], dtype=torch.float32, device=device)
-
-
-def rotate_x(a, device=None):
- s, c = np.sin(a), np.cos(a)
- return torch.tensor([[1, 0, 0, 0],
- [0, c, -s, 0],
- [0, s, c, 0],
- [0, 0, 0, 1]], dtype=torch.float32, device=device)
-
-
-def rotate_y(a, device=None):
- s, c = np.sin(a), np.cos(a)
- return torch.tensor([[ c, 0, s, 0],
- [ 0, 1, 0, 0],
- [-s, 0, c, 0],
- [ 0, 0, 0, 1]], dtype=torch.float32, device=device)
-
-
-def rotate_z(a, device=None):
- s, c = np.sin(a), np.cos(a)
- return torch.tensor([[c, -s, 0, 0],
- [s, c, 0, 0],
- [0, 0, 1, 0],
- [0, 0, 0, 1]], dtype=torch.float32, device=device)
-
-@torch.no_grad()
-def batch_random_rotation_translation(b, t, device=None):
- m = np.random.normal(size=[b, 3, 3])
- m[:, 1] = np.cross(m[:, 0], m[:, 2])
- m[:, 2] = np.cross(m[:, 0], m[:, 1])
- m = m / np.linalg.norm(m, axis=2, keepdims=True)
- m = np.pad(m, [[0, 0], [0, 1], [0, 1]], mode='constant')
- m[:, 3, 3] = 1.0
- m[:, :3, 3] = np.random.uniform(-t, t, size=[b, 3])
- return torch.tensor(m, dtype=torch.float32, device=device)
-
-@torch.no_grad()
-def random_rotation_translation(t, device=None):
- m = np.random.normal(size=[3, 3])
- m[1] = np.cross(m[0], m[2])
- m[2] = np.cross(m[0], m[1])
- m = m / np.linalg.norm(m, axis=1, keepdims=True)
- m = np.pad(m, [[0, 1], [0, 1]], mode='constant')
- m[3, 3] = 1.0
- m[:3, 3] = np.random.uniform(-t, t, size=[3])
- return torch.tensor(m, dtype=torch.float32, device=device)
-
-
-@torch.no_grad()
-def random_rotation(device=None):
- m = np.random.normal(size=[3, 3])
- m[1] = np.cross(m[0], m[2])
- m[2] = np.cross(m[0], m[1])
- m = m / np.linalg.norm(m, axis=1, keepdims=True)
- m = np.pad(m, [[0, 1], [0, 1]], mode='constant')
- m[3, 3] = 1.0
- m[:3, 3] = np.array([0,0,0]).astype(np.float32)
- return torch.tensor(m, dtype=torch.float32, device=device)
-
-
-def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
- return torch.sum(x*y, -1, keepdim=True)
-
-
-def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
- return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN
-
-
-def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
- return x / length(x, eps)
-
-
-def lr_schedule(iter, warmup_iter, scheduler_decay):
- if iter < warmup_iter:
- return iter / warmup_iter
- return max(0.0, 10 ** (
- -(iter - warmup_iter) * scheduler_decay))
-
-
-def trans_depth(depth):
- depth = depth[0].detach().cpu().numpy()
- valid = depth > 0
- depth[valid] -= depth[valid].min()
- depth[valid] = ((depth[valid] / depth[valid].max()) * 255)
- return depth.astype('uint8')
-
-
-def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None):
- assert isinstance(input, torch.Tensor)
- if posinf is None:
- posinf = torch.finfo(input.dtype).max
- if neginf is None:
- neginf = torch.finfo(input.dtype).min
- assert nan == 0
- return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
-
-
-def load_item(filepath):
- with open(filepath, 'r') as f:
- items = [name.strip() for name in f.readlines()]
- return set(items)
-
-def load_prompt(filepath):
- uuid2prompt = {}
- with open(filepath, 'r') as f:
- for line in f.readlines():
- list_line = line.split(',')
- uuid2prompt[list_line[0]] = ','.join(list_line[1:]).strip()
- return uuid2prompt
-
-def resize_and_center_image(image_tensor, scale=0.95, c = 0, shift = 0, rgb=False, aug_shift = 0):
- if scale == 1:
- return image_tensor
- B, C, H, W = image_tensor.shape
- new_H, new_W = int(H * scale), int(W * scale)
- resized_image = torch.nn.functional.interpolate(image_tensor, size=(new_H, new_W), mode='bilinear', align_corners=False).squeeze(0)
- background = torch.zeros_like(image_tensor) + c
- start_y, start_x = (H - new_H) // 2, (W - new_W) // 2
- if shift == 0:
- background[:, :, start_y:start_y + new_H, start_x:start_x + new_W] = resized_image
- else:
- for i in range(B):
- randx = random.randint(-shift, shift)
- randy = random.randint(-shift, shift)
- if rgb == True:
- if i == 0 or i==2 or i==4:
- randx = 0
- randy = 0
- background[i, :, start_y+randy:start_y + new_H+randy, start_x+randx:start_x + new_W+randx] = resized_image[i]
- if aug_shift == 0:
- return background
- for i in range(B):
- for j in range(C):
- background[i, j, :, :] += (random.random() - 0.5)*2 * aug_shift / 255
- return background
-
-def get_tri(triview_color, dim = 1, blender=True, c = 0, scale=0.95, shift = 0, fix = False, rgb=False, aug_shift = 0):
- # triview_color: [6,C,H,W]
- # rgb is useful when shift is not 0
- triview_color = resize_and_center_image(triview_color, scale=scale, c = c, shift=shift,rgb=rgb, aug_shift = aug_shift)
- if blender is False:
- triview_color0 = torch.rot90(triview_color[0],k=2,dims=[1,2])
- triview_color1 = torch.rot90(triview_color[4],k=1,dims=[1,2]).flip(2).flip(1)
- triview_color2 = torch.rot90(triview_color[5],k=1,dims=[1,2]).flip(2)
- triview_color3 = torch.rot90(triview_color[3],k=2,dims=[1,2]).flip(2)
- triview_color4 = torch.rot90(triview_color[1],k=3,dims=[1,2]).flip(1)
- triview_color5 = torch.rot90(triview_color[2],k=3,dims=[1,2]).flip(1).flip(2)
- else:
- triview_color0 = torch.rot90(triview_color[2],k=2,dims=[1,2])
- triview_color1 = torch.rot90(triview_color[4],k=0,dims=[1,2]).flip(2).flip(1)
- triview_color2 = torch.rot90(torch.rot90(triview_color[0],k=3,dims=[1,2]).flip(2), k=2,dims=[1,2])
- triview_color3 = torch.rot90(torch.rot90(triview_color[5],k=2,dims=[1,2]).flip(2), k=2,dims=[1,2])
- triview_color4 = torch.rot90(triview_color[1],k=2,dims=[1,2]).flip(1).flip(1).flip(2)
- triview_color5 = torch.rot90(triview_color[3],k=1,dims=[1,2]).flip(1).flip(2)
- if fix == True:
- triview_color0[1] = triview_color0[1] * 0
- triview_color0[2] = triview_color0[2] * 0
- triview_color3[1] = triview_color3[1] * 0
- triview_color3[2] = triview_color3[2] * 0
-
- triview_color1[0] = triview_color1[0] * 0
- triview_color1[1] = triview_color1[1] * 0
- triview_color4[0] = triview_color4[0] * 0
- triview_color4[1] = triview_color4[1] * 0
-
- triview_color2[0] = triview_color2[0] * 0
- triview_color2[2] = triview_color2[2] * 0
- triview_color5[0] = triview_color5[0] * 0
- triview_color5[2] = triview_color5[2] * 0
- color_tensor1_gt = torch.cat((triview_color0, triview_color1, triview_color2), dim=2)
- color_tensor2_gt = torch.cat((triview_color3, triview_color4, triview_color5), dim=2)
- color_tensor_gt = torch.cat((color_tensor1_gt, color_tensor2_gt), dim = dim)
- return color_tensor_gt
-
diff --git a/apps/third_party/Era3D/configs/test_unclip-512-6view.yaml b/apps/third_party/Era3D/configs/test_unclip-512-6view.yaml
deleted file mode 100644
index d9353a5984d60e96390da5a3d06a70f09ae9d6b0..0000000000000000000000000000000000000000
--- a/apps/third_party/Era3D/configs/test_unclip-512-6view.yaml
+++ /dev/null
@@ -1,56 +0,0 @@
-pretrained_model_name_or_path: './MacLab-Era3D-512-6view'
-revision: null
-
-num_views: 6
-validation_dataset:
- prompt_embeds_path: mvdiffusion/data/fixed_prompt_embeds_6view
- root_dir: 'examples'
- num_views: ${num_views}
- bg_color: 'white'
- img_wh: [512, 512]
- num_validation_samples: 1000
- crop_size: 420
-
-pred_type: 'joint'
-save_dir: 'mv_res'
-save_mode: 'rgba' # 'concat', 'rgba', 'rgb'
-seed: 42
-validation_batch_size: 1
-dataloader_num_workers: 1
-local_rank: -1
-
-pipe_kwargs:
- num_views: ${num_views}
-
-validation_guidance_scales: [3.0]
-pipe_validation_kwargs:
- num_inference_steps: 40
- eta: 1.0
-
-validation_grid_nrow: ${num_views}
-regress_elevation: true
-regress_focal_length: true
-unet_from_pretrained_kwargs:
- unclip: true
- sdxl: false
- num_views: ${num_views}
- sample_size: 64
- zero_init_conv_in: false # modify
-
- regress_elevation: ${regress_elevation}
- regress_focal_length: ${regress_focal_length}
- camera_embedding_type: e_de_da_sincos
- projection_camera_embeddings_input_dim: 4 # 2 for elevation and 6 for focal_length
- zero_init_camera_projection: false
- num_regress_blocks: 3
-
- cd_attention_last: false
- cd_attention_mid: false
- multiview_attention: true
- sparse_mv_attention: true
- selfattn_block: self_rowwise
- mvcd_attention: true
-
- use_dino: false
-
-enable_xformers_memory_efficient_attention: true
\ No newline at end of file
diff --git a/apps/third_party/Era3D/data/fixed_prompt_embeds_6view/clr_embeds.pt b/apps/third_party/Era3D/data/fixed_prompt_embeds_6view/clr_embeds.pt
deleted file mode 100644
index de105d6a0a017da97af4644608a87785ec54d9cb..0000000000000000000000000000000000000000
--- a/apps/third_party/Era3D/data/fixed_prompt_embeds_6view/clr_embeds.pt
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:b9e51666588d0f075e031262744d371e12076160231aab19a531dbf7ab976e4d
-size 946932
diff --git a/apps/third_party/Era3D/data/fixed_prompt_embeds_6view/normal_embeds.pt b/apps/third_party/Era3D/data/fixed_prompt_embeds_6view/normal_embeds.pt
deleted file mode 100644
index 7fb88dcf24443b235588cf426eba3951316e825f..0000000000000000000000000000000000000000
--- a/apps/third_party/Era3D/data/fixed_prompt_embeds_6view/normal_embeds.pt
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:53dfcd17f62fbfd8aeba60b1b05fa7559d72179738fd048e2ac1d53e5be5ed9d
-size 946941
diff --git a/apps/third_party/Era3D/data/normal_utils.py b/apps/third_party/Era3D/data/normal_utils.py
deleted file mode 100644
index 1a390082ca0d891f044eca3c6fc291d8f29036de..0000000000000000000000000000000000000000
--- a/apps/third_party/Era3D/data/normal_utils.py
+++ /dev/null
@@ -1,78 +0,0 @@
-import numpy as np
-def deg2rad(deg):
- return deg*np.pi/180
-
-def inv_RT(RT):
- # RT_h = np.concatenate([RT, np.array([[0,0,0,1]])], axis=0)
- RT_inv = np.linalg.inv(RT)
-
- return RT_inv[:3, :]
-def camNormal2worldNormal(rot_c2w, camNormal):
- H,W,_ = camNormal.shape
- normal_img = np.matmul(rot_c2w[None, :, :], camNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
-
- return normal_img
-
-def worldNormal2camNormal(rot_w2c, normal_map_world):
- H,W,_ = normal_map_world.shape
- # normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
-
- # faster version
- # Reshape the normal map into a 2D array where each row represents a normal vector
- normal_map_flat = normal_map_world.reshape(-1, 3)
-
- # Transform the normal vectors using the transformation matrix
- normal_map_camera_flat = np.dot(normal_map_flat, rot_w2c.T)
-
- # Reshape the transformed normal map back to its original shape
- normal_map_camera = normal_map_camera_flat.reshape(normal_map_world.shape)
-
- return normal_map_camera
-
-def trans_normal(normal, RT_w2c, RT_w2c_target):
-
- # normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal)
- # normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3,:3], normal_world)
-
- relative_RT = np.matmul(RT_w2c_target[:3,:3], np.linalg.inv(RT_w2c[:3,:3]))
- return worldNormal2camNormal(relative_RT[:3,:3], normal)
-
-def trans_normal_complex(normal, RT_w2c, RT_w2c_rela_to_cond):
- # camview -> world -> condview
- normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal)
- # debug_normal_world = normal2img(normal_world)
-
- # relative_RT = np.matmul(RT_w2c_rela_to_cond[:3,:3], np.linalg.inv(RT_w2c[:3,:3]))
- normal_target_cam = worldNormal2camNormal(RT_w2c_rela_to_cond[:3,:3], normal_world)
- # normal_condview = normal2img(normal_target_cam)
- return normal_target_cam
-def img2normal(img):
- return (img/255.)*2-1
-
-def normal2img(normal):
- return np.uint8((normal*0.5+0.5)*255)
-
-def norm_normalize(normal, dim=-1):
-
- normal = normal/(np.linalg.norm(normal, axis=dim, keepdims=True)+1e-6)
-
- return normal
-
-def plot_grid_images(images, row, col, path=None):
- import cv2
- """
- Args:
- images: np.array [B, H, W, 3]
- row:
- col:
- save_path:
-
- Returns:
-
- """
- images = images.detach().cpu().numpy()
- assert row * col == images.shape[0]
- images = np.vstack([np.hstack(images[r * col:(r + 1) * col]) for r in range(row)])
- if path:
- cv2.imwrite(path, images[:,:,::-1] * 255)
- return images
\ No newline at end of file
diff --git a/apps/third_party/Era3D/data/single_image_dataset.py b/apps/third_party/Era3D/data/single_image_dataset.py
deleted file mode 100644
index e225ab8a7765173f576942f741bf2008567136d6..0000000000000000000000000000000000000000
--- a/apps/third_party/Era3D/data/single_image_dataset.py
+++ /dev/null
@@ -1,247 +0,0 @@
-from typing import Dict
-import numpy as np
-from omegaconf import DictConfig, ListConfig
-import torch
-from torch.utils.data import Dataset
-from pathlib import Path
-import json
-from PIL import Image
-from torchvision import transforms
-from einops import rearrange
-from typing import Literal, Tuple, Optional, Any
-import cv2
-import random
-
-import json
-import os, sys
-import math
-
-from glob import glob
-
-import PIL.Image
-from .normal_utils import trans_normal, normal2img, img2normal
-import pdb
-
-import cv2
-import numpy as np
-
-def add_margin(pil_img, color=0, size=256):
- width, height = pil_img.size
- result = Image.new(pil_img.mode, (size, size), color)
- result.paste(pil_img, ((size - width) // 2, (size - height) // 2))
- return result
-
-def scale_and_place_object(image, scale_factor):
- assert np.shape(image)[-1]==4 # RGBA
-
- # Extract the alpha channel (transparency) and the object (RGB channels)
- alpha_channel = image[:, :, 3]
-
- # Find the bounding box coordinates of the object
- coords = cv2.findNonZero(alpha_channel)
- x, y, width, height = cv2.boundingRect(coords)
-
- # Calculate the scale factor for resizing
- original_height, original_width = image.shape[:2]
-
- if width > height:
- size = width
- original_size = original_width
- else:
- size = height
- original_size = original_height
-
- scale_factor = min(scale_factor, size / (original_size+0.0))
-
- new_size = scale_factor * original_size
- scale_factor = new_size / size
-
- # Calculate the new size based on the scale factor
- new_width = int(width * scale_factor)
- new_height = int(height * scale_factor)
-
- center_x = original_width // 2
- center_y = original_height // 2
-
- paste_x = center_x - (new_width // 2)
- paste_y = center_y - (new_height // 2)
-
- # Resize the object (RGB channels) to the new size
- rescaled_object = cv2.resize(image[y:y+height, x:x+width], (new_width, new_height))
-
- # Create a new RGBA image with the resized image
- new_image = np.zeros((original_height, original_width, 4), dtype=np.uint8)
-
- new_image[paste_y:paste_y + new_height, paste_x:paste_x + new_width] = rescaled_object
-
- return new_image
-
-class SingleImageDataset(Dataset):
- def __init__(self,
- root_dir: str,
- num_views: int,
- img_wh: Tuple[int, int],
- bg_color: str,
- crop_size: int = 224,
- single_image: Optional[PIL.Image.Image] = None,
- num_validation_samples: Optional[int] = None,
- filepaths: Optional[list] = None,
- cond_type: Optional[str] = None,
- prompt_embeds_path: Optional[str] = None,
- gt_path: Optional[str] = None
- ) -> None:
- """Create a dataset from a folder of images.
- If you pass in a root directory it will be searched for images
- ending in ext (ext can be a list)
- """
- self.root_dir = root_dir
- self.num_views = num_views
- self.img_wh = img_wh
- self.crop_size = crop_size
- self.bg_color = bg_color
- self.cond_type = cond_type
- self.gt_path = gt_path
-
-
- if single_image is None:
- if filepaths is None:
- # Get a list of all files in the directory
- file_list = os.listdir(self.root_dir)
- else:
- file_list = filepaths
-
- # Filter the files that end with .png or .jpg
- self.file_list = [file for file in file_list if file.endswith(('.png', '.jpg', '.webp'))]
- else:
- self.file_list = None
-
- # load all images
- self.all_images = []
- self.all_alphas = []
- bg_color = self.get_bg_color()
-
- if single_image is not None:
- image, alpha = self.load_image(None, bg_color, return_type='pt', Imagefile=single_image)
- self.all_images.append(image)
- self.all_alphas.append(alpha)
- else:
- for file in self.file_list:
- print(os.path.join(self.root_dir, file))
- image, alpha = self.load_image(os.path.join(self.root_dir, file), bg_color, return_type='pt')
- self.all_images.append(image)
- self.all_alphas.append(alpha)
-
-
-
- self.all_images = self.all_images[:num_validation_samples]
- self.all_alphas = self.all_alphas[:num_validation_samples]
-
- try:
- self.normal_text_embeds = torch.load(f'{prompt_embeds_path}/normal_embeds.pt')
- self.color_text_embeds = torch.load(f'{prompt_embeds_path}/clr_embeds.pt') # 4view
- except:
- self.color_text_embeds = torch.load(f'{prompt_embeds_path}/embeds.pt')
- self.normal_text_embeds = None
-
- def __len__(self):
- return len(self.all_images)
-
- def get_bg_color(self):
- if self.bg_color == 'white':
- bg_color = np.array([1., 1., 1.], dtype=np.float32)
- elif self.bg_color == 'black':
- bg_color = np.array([0., 0., 0.], dtype=np.float32)
- elif self.bg_color == 'gray':
- bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
- elif self.bg_color == 'random':
- bg_color = np.random.rand(3)
- elif isinstance(self.bg_color, float):
- bg_color = np.array([self.bg_color] * 3, dtype=np.float32)
- else:
- raise NotImplementedError
- return bg_color
-
-
- def load_image(self, img_path, bg_color, return_type='np', Imagefile=None):
- # pil always returns uint8
- if Imagefile is None:
- image_input = Image.open(img_path)
- else:
- image_input = Imagefile
- image_size = self.img_wh[0]
-
- if self.crop_size!=-1:
- alpha_np = np.asarray(image_input)[:, :, 3]
- coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
- min_x, min_y = np.min(coords, 0)
- max_x, max_y = np.max(coords, 0)
- ref_img_ = image_input.crop((min_x, min_y, max_x, max_y))
- h, w = ref_img_.height, ref_img_.width
- scale = self.crop_size / max(h, w)
- h_, w_ = int(scale * h), int(scale * w)
- ref_img_ = ref_img_.resize((w_, h_))
- image_input = add_margin(ref_img_, size=image_size)
- else:
- image_input = add_margin(image_input, size=max(image_input.height, image_input.width))
- image_input = image_input.resize((image_size, image_size))
-
- # img = scale_and_place_object(img, self.scale_ratio)
- img = np.array(image_input)
- img = img.astype(np.float32) / 255. # [0, 1]
- assert img.shape[-1] == 4 # RGBA
-
- alpha = img[...,3:4]
- img = img[...,:3] * alpha + bg_color * (1 - alpha)
-
- if return_type == "np":
- pass
- elif return_type == "pt":
- img = torch.from_numpy(img)
- alpha = torch.from_numpy(alpha)
- else:
- raise NotImplementedError
-
- return img, alpha
-
-
- def __getitem__(self, index):
- image = self.all_images[index%len(self.all_images)]
- alpha = self.all_alphas[index%len(self.all_images)]
- if self.file_list is not None:
- filename = self.file_list[index%len(self.all_images)].replace(".png", "")
- else:
- filename = 'null'
- img_tensors_in = [
- image.permute(2, 0, 1)
- ] * self.num_views
-
- alpha_tensors_in = [
- alpha.permute(2, 0, 1)
- ] * self.num_views
-
- img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
- alpha_tensors_in = torch.stack(alpha_tensors_in, dim=0).float() # (Nv, 3, H, W)
-
- if self.gt_path is not None:
- gt_image = self.gt_images[index%len(self.all_images)]
- gt_alpha = self.gt_alpha[index%len(self.all_images)]
- gt_img_tensors_in = [gt_image.permute(2, 0, 1) ] * self.num_views
- gt_alpha_tensors_in = [gt_alpha.permute(2, 0, 1) ] * self.num_views
- gt_img_tensors_in = torch.stack(gt_img_tensors_in, dim=0).float()
- gt_alpha_tensors_in = torch.stack(gt_alpha_tensors_in, dim=0).float()
-
- normal_prompt_embeddings = self.normal_text_embeds if hasattr(self, 'normal_text_embeds') else None
- color_prompt_embeddings = self.color_text_embeds if hasattr(self, 'color_text_embeds') else None
-
- out = {
- 'imgs_in': img_tensors_in.unsqueeze(0),
- 'alphas': alpha_tensors_in.unsqueeze(0),
- 'normal_prompt_embeddings': normal_prompt_embeddings.unsqueeze(0),
- 'color_prompt_embeddings': color_prompt_embeddings.unsqueeze(0),
- 'filename': filename,
- }
-
- return out
-
-
-
diff --git a/apps/third_party/Era3D/mvdiffusion/models/transformer_mv2d_image.py b/apps/third_party/Era3D/mvdiffusion/models/transformer_mv2d_image.py
deleted file mode 100644
index 68d8b4d0ba8d05bfd6a760acf99080cfde7cbb62..0000000000000000000000000000000000000000
--- a/apps/third_party/Era3D/mvdiffusion/models/transformer_mv2d_image.py
+++ /dev/null
@@ -1,1029 +0,0 @@
-# Copyright 2023 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from dataclasses import dataclass
-from typing import Any, Dict, Optional
-
-import torch
-import torch.nn.functional as F
-from torch import nn
-
-from diffusers.configuration_utils import ConfigMixin, register_to_config
-from diffusers.models.embeddings import ImagePositionalEmbeddings
-from diffusers.utils import BaseOutput, deprecate
-from diffusers.utils.torch_utils import maybe_allow_in_graph
-from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
-from diffusers.models.embeddings import PatchEmbed
-from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
-from diffusers.models.modeling_utils import ModelMixin
-from diffusers.utils.import_utils import is_xformers_available
-
-from einops import rearrange, repeat
-import pdb
-import random
-
-
-if is_xformers_available():
- import xformers
- import xformers.ops
-else:
- xformers = None
-
-def my_repeat(tensor, num_repeats):
- """
- Repeat a tensor along a given dimension
- """
- if len(tensor.shape) == 3:
- return repeat(tensor, "b d c -> (b v) d c", v=num_repeats)
- elif len(tensor.shape) == 4:
- return repeat(tensor, "a b d c -> (a v) b d c", v=num_repeats)
-
-
-@dataclass
-class TransformerMV2DModelOutput(BaseOutput):
- """
- The output of [`Transformer2DModel`].
-
- Args:
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
- The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
- distributions for the unnoised latent pixels.
- """
-
- sample: torch.FloatTensor
-
-
-class TransformerMV2DModel(ModelMixin, ConfigMixin):
- """
- A 2D Transformer model for image-like data.
-
- Parameters:
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
- in_channels (`int`, *optional*):
- The number of channels in the input and output (specify if the input is **continuous**).
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
- cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
- sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
- This is fixed during training since it is used to learn a number of position embeddings.
- num_vector_embeds (`int`, *optional*):
- The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
- Includes the class for the masked latent pixel.
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
- num_embeds_ada_norm ( `int`, *optional*):
- The number of diffusion steps used during training. Pass if at least one of the norm_layers is
- `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
- added to the hidden states.
-
- During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
- attention_bias (`bool`, *optional*):
- Configure if the `TransformerBlocks` attention should contain a bias parameter.
- """
-
- @register_to_config
- def __init__(
- self,
- num_attention_heads: int = 16,
- attention_head_dim: int = 88,
- in_channels: Optional[int] = None,
- out_channels: Optional[int] = None,
- num_layers: int = 1,
- dropout: float = 0.0,
- norm_num_groups: int = 32,
- cross_attention_dim: Optional[int] = None,
- attention_bias: bool = False,
- sample_size: Optional[int] = None,
- num_vector_embeds: Optional[int] = None,
- patch_size: Optional[int] = None,
- activation_fn: str = "geglu",
- num_embeds_ada_norm: Optional[int] = None,
- use_linear_projection: bool = False,
- only_cross_attention: bool = False,
- upcast_attention: bool = False,
- norm_type: str = "layer_norm",
- norm_elementwise_affine: bool = True,
- num_views: int = 1,
- cd_attention_last: bool=False,
- cd_attention_mid: bool=False,
- multiview_attention: bool=True,
- sparse_mv_attention: bool = False,
- mvcd_attention: bool=False
- ):
- super().__init__()
- self.use_linear_projection = use_linear_projection
- self.num_attention_heads = num_attention_heads
- self.attention_head_dim = attention_head_dim
- inner_dim = num_attention_heads * attention_head_dim
-
- # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
- # Define whether input is continuous or discrete depending on configuration
- self.is_input_continuous = (in_channels is not None) and (patch_size is None)
- self.is_input_vectorized = num_vector_embeds is not None
- self.is_input_patches = in_channels is not None and patch_size is not None
-
- if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
- deprecation_message = (
- f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
- " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
- " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
- " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
- " would be very nice if you could open a Pull request for the `transformer/config.json` file"
- )
- deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
- norm_type = "ada_norm"
-
- if self.is_input_continuous and self.is_input_vectorized:
- raise ValueError(
- f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
- " sure that either `in_channels` or `num_vector_embeds` is None."
- )
- elif self.is_input_vectorized and self.is_input_patches:
- raise ValueError(
- f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
- " sure that either `num_vector_embeds` or `num_patches` is None."
- )
- elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
- raise ValueError(
- f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
- f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
- )
-
- # 2. Define input layers
- if self.is_input_continuous:
- self.in_channels = in_channels
-
- self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
- if use_linear_projection:
- self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
- else:
- self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
- elif self.is_input_vectorized:
- assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
- assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
-
- self.height = sample_size
- self.width = sample_size
- self.num_vector_embeds = num_vector_embeds
- self.num_latent_pixels = self.height * self.width
-
- self.latent_image_embedding = ImagePositionalEmbeddings(
- num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
- )
- elif self.is_input_patches:
- assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
-
- self.height = sample_size
- self.width = sample_size
-
- self.patch_size = patch_size
- self.pos_embed = PatchEmbed(
- height=sample_size,
- width=sample_size,
- patch_size=patch_size,
- in_channels=in_channels,
- embed_dim=inner_dim,
- )
-
- # 3. Define transformers blocks
- self.transformer_blocks = nn.ModuleList(
- [
- BasicMVTransformerBlock(
- inner_dim,
- num_attention_heads,
- attention_head_dim,
- dropout=dropout,
- cross_attention_dim=cross_attention_dim,
- activation_fn=activation_fn,
- num_embeds_ada_norm=num_embeds_ada_norm,
- attention_bias=attention_bias,
- only_cross_attention=only_cross_attention,
- upcast_attention=upcast_attention,
- norm_type=norm_type,
- norm_elementwise_affine=norm_elementwise_affine,
- num_views=num_views,
- cd_attention_last=cd_attention_last,
- cd_attention_mid=cd_attention_mid,
- multiview_attention=multiview_attention,
- sparse_mv_attention=sparse_mv_attention,
- mvcd_attention=mvcd_attention
- )
- for d in range(num_layers)
- ]
- )
-
- # 4. Define output layers
- self.out_channels = in_channels if out_channels is None else out_channels
- if self.is_input_continuous:
- # TODO: should use out_channels for continuous projections
- if use_linear_projection:
- self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
- else:
- self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
- elif self.is_input_vectorized:
- self.norm_out = nn.LayerNorm(inner_dim)
- self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
- elif self.is_input_patches:
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
- self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
- self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- timestep: Optional[torch.LongTensor] = None,
- class_labels: Optional[torch.LongTensor] = None,
- cross_attention_kwargs: Dict[str, Any] = None,
- attention_mask: Optional[torch.Tensor] = None,
- encoder_attention_mask: Optional[torch.Tensor] = None,
- return_dict: bool = True,
- ):
- """
- The [`Transformer2DModel`] forward method.
-
- Args:
- hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
- Input `hidden_states`.
- encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
- self-attention.
- timestep ( `torch.LongTensor`, *optional*):
- Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
- class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
- Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
- `AdaLayerZeroNorm`.
- encoder_attention_mask ( `torch.Tensor`, *optional*):
- Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
-
- * Mask `(batch, sequence_length)` True = keep, False = discard.
- * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
-
- If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
- above. This bias will be added to the cross-attention scores.
- return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
- tuple.
-
- Returns:
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
- `tuple` where the first element is the sample tensor.
- """
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
- # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
- # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
- # expects mask of shape:
- # [batch, key_tokens]
- # adds singleton query_tokens dimension:
- # [batch, 1, key_tokens]
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
- if attention_mask is not None and attention_mask.ndim == 2:
- # assume that mask is expressed as:
- # (1 = keep, 0 = discard)
- # convert mask into a bias that can be added to attention scores:
- # (keep = +0, discard = -10000.0)
- attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
- attention_mask = attention_mask.unsqueeze(1)
-
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
- encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
-
- # 1. Input
- if self.is_input_continuous:
- batch, _, height, width = hidden_states.shape
- residual = hidden_states
-
- hidden_states = self.norm(hidden_states)
- if not self.use_linear_projection:
- hidden_states = self.proj_in(hidden_states)
- inner_dim = hidden_states.shape[1]
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
- else:
- inner_dim = hidden_states.shape[1]
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
- hidden_states = self.proj_in(hidden_states)
- elif self.is_input_vectorized:
- hidden_states = self.latent_image_embedding(hidden_states)
- elif self.is_input_patches:
- hidden_states = self.pos_embed(hidden_states)
-
- # 2. Blocks
- for block in self.transformer_blocks:
- hidden_states = block(
- hidden_states,
- attention_mask=attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- timestep=timestep,
- cross_attention_kwargs=cross_attention_kwargs,
- class_labels=class_labels,
- )
-
- # 3. Output
- if self.is_input_continuous:
- if not self.use_linear_projection:
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
- hidden_states = self.proj_out(hidden_states)
- else:
- hidden_states = self.proj_out(hidden_states)
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
-
- output = hidden_states + residual
- elif self.is_input_vectorized:
- hidden_states = self.norm_out(hidden_states)
- logits = self.out(hidden_states)
- # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
- logits = logits.permute(0, 2, 1)
-
- # log(p(x_0))
- output = F.log_softmax(logits.double(), dim=1).float()
- elif self.is_input_patches:
- # TODO: cleanup!
- conditioning = self.transformer_blocks[0].norm1.emb(
- timestep, class_labels, hidden_dtype=hidden_states.dtype
- )
- shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
- hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
- hidden_states = self.proj_out_2(hidden_states)
-
- # unpatchify
- height = width = int(hidden_states.shape[1] ** 0.5)
- hidden_states = hidden_states.reshape(
- shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
- )
- hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
- output = hidden_states.reshape(
- shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
- )
-
- if not return_dict:
- return (output,)
-
- return TransformerMV2DModelOutput(sample=output)
-
-
-@maybe_allow_in_graph
-class BasicMVTransformerBlock(nn.Module):
- r"""
- A basic Transformer block.
-
- Parameters:
- dim (`int`): The number of channels in the input and output.
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
- attention_head_dim (`int`): The number of channels in each head.
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
- only_cross_attention (`bool`, *optional*):
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
- double_self_attention (`bool`, *optional*):
- Whether to use two self-attention layers. In this case no cross attention layers are used.
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
- num_embeds_ada_norm (:
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
- attention_bias (:
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
- """
-
- def __init__(
- self,
- dim: int,
- num_attention_heads: int,
- attention_head_dim: int,
- dropout=0.0,
- cross_attention_dim: Optional[int] = None,
- activation_fn: str = "geglu",
- num_embeds_ada_norm: Optional[int] = None,
- attention_bias: bool = False,
- only_cross_attention: bool = False,
- double_self_attention: bool = False,
- upcast_attention: bool = False,
- norm_elementwise_affine: bool = True,
- norm_type: str = "layer_norm",
- final_dropout: bool = False,
- num_views: int = 1,
- cd_attention_last: bool = False,
- cd_attention_mid: bool = False,
- multiview_attention: bool = True,
- sparse_mv_attention: bool = False,
- mvcd_attention: bool = False
- ):
- super().__init__()
- self.only_cross_attention = only_cross_attention
-
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
-
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
- raise ValueError(
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
- )
-
- # Define 3 blocks. Each block has its own normalization layer.
- # 1. Self-Attn
- if self.use_ada_layer_norm:
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
- elif self.use_ada_layer_norm_zero:
- self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
- else:
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
-
- self.multiview_attention = multiview_attention
- self.sparse_mv_attention = sparse_mv_attention
- self.mvcd_attention = mvcd_attention
-
- self.attn1 = CustomAttention(
- query_dim=dim,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- dropout=dropout,
- bias=attention_bias,
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
- upcast_attention=upcast_attention,
- processor=MVAttnProcessor()
- )
-
- # 2. Cross-Attn
- if cross_attention_dim is not None or double_self_attention:
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
- # the second cross attention block.
- self.norm2 = (
- AdaLayerNorm(dim, num_embeds_ada_norm)
- if self.use_ada_layer_norm
- else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
- )
- self.attn2 = Attention(
- query_dim=dim,
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- dropout=dropout,
- bias=attention_bias,
- upcast_attention=upcast_attention,
- ) # is self-attn if encoder_hidden_states is none
- else:
- self.norm2 = None
- self.attn2 = None
-
- # 3. Feed-forward
- self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
- self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
-
- # let chunk size default to None
- self._chunk_size = None
- self._chunk_dim = 0
-
- self.num_views = num_views
-
- self.cd_attention_last = cd_attention_last
-
- if self.cd_attention_last:
- # Joint task -Attn
- self.attn_joint_last = CustomJointAttention(
- query_dim=dim,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- dropout=dropout,
- bias=attention_bias,
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
- upcast_attention=upcast_attention,
- processor=JointAttnProcessor()
- )
- nn.init.zeros_(self.attn_joint_last.to_out[0].weight.data)
- self.norm_joint_last = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
-
-
- self.cd_attention_mid = cd_attention_mid
-
- if self.cd_attention_mid:
- print("cross-domain attn in the middle")
- # Joint task -Attn
- self.attn_joint_mid = CustomJointAttention(
- query_dim=dim,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- dropout=dropout,
- bias=attention_bias,
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
- upcast_attention=upcast_attention,
- processor=JointAttnProcessor()
- )
- nn.init.zeros_(self.attn_joint_mid.to_out[0].weight.data)
- self.norm_joint_mid = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
-
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
- # Sets chunk feed-forward
- self._chunk_size = chunk_size
- self._chunk_dim = dim
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- attention_mask: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
- timestep: Optional[torch.LongTensor] = None,
- cross_attention_kwargs: Dict[str, Any] = None,
- class_labels: Optional[torch.LongTensor] = None,
- ):
- assert attention_mask is None # not supported yet
- # Notice that normalization is always applied before the real computation in the following blocks.
- # 1. Self-Attention
- if self.use_ada_layer_norm:
- norm_hidden_states = self.norm1(hidden_states, timestep)
- elif self.use_ada_layer_norm_zero:
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
- )
- else:
- norm_hidden_states = self.norm1(hidden_states)
-
- cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
-
- attn_output = self.attn1(
- norm_hidden_states,
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
- attention_mask=attention_mask,
- num_views=self.num_views,
- multiview_attention=self.multiview_attention,
- sparse_mv_attention=self.sparse_mv_attention,
- mvcd_attention=self.mvcd_attention,
- **cross_attention_kwargs,
- )
-
-
- if self.use_ada_layer_norm_zero:
- attn_output = gate_msa.unsqueeze(1) * attn_output
- hidden_states = attn_output + hidden_states
-
- # joint attention twice
- if self.cd_attention_mid:
- norm_hidden_states = (
- self.norm_joint_mid(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_mid(hidden_states)
- )
- hidden_states = self.attn_joint_mid(norm_hidden_states) + hidden_states
-
- # 2. Cross-Attention
- if self.attn2 is not None:
- norm_hidden_states = (
- self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
- )
-
- attn_output = self.attn2(
- norm_hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=encoder_attention_mask,
- **cross_attention_kwargs,
- )
- hidden_states = attn_output + hidden_states
-
- # 3. Feed-forward
- norm_hidden_states = self.norm3(hidden_states)
-
- if self.use_ada_layer_norm_zero:
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
-
- if self._chunk_size is not None:
- # "feed_forward_chunk_size" can be used to save memory
- if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
- raise ValueError(
- f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
- )
-
- num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
- ff_output = torch.cat(
- [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
- dim=self._chunk_dim,
- )
- else:
- ff_output = self.ff(norm_hidden_states)
-
- if self.use_ada_layer_norm_zero:
- ff_output = gate_mlp.unsqueeze(1) * ff_output
-
- hidden_states = ff_output + hidden_states
-
- if self.cd_attention_last:
- norm_hidden_states = (
- self.norm_joint_last(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_last(hidden_states)
- )
- hidden_states = self.attn_joint_last(norm_hidden_states) + hidden_states
-
- return hidden_states
-
-
-class CustomAttention(Attention):
- def set_use_memory_efficient_attention_xformers(
- self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
- ):
- processor = XFormersMVAttnProcessor()
- self.set_processor(processor)
- # print("using xformers attention processor")
-
-
-class CustomJointAttention(Attention):
- def set_use_memory_efficient_attention_xformers(
- self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
- ):
- processor = XFormersJointAttnProcessor()
- self.set_processor(processor)
- # print("using xformers attention processor")
-
-class MVAttnProcessor:
- r"""
- Default processor for performing attention-related computations.
- """
-
- def __call__(
- self,
- attn: Attention,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- num_views=1,
- multiview_attention=True
- ):
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- # print('query', query.shape, 'key', key.shape, 'value', value.shape)
- #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
- # pdb.set_trace()
- # multi-view self-attention
- if multiview_attention:
- if num_views <= 6:
- # after use xformer; possible to train with 6 views
- key = rearrange(key, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
- value = rearrange(value, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
- else:# apply sparse attention
- pass
- # print("use sparse attention")
- # # seems that the sparse random sampling cause problems
- # # don't use random sampling, just fix the indexes
- # onekey = rearrange(key, "(b t) d c -> b t d c", t=num_views)
- # onevalue = rearrange(value, "(b t) d c -> b t d c", t=num_views)
- # allkeys = []
- # allvalues = []
- # all_indexes = {
- # 0 : [0, 2, 3, 4],
- # 1: [0, 1, 3, 5],
- # 2: [0, 2, 3, 4],
- # 3: [0, 2, 3, 4],
- # 4: [0, 2, 3, 4],
- # 5: [0, 1, 3, 5]
- # }
- # for jj in range(num_views):
- # # valid_index = [x for x in range(0, num_views) if x!= jj]
- # # indexes = random.sample(valid_index, 3) + [jj] + [0]
- # indexes = all_indexes[jj]
-
- # indexes = torch.tensor(indexes).long().to(key.device)
- # allkeys.append(onekey[:, indexes])
- # allvalues.append(onevalue[:, indexes])
- # keys = torch.stack(allkeys, dim=1) # checked, should be dim=1
- # values = torch.stack(allvalues, dim=1)
- # key = rearrange(keys, 'b t f d c -> (b t) (f d) c')
- # value = rearrange(values, 'b t f d c -> (b t) (f d) c')
-
-
- query = attn.head_to_batch_dim(query).contiguous()
- key = attn.head_to_batch_dim(key).contiguous()
- value = attn.head_to_batch_dim(value).contiguous()
-
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
- hidden_states = torch.bmm(attention_probs, value)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-class XFormersMVAttnProcessor:
- r"""
- Default processor for performing attention-related computations.
- """
-
- def __call__(
- self,
- attn: Attention,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- num_views=1.,
- multiview_attention=True,
- sparse_mv_attention=False,
- mvcd_attention=False,
- ):
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
-
- # from yuancheng; here attention_mask is None
- if attention_mask is not None:
- # expand our mask's singleton query_tokens dimension:
- # [batch*heads, 1, key_tokens] ->
- # [batch*heads, query_tokens, key_tokens]
- # so that it can be added as a bias onto the attention scores that xformers computes:
- # [batch*heads, query_tokens, key_tokens]
- # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
- _, query_tokens, _ = hidden_states.shape
- attention_mask = attention_mask.expand(-1, query_tokens, -1)
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key_raw = attn.to_k(encoder_hidden_states)
- value_raw = attn.to_v(encoder_hidden_states)
-
- # print('query', query.shape, 'key', key.shape, 'value', value.shape)
- #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
- # pdb.set_trace()
- # multi-view self-attention
- if multiview_attention:
- if not sparse_mv_attention:
- key = my_repeat(rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
- value = my_repeat(rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
- else:
- key_front = my_repeat(rearrange(key_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views) # [(b t), d, c]
- value_front = my_repeat(rearrange(value_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views)
- key = torch.cat([key_front, key_raw], dim=1) # shape (b t) (2 d) c
- value = torch.cat([value_front, value_raw], dim=1)
-
- if mvcd_attention:
- # memory efficient, cross domain attention
- key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2) # keys shape (b t) d c
- value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2)
- key_cross = torch.concat([key_1, key_0], dim=0)
- value_cross = torch.concat([value_1, value_0], dim=0) # shape (b t) d c
- key = torch.cat([key, key_cross], dim=1)
- value = torch.cat([value, value_cross], dim=1) # shape (b t) (t+1 d) c
- else:
- # print("don't use multiview attention.")
- key = key_raw
- value = value_raw
-
- query = attn.head_to_batch_dim(query)
- key = attn.head_to_batch_dim(key)
- value = attn.head_to_batch_dim(value)
-
- hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-
-class XFormersJointAttnProcessor:
- r"""
- Default processor for performing attention-related computations.
- """
-
- def __call__(
- self,
- attn: Attention,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- num_tasks=2
- ):
-
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
-
- # from yuancheng; here attention_mask is None
- if attention_mask is not None:
- # expand our mask's singleton query_tokens dimension:
- # [batch*heads, 1, key_tokens] ->
- # [batch*heads, query_tokens, key_tokens]
- # so that it can be added as a bias onto the attention scores that xformers computes:
- # [batch*heads, query_tokens, key_tokens]
- # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
- _, query_tokens, _ = hidden_states.shape
- attention_mask = attention_mask.expand(-1, query_tokens, -1)
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- assert num_tasks == 2 # only support two tasks now
-
- key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
- value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
- key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
- value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
- key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
- value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
-
-
- query = attn.head_to_batch_dim(query).contiguous()
- key = attn.head_to_batch_dim(key).contiguous()
- value = attn.head_to_batch_dim(value).contiguous()
-
- hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-class JointAttnProcessor:
- r"""
- Default processor for performing attention-related computations.
- """
-
- def __call__(
- self,
- attn: Attention,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- num_tasks=2
- ):
-
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
-
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- assert num_tasks == 2 # only support two tasks now
-
- key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
- value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
- key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
- value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
- key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
- value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
-
-
- query = attn.head_to_batch_dim(query).contiguous()
- key = attn.head_to_batch_dim(key).contiguous()
- value = attn.head_to_batch_dim(value).contiguous()
-
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
- hidden_states = torch.bmm(attention_probs, value)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
diff --git a/apps/third_party/Era3D/mvdiffusion/models/transformer_mv2d_rowwise.py b/apps/third_party/Era3D/mvdiffusion/models/transformer_mv2d_rowwise.py
deleted file mode 100644
index 5bf9643718a58410e099f71af2888f8da274fc5d..0000000000000000000000000000000000000000
--- a/apps/third_party/Era3D/mvdiffusion/models/transformer_mv2d_rowwise.py
+++ /dev/null
@@ -1,978 +0,0 @@
-# Copyright 2023 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from dataclasses import dataclass
-from typing import Any, Dict, Optional
-
-import torch
-import torch.nn.functional as F
-from torch import nn
-
-from diffusers.configuration_utils import ConfigMixin, register_to_config
-from diffusers.models.embeddings import ImagePositionalEmbeddings
-from diffusers.utils import BaseOutput, deprecate
-from diffusers.utils.torch_utils import maybe_allow_in_graph
-from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
-from diffusers.models.embeddings import PatchEmbed
-from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
-from diffusers.models.modeling_utils import ModelMixin
-from diffusers.utils.import_utils import is_xformers_available
-
-from einops import rearrange
-import pdb
-import random
-import math
-
-
-if is_xformers_available():
- import xformers
- import xformers.ops
-else:
- xformers = None
-
-
-@dataclass
-class TransformerMV2DModelOutput(BaseOutput):
- """
- The output of [`Transformer2DModel`].
-
- Args:
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
- The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
- distributions for the unnoised latent pixels.
- """
-
- sample: torch.FloatTensor
-
-
-class TransformerMV2DModel(ModelMixin, ConfigMixin):
- """
- A 2D Transformer model for image-like data.
-
- Parameters:
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
- in_channels (`int`, *optional*):
- The number of channels in the input and output (specify if the input is **continuous**).
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
- cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
- sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
- This is fixed during training since it is used to learn a number of position embeddings.
- num_vector_embeds (`int`, *optional*):
- The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
- Includes the class for the masked latent pixel.
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
- num_embeds_ada_norm ( `int`, *optional*):
- The number of diffusion steps used during training. Pass if at least one of the norm_layers is
- `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
- added to the hidden states.
-
- During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
- attention_bias (`bool`, *optional*):
- Configure if the `TransformerBlocks` attention should contain a bias parameter.
- """
-
- @register_to_config
- def __init__(
- self,
- num_attention_heads: int = 16,
- attention_head_dim: int = 88,
- in_channels: Optional[int] = None,
- out_channels: Optional[int] = None,
- num_layers: int = 1,
- dropout: float = 0.0,
- norm_num_groups: int = 32,
- cross_attention_dim: Optional[int] = None,
- attention_bias: bool = False,
- sample_size: Optional[int] = None,
- num_vector_embeds: Optional[int] = None,
- patch_size: Optional[int] = None,
- activation_fn: str = "geglu",
- num_embeds_ada_norm: Optional[int] = None,
- use_linear_projection: bool = False,
- only_cross_attention: bool = False,
- upcast_attention: bool = False,
- norm_type: str = "layer_norm",
- norm_elementwise_affine: bool = True,
- num_views: int = 1,
- cd_attention_last: bool=False,
- cd_attention_mid: bool=False,
- multiview_attention: bool=True,
- sparse_mv_attention: bool = True, # not used
- mvcd_attention: bool=False
- ):
- super().__init__()
- self.use_linear_projection = use_linear_projection
- self.num_attention_heads = num_attention_heads
- self.attention_head_dim = attention_head_dim
- inner_dim = num_attention_heads * attention_head_dim
-
- # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
- # Define whether input is continuous or discrete depending on configuration
- self.is_input_continuous = (in_channels is not None) and (patch_size is None)
- self.is_input_vectorized = num_vector_embeds is not None
- self.is_input_patches = in_channels is not None and patch_size is not None
-
- if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
- deprecation_message = (
- f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
- " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
- " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
- " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
- " would be very nice if you could open a Pull request for the `transformer/config.json` file"
- )
- deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
- norm_type = "ada_norm"
-
- if self.is_input_continuous and self.is_input_vectorized:
- raise ValueError(
- f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
- " sure that either `in_channels` or `num_vector_embeds` is None."
- )
- elif self.is_input_vectorized and self.is_input_patches:
- raise ValueError(
- f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
- " sure that either `num_vector_embeds` or `num_patches` is None."
- )
- elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
- raise ValueError(
- f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
- f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
- )
-
- # 2. Define input layers
- if self.is_input_continuous:
- self.in_channels = in_channels
-
- self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
- if use_linear_projection:
- self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
- else:
- self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
- elif self.is_input_vectorized:
- assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
- assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
-
- self.height = sample_size
- self.width = sample_size
- self.num_vector_embeds = num_vector_embeds
- self.num_latent_pixels = self.height * self.width
-
- self.latent_image_embedding = ImagePositionalEmbeddings(
- num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
- )
- elif self.is_input_patches:
- assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
-
- self.height = sample_size
- self.width = sample_size
-
- self.patch_size = patch_size
- self.pos_embed = PatchEmbed(
- height=sample_size,
- width=sample_size,
- patch_size=patch_size,
- in_channels=in_channels,
- embed_dim=inner_dim,
- )
-
- # 3. Define transformers blocks
- self.transformer_blocks = nn.ModuleList(
- [
- BasicMVTransformerBlock(
- inner_dim,
- num_attention_heads,
- attention_head_dim,
- dropout=dropout,
- cross_attention_dim=cross_attention_dim,
- activation_fn=activation_fn,
- num_embeds_ada_norm=num_embeds_ada_norm,
- attention_bias=attention_bias,
- only_cross_attention=only_cross_attention,
- upcast_attention=upcast_attention,
- norm_type=norm_type,
- norm_elementwise_affine=norm_elementwise_affine,
- num_views=num_views,
- cd_attention_last=cd_attention_last,
- cd_attention_mid=cd_attention_mid,
- multiview_attention=multiview_attention,
- mvcd_attention=mvcd_attention
- )
- for d in range(num_layers)
- ]
- )
-
- # 4. Define output layers
- self.out_channels = in_channels if out_channels is None else out_channels
- if self.is_input_continuous:
- # TODO: should use out_channels for continuous projections
- if use_linear_projection:
- self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
- else:
- self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
- elif self.is_input_vectorized:
- self.norm_out = nn.LayerNorm(inner_dim)
- self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
- elif self.is_input_patches:
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
- self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
- self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- timestep: Optional[torch.LongTensor] = None,
- class_labels: Optional[torch.LongTensor] = None,
- cross_attention_kwargs: Dict[str, Any] = None,
- attention_mask: Optional[torch.Tensor] = None,
- encoder_attention_mask: Optional[torch.Tensor] = None,
- return_dict: bool = True,
- ):
- """
- The [`Transformer2DModel`] forward method.
-
- Args:
- hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
- Input `hidden_states`.
- encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
- self-attention.
- timestep ( `torch.LongTensor`, *optional*):
- Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
- class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
- Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
- `AdaLayerZeroNorm`.
- encoder_attention_mask ( `torch.Tensor`, *optional*):
- Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
-
- * Mask `(batch, sequence_length)` True = keep, False = discard.
- * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
-
- If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
- above. This bias will be added to the cross-attention scores.
- return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
- tuple.
-
- Returns:
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
- `tuple` where the first element is the sample tensor.
- """
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
- # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
- # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
- # expects mask of shape:
- # [batch, key_tokens]
- # adds singleton query_tokens dimension:
- # [batch, 1, key_tokens]
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
- if attention_mask is not None and attention_mask.ndim == 2:
- # assume that mask is expressed as:
- # (1 = keep, 0 = discard)
- # convert mask into a bias that can be added to attention scores:
- # (keep = +0, discard = -10000.0)
- attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
- attention_mask = attention_mask.unsqueeze(1)
-
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
- encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
-
- # 1. Input
- if self.is_input_continuous:
- batch, _, height, width = hidden_states.shape
- residual = hidden_states
-
- hidden_states = self.norm(hidden_states)
- if not self.use_linear_projection:
- hidden_states = self.proj_in(hidden_states)
- inner_dim = hidden_states.shape[1]
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
- else:
- inner_dim = hidden_states.shape[1]
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
- hidden_states = self.proj_in(hidden_states)
- elif self.is_input_vectorized:
- hidden_states = self.latent_image_embedding(hidden_states)
- elif self.is_input_patches:
- hidden_states = self.pos_embed(hidden_states)
-
- # 2. Blocks
- for block in self.transformer_blocks:
- hidden_states = block(
- hidden_states,
- attention_mask=attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- timestep=timestep,
- cross_attention_kwargs=cross_attention_kwargs,
- class_labels=class_labels,
- )
-
- # 3. Output
- if self.is_input_continuous:
- if not self.use_linear_projection:
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
- hidden_states = self.proj_out(hidden_states)
- else:
- hidden_states = self.proj_out(hidden_states)
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
-
- output = hidden_states + residual
- elif self.is_input_vectorized:
- hidden_states = self.norm_out(hidden_states)
- logits = self.out(hidden_states)
- # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
- logits = logits.permute(0, 2, 1)
-
- # log(p(x_0))
- output = F.log_softmax(logits.double(), dim=1).float()
- elif self.is_input_patches:
- # TODO: cleanup!
- conditioning = self.transformer_blocks[0].norm1.emb(
- timestep, class_labels, hidden_dtype=hidden_states.dtype
- )
- shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
- hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
- hidden_states = self.proj_out_2(hidden_states)
-
- # unpatchify
- height = width = int(hidden_states.shape[1] ** 0.5)
- hidden_states = hidden_states.reshape(
- shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
- )
- hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
- output = hidden_states.reshape(
- shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
- )
-
- if not return_dict:
- return (output,)
-
- return TransformerMV2DModelOutput(sample=output)
-
-
-@maybe_allow_in_graph
-class BasicMVTransformerBlock(nn.Module):
- r"""
- A basic Transformer block.
-
- Parameters:
- dim (`int`): The number of channels in the input and output.
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
- attention_head_dim (`int`): The number of channels in each head.
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
- only_cross_attention (`bool`, *optional*):
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
- double_self_attention (`bool`, *optional*):
- Whether to use two self-attention layers. In this case no cross attention layers are used.
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
- num_embeds_ada_norm (:
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
- attention_bias (:
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
- """
-
- def __init__(
- self,
- dim: int,
- num_attention_heads: int,
- attention_head_dim: int,
- dropout=0.0,
- cross_attention_dim: Optional[int] = None,
- activation_fn: str = "geglu",
- num_embeds_ada_norm: Optional[int] = None,
- attention_bias: bool = False,
- only_cross_attention: bool = False,
- double_self_attention: bool = False,
- upcast_attention: bool = False,
- norm_elementwise_affine: bool = True,
- norm_type: str = "layer_norm",
- final_dropout: bool = False,
- num_views: int = 1,
- cd_attention_last: bool = False,
- cd_attention_mid: bool = False,
- multiview_attention: bool = True,
- mvcd_attention: bool = False,
- rowwise_attention: bool = True
- ):
- super().__init__()
- self.only_cross_attention = only_cross_attention
-
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
-
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
- raise ValueError(
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
- )
-
- # Define 3 blocks. Each block has its own normalization layer.
- # 1. Self-Attn
- if self.use_ada_layer_norm:
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
- elif self.use_ada_layer_norm_zero:
- self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
- else:
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
-
- self.multiview_attention = multiview_attention
- self.mvcd_attention = mvcd_attention
- self.rowwise_attention = multiview_attention and rowwise_attention
-
- # rowwise multiview attention
-
- print('INFO: using row wise attention...')
-
- self.attn1 = CustomAttention(
- query_dim=dim,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- dropout=dropout,
- bias=attention_bias,
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
- upcast_attention=upcast_attention,
- processor=MVAttnProcessor()
- )
-
- # 2. Cross-Attn
- if cross_attention_dim is not None or double_self_attention:
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
- # the second cross attention block.
- self.norm2 = (
- AdaLayerNorm(dim, num_embeds_ada_norm)
- if self.use_ada_layer_norm
- else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
- )
- self.attn2 = Attention(
- query_dim=dim,
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- dropout=dropout,
- bias=attention_bias,
- upcast_attention=upcast_attention,
- ) # is self-attn if encoder_hidden_states is none
- else:
- self.norm2 = None
- self.attn2 = None
-
- # 3. Feed-forward
- self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
- self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
-
- # let chunk size default to None
- self._chunk_size = None
- self._chunk_dim = 0
-
- self.num_views = num_views
-
- self.cd_attention_last = cd_attention_last
-
- if self.cd_attention_last:
- # Joint task -Attn
- self.attn_joint = CustomJointAttention(
- query_dim=dim,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- dropout=dropout,
- bias=attention_bias,
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
- upcast_attention=upcast_attention,
- processor=JointAttnProcessor()
- )
- nn.init.zeros_(self.attn_joint.to_out[0].weight.data)
- self.norm_joint = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
-
-
- self.cd_attention_mid = cd_attention_mid
-
- if self.cd_attention_mid:
- print("joint twice")
- # Joint task -Attn
- self.attn_joint_twice = CustomJointAttention(
- query_dim=dim,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- dropout=dropout,
- bias=attention_bias,
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
- upcast_attention=upcast_attention,
- processor=JointAttnProcessor()
- )
- nn.init.zeros_(self.attn_joint_twice.to_out[0].weight.data)
- self.norm_joint_twice = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
-
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
- # Sets chunk feed-forward
- self._chunk_size = chunk_size
- self._chunk_dim = dim
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- attention_mask: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
- timestep: Optional[torch.LongTensor] = None,
- cross_attention_kwargs: Dict[str, Any] = None,
- class_labels: Optional[torch.LongTensor] = None,
- ):
- assert attention_mask is None # not supported yet
- # Notice that normalization is always applied before the real computation in the following blocks.
- # 1. Self-Attention
- if self.use_ada_layer_norm:
- norm_hidden_states = self.norm1(hidden_states, timestep)
- elif self.use_ada_layer_norm_zero:
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
- )
- else:
- norm_hidden_states = self.norm1(hidden_states)
-
- cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
-
- attn_output = self.attn1(
- norm_hidden_states,
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
- attention_mask=attention_mask,
- multiview_attention=self.multiview_attention,
- mvcd_attention=self.mvcd_attention,
- num_views=self.num_views,
- **cross_attention_kwargs,
- )
-
- if self.use_ada_layer_norm_zero:
- attn_output = gate_msa.unsqueeze(1) * attn_output
- hidden_states = attn_output + hidden_states
-
- # joint attention twice
- if self.cd_attention_mid:
- norm_hidden_states = (
- self.norm_joint_twice(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_twice(hidden_states)
- )
- hidden_states = self.attn_joint_twice(norm_hidden_states) + hidden_states
-
- # 2. Cross-Attention
- if self.attn2 is not None:
- norm_hidden_states = (
- self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
- )
-
- attn_output = self.attn2(
- norm_hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=encoder_attention_mask,
- **cross_attention_kwargs,
- )
- hidden_states = attn_output + hidden_states
-
- # 3. Feed-forward
- norm_hidden_states = self.norm3(hidden_states)
-
- if self.use_ada_layer_norm_zero:
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
-
- if self._chunk_size is not None:
- # "feed_forward_chunk_size" can be used to save memory
- if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
- raise ValueError(
- f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
- )
-
- num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
- ff_output = torch.cat(
- [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
- dim=self._chunk_dim,
- )
- else:
- ff_output = self.ff(norm_hidden_states)
-
- if self.use_ada_layer_norm_zero:
- ff_output = gate_mlp.unsqueeze(1) * ff_output
-
- hidden_states = ff_output + hidden_states
-
- if self.cd_attention_last:
- norm_hidden_states = (
- self.norm_joint(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint(hidden_states)
- )
- hidden_states = self.attn_joint(norm_hidden_states) + hidden_states
-
- return hidden_states
-
-
-class CustomAttention(Attention):
- def set_use_memory_efficient_attention_xformers(
- self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
- ):
- processor = XFormersMVAttnProcessor()
- self.set_processor(processor)
- # print("using xformers attention processor")
-
-
-class CustomJointAttention(Attention):
- def set_use_memory_efficient_attention_xformers(
- self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
- ):
- processor = XFormersJointAttnProcessor()
- self.set_processor(processor)
- # print("using xformers attention processor")
-
-class MVAttnProcessor:
- r"""
- Default processor for performing attention-related computations.
- """
-
- def __call__(
- self,
- attn: Attention,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- num_views=1,
- multiview_attention=True
- ):
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
- height = int(math.sqrt(sequence_length))
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- # print('query', query.shape, 'key', key.shape, 'value', value.shape)
- #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
- # pdb.set_trace()
- # multi-view self-attention
- key = rearrange(key, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
- value = rearrange(value, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
- query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
-
- query = attn.head_to_batch_dim(query).contiguous()
- key = attn.head_to_batch_dim(key).contiguous()
- value = attn.head_to_batch_dim(value).contiguous()
-
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
- hidden_states = torch.bmm(attention_probs, value)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-class XFormersMVAttnProcessor:
- r"""
- Default processor for performing attention-related computations.
- """
-
- def __call__(
- self,
- attn: Attention,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- num_views=1,
- multiview_attention=True,
- mvcd_attention=False,
- ):
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
- height = int(math.sqrt(sequence_length))
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # from yuancheng; here attention_mask is None
- if attention_mask is not None:
- # expand our mask's singleton query_tokens dimension:
- # [batch*heads, 1, key_tokens] ->
- # [batch*heads, query_tokens, key_tokens]
- # so that it can be added as a bias onto the attention scores that xformers computes:
- # [batch*heads, query_tokens, key_tokens]
- # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
- _, query_tokens, _ = hidden_states.shape
- attention_mask = attention_mask.expand(-1, query_tokens, -1)
-
- if attn.group_norm is not None:
- print('Warning: using group norm, pay attention to use it in row-wise attention')
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key_raw = attn.to_k(encoder_hidden_states)
- value_raw = attn.to_v(encoder_hidden_states)
-
- # print('query', query.shape, 'key', key.shape, 'value', value.shape)
- # pdb.set_trace()
-
- key = rearrange(key_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
- value = rearrange(value_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
- query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
- if mvcd_attention:
- # memory efficient, cross domain attention
- key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2) # keys shape (b t) d c
- value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2)
- key_cross = torch.concat([key_1, key_0], dim=0)
- value_cross = torch.concat([value_1, value_0], dim=0) # shape (b t) d c
- key = torch.cat([key, key_cross], dim=1)
- value = torch.cat([value, value_cross], dim=1) # shape (b t) (t+1 d) c
-
-
- query = attn.head_to_batch_dim(query) # torch.Size([960, 384, 64])
- key = attn.head_to_batch_dim(key)
- value = attn.head_to_batch_dim(value)
-
- hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- # print(hidden_states.shape)
- hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-class XFormersJointAttnProcessor:
- r"""
- Default processor for performing attention-related computations.
- """
-
- def __call__(
- self,
- attn: Attention,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- num_tasks=2
- ):
-
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
-
- # from yuancheng; here attention_mask is None
- if attention_mask is not None:
- # expand our mask's singleton query_tokens dimension:
- # [batch*heads, 1, key_tokens] ->
- # [batch*heads, query_tokens, key_tokens]
- # so that it can be added as a bias onto the attention scores that xformers computes:
- # [batch*heads, query_tokens, key_tokens]
- # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
- _, query_tokens, _ = hidden_states.shape
- attention_mask = attention_mask.expand(-1, query_tokens, -1)
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- assert num_tasks == 2 # only support two tasks now
-
- key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
- value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
- key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
- value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
- key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
- value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
-
-
- query = attn.head_to_batch_dim(query).contiguous()
- key = attn.head_to_batch_dim(key).contiguous()
- value = attn.head_to_batch_dim(value).contiguous()
-
- hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-class JointAttnProcessor:
- r"""
- Default processor for performing attention-related computations.
- """
-
- def __call__(
- self,
- attn: Attention,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- num_tasks=2
- ):
-
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
-
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- assert num_tasks == 2 # only support two tasks now
-
- key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
- value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
- key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
- value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
- key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
- value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
-
-
- query = attn.head_to_batch_dim(query).contiguous()
- key = attn.head_to_batch_dim(key).contiguous()
- value = attn.head_to_batch_dim(value).contiguous()
-
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
- hidden_states = torch.bmm(attention_probs, value)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
\ No newline at end of file
diff --git a/apps/third_party/Era3D/mvdiffusion/models/transformer_mv2d_self_rowwise.py b/apps/third_party/Era3D/mvdiffusion/models/transformer_mv2d_self_rowwise.py
deleted file mode 100644
index 4eca703670c8d2f1d0fbb08027694e87acb8c926..0000000000000000000000000000000000000000
--- a/apps/third_party/Era3D/mvdiffusion/models/transformer_mv2d_self_rowwise.py
+++ /dev/null
@@ -1,1038 +0,0 @@
-# Copyright 2023 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from dataclasses import dataclass
-from typing import Any, Dict, Optional
-
-import torch
-import torch.nn.functional as F
-from torch import nn
-
-from diffusers.configuration_utils import ConfigMixin, register_to_config
-from diffusers.models.embeddings import ImagePositionalEmbeddings
-from diffusers.utils import BaseOutput, deprecate
-from diffusers.utils.torch_utils import maybe_allow_in_graph
-from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
-from diffusers.models.embeddings import PatchEmbed
-from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
-from diffusers.models.modeling_utils import ModelMixin
-from diffusers.utils.import_utils import is_xformers_available
-
-from einops import rearrange
-import pdb
-import random
-import math
-
-
-if is_xformers_available():
- import xformers
- import xformers.ops
-else:
- xformers = None
-
-
-@dataclass
-class TransformerMV2DModelOutput(BaseOutput):
- """
- The output of [`Transformer2DModel`].
-
- Args:
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
- The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
- distributions for the unnoised latent pixels.
- """
-
- sample: torch.FloatTensor
-
-
-class TransformerMV2DModel(ModelMixin, ConfigMixin):
- """
- A 2D Transformer model for image-like data.
-
- Parameters:
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
- in_channels (`int`, *optional*):
- The number of channels in the input and output (specify if the input is **continuous**).
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
- cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
- sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
- This is fixed during training since it is used to learn a number of position embeddings.
- num_vector_embeds (`int`, *optional*):
- The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
- Includes the class for the masked latent pixel.
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
- num_embeds_ada_norm ( `int`, *optional*):
- The number of diffusion steps used during training. Pass if at least one of the norm_layers is
- `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
- added to the hidden states.
-
- During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
- attention_bias (`bool`, *optional*):
- Configure if the `TransformerBlocks` attention should contain a bias parameter.
- """
-
- @register_to_config
- def __init__(
- self,
- num_attention_heads: int = 16,
- attention_head_dim: int = 88,
- in_channels: Optional[int] = None,
- out_channels: Optional[int] = None,
- num_layers: int = 1,
- dropout: float = 0.0,
- norm_num_groups: int = 32,
- cross_attention_dim: Optional[int] = None,
- attention_bias: bool = False,
- sample_size: Optional[int] = None,
- num_vector_embeds: Optional[int] = None,
- patch_size: Optional[int] = None,
- activation_fn: str = "geglu",
- num_embeds_ada_norm: Optional[int] = None,
- use_linear_projection: bool = False,
- only_cross_attention: bool = False,
- upcast_attention: bool = False,
- norm_type: str = "layer_norm",
- norm_elementwise_affine: bool = True,
- num_views: int = 1,
- cd_attention_mid: bool=False,
- cd_attention_last: bool=False,
- multiview_attention: bool=True,
- sparse_mv_attention: bool = True, # not used
- mvcd_attention: bool=False,
- use_dino: bool=False
- ):
- super().__init__()
- self.use_linear_projection = use_linear_projection
- self.num_attention_heads = num_attention_heads
- self.attention_head_dim = attention_head_dim
- inner_dim = num_attention_heads * attention_head_dim
-
- # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
- # Define whether input is continuous or discrete depending on configuration
- self.is_input_continuous = (in_channels is not None) and (patch_size is None)
- self.is_input_vectorized = num_vector_embeds is not None
- self.is_input_patches = in_channels is not None and patch_size is not None
-
- if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
- deprecation_message = (
- f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
- " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
- " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
- " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
- " would be very nice if you could open a Pull request for the `transformer/config.json` file"
- )
- deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
- norm_type = "ada_norm"
-
- if self.is_input_continuous and self.is_input_vectorized:
- raise ValueError(
- f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
- " sure that either `in_channels` or `num_vector_embeds` is None."
- )
- elif self.is_input_vectorized and self.is_input_patches:
- raise ValueError(
- f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
- " sure that either `num_vector_embeds` or `num_patches` is None."
- )
- elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
- raise ValueError(
- f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
- f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
- )
-
- # 2. Define input layers
- if self.is_input_continuous:
- self.in_channels = in_channels
-
- self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
- if use_linear_projection:
- self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
- else:
- self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
- elif self.is_input_vectorized:
- assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
- assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
-
- self.height = sample_size
- self.width = sample_size
- self.num_vector_embeds = num_vector_embeds
- self.num_latent_pixels = self.height * self.width
-
- self.latent_image_embedding = ImagePositionalEmbeddings(
- num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
- )
- elif self.is_input_patches:
- assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
-
- self.height = sample_size
- self.width = sample_size
-
- self.patch_size = patch_size
- self.pos_embed = PatchEmbed(
- height=sample_size,
- width=sample_size,
- patch_size=patch_size,
- in_channels=in_channels,
- embed_dim=inner_dim,
- )
-
- # 3. Define transformers blocks
- self.transformer_blocks = nn.ModuleList(
- [
- BasicMVTransformerBlock(
- inner_dim,
- num_attention_heads,
- attention_head_dim,
- dropout=dropout,
- cross_attention_dim=cross_attention_dim,
- activation_fn=activation_fn,
- num_embeds_ada_norm=num_embeds_ada_norm,
- attention_bias=attention_bias,
- only_cross_attention=only_cross_attention,
- upcast_attention=upcast_attention,
- norm_type=norm_type,
- norm_elementwise_affine=norm_elementwise_affine,
- num_views=num_views,
- cd_attention_last=cd_attention_last,
- cd_attention_mid=cd_attention_mid,
- multiview_attention=multiview_attention,
- mvcd_attention=mvcd_attention,
- use_dino=use_dino
- )
- for d in range(num_layers)
- ]
- )
-
- # 4. Define output layers
- self.out_channels = in_channels if out_channels is None else out_channels
- if self.is_input_continuous:
- # TODO: should use out_channels for continuous projections
- if use_linear_projection:
- self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
- else:
- self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
- elif self.is_input_vectorized:
- self.norm_out = nn.LayerNorm(inner_dim)
- self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
- elif self.is_input_patches:
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
- self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
- self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- dino_feature: Optional[torch.Tensor] = None,
- timestep: Optional[torch.LongTensor] = None,
- class_labels: Optional[torch.LongTensor] = None,
- cross_attention_kwargs: Dict[str, Any] = None,
- attention_mask: Optional[torch.Tensor] = None,
- encoder_attention_mask: Optional[torch.Tensor] = None,
- return_dict: bool = True,
- ):
- """
- The [`Transformer2DModel`] forward method.
-
- Args:
- hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
- Input `hidden_states`.
- encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
- self-attention.
- timestep ( `torch.LongTensor`, *optional*):
- Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
- class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
- Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
- `AdaLayerZeroNorm`.
- encoder_attention_mask ( `torch.Tensor`, *optional*):
- Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
-
- * Mask `(batch, sequence_length)` True = keep, False = discard.
- * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
-
- If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
- above. This bias will be added to the cross-attention scores.
- return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
- tuple.
-
- Returns:
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
- `tuple` where the first element is the sample tensor.
- """
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
- # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
- # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
- # expects mask of shape:
- # [batch, key_tokens]
- # adds singleton query_tokens dimension:
- # [batch, 1, key_tokens]
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
- if attention_mask is not None and attention_mask.ndim == 2:
- # assume that mask is expressed as:
- # (1 = keep, 0 = discard)
- # convert mask into a bias that can be added to attention scores:
- # (keep = +0, discard = -10000.0)
- attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
- attention_mask = attention_mask.unsqueeze(1)
-
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
- encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
-
- # 1. Input
- if self.is_input_continuous:
- batch, _, height, width = hidden_states.shape
- residual = hidden_states
-
- hidden_states = self.norm(hidden_states)
- if not self.use_linear_projection:
- hidden_states = self.proj_in(hidden_states)
- inner_dim = hidden_states.shape[1]
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
- else:
- inner_dim = hidden_states.shape[1]
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
- hidden_states = self.proj_in(hidden_states)
- elif self.is_input_vectorized:
- hidden_states = self.latent_image_embedding(hidden_states)
- elif self.is_input_patches:
- hidden_states = self.pos_embed(hidden_states)
-
- # 2. Blocks
- for block in self.transformer_blocks:
- hidden_states = block(
- hidden_states,
- attention_mask=attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- dino_feature=dino_feature,
- encoder_attention_mask=encoder_attention_mask,
- timestep=timestep,
- cross_attention_kwargs=cross_attention_kwargs,
- class_labels=class_labels,
- )
-
- # 3. Output
- if self.is_input_continuous:
- if not self.use_linear_projection:
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
- hidden_states = self.proj_out(hidden_states)
- else:
- hidden_states = self.proj_out(hidden_states)
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
-
- output = hidden_states + residual
- elif self.is_input_vectorized:
- hidden_states = self.norm_out(hidden_states)
- logits = self.out(hidden_states)
- # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
- logits = logits.permute(0, 2, 1)
-
- # log(p(x_0))
- output = F.log_softmax(logits.double(), dim=1).float()
- elif self.is_input_patches:
- # TODO: cleanup!
- conditioning = self.transformer_blocks[0].norm1.emb(
- timestep, class_labels, hidden_dtype=hidden_states.dtype
- )
- shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
- hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
- hidden_states = self.proj_out_2(hidden_states)
-
- # unpatchify
- height = width = int(hidden_states.shape[1] ** 0.5)
- hidden_states = hidden_states.reshape(
- shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
- )
- hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
- output = hidden_states.reshape(
- shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
- )
-
- if not return_dict:
- return (output,)
-
- return TransformerMV2DModelOutput(sample=output)
-
-
-@maybe_allow_in_graph
-class BasicMVTransformerBlock(nn.Module):
- r"""
- A basic Transformer block.
-
- Parameters:
- dim (`int`): The number of channels in the input and output.
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
- attention_head_dim (`int`): The number of channels in each head.
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
- only_cross_attention (`bool`, *optional*):
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
- double_self_attention (`bool`, *optional*):
- Whether to use two self-attention layers. In this case no cross attention layers are used.
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
- num_embeds_ada_norm (:
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
- attention_bias (:
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
- """
-
- def __init__(
- self,
- dim: int,
- num_attention_heads: int,
- attention_head_dim: int,
- dropout=0.0,
- cross_attention_dim: Optional[int] = None,
- activation_fn: str = "geglu",
- num_embeds_ada_norm: Optional[int] = None,
- attention_bias: bool = False,
- only_cross_attention: bool = False,
- double_self_attention: bool = False,
- upcast_attention: bool = False,
- norm_elementwise_affine: bool = True,
- norm_type: str = "layer_norm",
- final_dropout: bool = False,
- num_views: int = 1,
- cd_attention_last: bool = False,
- cd_attention_mid: bool = False,
- multiview_attention: bool = True,
- mvcd_attention: bool = False,
- rowwise_attention: bool = True,
- use_dino: bool = False
- ):
- super().__init__()
- self.only_cross_attention = only_cross_attention
-
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
-
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
- raise ValueError(
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
- )
-
- # Define 3 blocks. Each block has its own normalization layer.
- # 1. Self-Attn
- if self.use_ada_layer_norm:
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
- elif self.use_ada_layer_norm_zero:
- self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
- else:
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
-
- self.multiview_attention = multiview_attention
- self.mvcd_attention = mvcd_attention
- self.cd_attention_mid = cd_attention_mid
- self.rowwise_attention = multiview_attention and rowwise_attention
-
- if mvcd_attention and (not cd_attention_mid):
- # add cross domain attn to self attn
- self.attn1 = CustomJointAttention(
- query_dim=dim,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- dropout=dropout,
- bias=attention_bias,
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
- upcast_attention=upcast_attention,
- processor=JointAttnProcessor()
- )
- else:
- self.attn1 = Attention(
- query_dim=dim,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- dropout=dropout,
- bias=attention_bias,
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
- upcast_attention=upcast_attention
- )
- # 1.1 rowwise multiview attention
- if self.rowwise_attention:
- # print('INFO: using self+row_wise mv attention...')
- self.norm_mv = (
- AdaLayerNorm(dim, num_embeds_ada_norm)
- if self.use_ada_layer_norm
- else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
- )
- self.attn_mv = CustomAttention(
- query_dim=dim,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- dropout=dropout,
- bias=attention_bias,
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
- upcast_attention=upcast_attention,
- processor=MVAttnProcessor()
- )
- nn.init.zeros_(self.attn_mv.to_out[0].weight.data)
- else:
- self.norm_mv = None
- self.attn_mv = None
-
- # # 1.2 rowwise cross-domain attn
- # if mvcd_attention:
- # self.attn_joint = CustomJointAttention(
- # query_dim=dim,
- # heads=num_attention_heads,
- # dim_head=attention_head_dim,
- # dropout=dropout,
- # bias=attention_bias,
- # cross_attention_dim=cross_attention_dim if only_cross_attention else None,
- # upcast_attention=upcast_attention,
- # processor=JointAttnProcessor()
- # )
- # nn.init.zeros_(self.attn_joint.to_out[0].weight.data)
- # self.norm_joint = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
- # else:
- # self.attn_joint = None
- # self.norm_joint = None
-
- # 2. Cross-Attn
- if cross_attention_dim is not None or double_self_attention:
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
- # the second cross attention block.
- self.norm2 = (
- AdaLayerNorm(dim, num_embeds_ada_norm)
- if self.use_ada_layer_norm
- else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
- )
- self.attn2 = Attention(
- query_dim=dim,
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- dropout=dropout,
- bias=attention_bias,
- upcast_attention=upcast_attention,
- ) # is self-attn if encoder_hidden_states is none
- else:
- self.norm2 = None
- self.attn2 = None
-
- # 3. Feed-forward
- self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
- self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
-
- # let chunk size default to None
- self._chunk_size = None
- self._chunk_dim = 0
-
- self.num_views = num_views
-
-
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
- # Sets chunk feed-forward
- self._chunk_size = chunk_size
- self._chunk_dim = dim
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- attention_mask: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
- timestep: Optional[torch.LongTensor] = None,
- cross_attention_kwargs: Dict[str, Any] = None,
- class_labels: Optional[torch.LongTensor] = None,
- dino_feature: Optional[torch.FloatTensor] = None
- ):
- assert attention_mask is None # not supported yet
- # Notice that normalization is always applied before the real computation in the following blocks.
- # 1. Self-Attention
- if self.use_ada_layer_norm:
- norm_hidden_states = self.norm1(hidden_states, timestep)
- elif self.use_ada_layer_norm_zero:
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
- )
- else:
- norm_hidden_states = self.norm1(hidden_states)
-
- cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
-
- attn_output = self.attn1(
- norm_hidden_states,
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
- attention_mask=attention_mask,
- # multiview_attention=self.multiview_attention,
- # mvcd_attention=self.mvcd_attention,
- **cross_attention_kwargs,
- )
-
-
- if self.use_ada_layer_norm_zero:
- attn_output = gate_msa.unsqueeze(1) * attn_output
- hidden_states = attn_output + hidden_states
-
- # import pdb;pdb.set_trace()
- # 1.1 row wise multiview attention
- if self.rowwise_attention:
- norm_hidden_states = (
- self.norm_mv(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_mv(hidden_states)
- )
- attn_output = self.attn_mv(
- norm_hidden_states,
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
- attention_mask=attention_mask,
- num_views=self.num_views,
- multiview_attention=self.multiview_attention,
- cd_attention_mid=self.cd_attention_mid,
- **cross_attention_kwargs,
- )
- hidden_states = attn_output + hidden_states
-
-
- # 2. Cross-Attention
- if self.attn2 is not None:
- norm_hidden_states = (
- self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
- )
-
- attn_output = self.attn2(
- norm_hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=encoder_attention_mask,
- **cross_attention_kwargs,
- )
- hidden_states = attn_output + hidden_states
-
- # 3. Feed-forward
- norm_hidden_states = self.norm3(hidden_states)
-
- if self.use_ada_layer_norm_zero:
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
-
- if self._chunk_size is not None:
- # "feed_forward_chunk_size" can be used to save memory
- if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
- raise ValueError(
- f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
- )
-
- num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
- ff_output = torch.cat(
- [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
- dim=self._chunk_dim,
- )
- else:
- ff_output = self.ff(norm_hidden_states)
-
- if self.use_ada_layer_norm_zero:
- ff_output = gate_mlp.unsqueeze(1) * ff_output
-
- hidden_states = ff_output + hidden_states
-
- return hidden_states
-
-
-class CustomAttention(Attention):
- def set_use_memory_efficient_attention_xformers(
- self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
- ):
- processor = XFormersMVAttnProcessor()
- self.set_processor(processor)
- # print("using xformers attention processor")
-
-
-class CustomJointAttention(Attention):
- def set_use_memory_efficient_attention_xformers(
- self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
- ):
- processor = XFormersJointAttnProcessor()
- self.set_processor(processor)
- # print("using xformers attention processor")
-
-class MVAttnProcessor:
- r"""
- Default processor for performing attention-related computations.
- """
-
- def __call__(
- self,
- attn: Attention,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- num_views=1,
- cd_attention_mid=False
- ):
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
- height = int(math.sqrt(sequence_length))
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- # print('query', query.shape, 'key', key.shape, 'value', value.shape)
- #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
- # pdb.set_trace()
- # multi-view self-attention
- def transpose(tensor):
- tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height)
- tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c
- tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c
- tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height)
- return tensor
-
- if cd_attention_mid:
- key = transpose(key)
- value = transpose(value)
- query = transpose(query)
- else:
- key = rearrange(key, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
- value = rearrange(value, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
- query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
-
- query = attn.head_to_batch_dim(query).contiguous()
- key = attn.head_to_batch_dim(key).contiguous()
- value = attn.head_to_batch_dim(value).contiguous()
-
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
- hidden_states = torch.bmm(attention_probs, value)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- if cd_attention_mid:
- hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height)
- hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c
- hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c
- hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height)
- else:
- hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-class XFormersMVAttnProcessor:
- r"""
- Default processor for performing attention-related computations.
- """
-
- def __call__(
- self,
- attn: Attention,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- num_views=1,
- multiview_attention=True,
- cd_attention_mid=False
- ):
- # print(num_views)
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
- height = int(math.sqrt(sequence_length))
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # from yuancheng; here attention_mask is None
- if attention_mask is not None:
- # expand our mask's singleton query_tokens dimension:
- # [batch*heads, 1, key_tokens] ->
- # [batch*heads, query_tokens, key_tokens]
- # so that it can be added as a bias onto the attention scores that xformers computes:
- # [batch*heads, query_tokens, key_tokens]
- # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
- _, query_tokens, _ = hidden_states.shape
- attention_mask = attention_mask.expand(-1, query_tokens, -1)
-
- if attn.group_norm is not None:
- print('Warning: using group norm, pay attention to use it in row-wise attention')
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key_raw = attn.to_k(encoder_hidden_states)
- value_raw = attn.to_v(encoder_hidden_states)
-
- # print('query', query.shape, 'key', key.shape, 'value', value.shape)
- # pdb.set_trace()
- def transpose(tensor):
- tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height)
- tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c
- tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c
- tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height)
- return tensor
- # print(mvcd_attention)
- # import pdb;pdb.set_trace()
- if cd_attention_mid:
- key = transpose(key_raw)
- value = transpose(value_raw)
- query = transpose(query)
- else:
- key = rearrange(key_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
- value = rearrange(value_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
- query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
-
-
- query = attn.head_to_batch_dim(query) # torch.Size([960, 384, 64])
- key = attn.head_to_batch_dim(key)
- value = attn.head_to_batch_dim(value)
-
- hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if cd_attention_mid:
- hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height)
- hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c
- hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c
- hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height)
- else:
- hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-class XFormersJointAttnProcessor:
- r"""
- Default processor for performing attention-related computations.
- """
-
- def __call__(
- self,
- attn: Attention,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- num_tasks=2
- ):
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
-
- # from yuancheng; here attention_mask is None
- if attention_mask is not None:
- # expand our mask's singleton query_tokens dimension:
- # [batch*heads, 1, key_tokens] ->
- # [batch*heads, query_tokens, key_tokens]
- # so that it can be added as a bias onto the attention scores that xformers computes:
- # [batch*heads, query_tokens, key_tokens]
- # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
- _, query_tokens, _ = hidden_states.shape
- attention_mask = attention_mask.expand(-1, query_tokens, -1)
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- assert num_tasks == 2 # only support two tasks now
-
- def transpose(tensor):
- tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c
- tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c
- return tensor
- key = transpose(key)
- value = transpose(value)
- query = transpose(query)
- # from icecream import ic
- # ic(key.shape, value.shape, query.shape)
- # import pdb;pdb.set_trace()
- query = attn.head_to_batch_dim(query).contiguous()
- key = attn.head_to_batch_dim(key).contiguous()
- value = attn.head_to_batch_dim(value).contiguous()
-
- hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- hidden_states_normal, hidden_states_color = torch.chunk(hidden_states, dim=1, chunks=2)
- hidden_states = torch.cat([hidden_states_normal, hidden_states_color], dim=0) # 2bv hw c
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-class JointAttnProcessor:
- r"""
- Default processor for performing attention-related computations.
- """
-
- def __call__(
- self,
- attn: Attention,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- num_tasks=2
- ):
-
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
-
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- assert num_tasks == 2 # only support two tasks now
-
- def transpose(tensor):
- tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c
- tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c
- return tensor
- key = transpose(key)
- value = transpose(value)
- query = transpose(query)
-
-
- query = attn.head_to_batch_dim(query).contiguous()
- key = attn.head_to_batch_dim(key).contiguous()
- value = attn.head_to_batch_dim(value).contiguous()
-
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
- hidden_states = torch.bmm(attention_probs, value)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- hidden_states = torch.cat([hidden_states[:, 0], hidden_states[:, 1]], dim=0) # 2bv hw c
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
\ No newline at end of file
diff --git a/apps/third_party/Era3D/mvdiffusion/models/unet_mv2d_blocks.py b/apps/third_party/Era3D/mvdiffusion/models/unet_mv2d_blocks.py
deleted file mode 100644
index 14face576d731c4d17c864e6fd2b9d156a63a8cf..0000000000000000000000000000000000000000
--- a/apps/third_party/Era3D/mvdiffusion/models/unet_mv2d_blocks.py
+++ /dev/null
@@ -1,970 +0,0 @@
-# Copyright 2023 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import Any, Dict, Optional, Tuple
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-from torch import nn
-
-from diffusers.utils import is_torch_version, logging
-from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
-from diffusers.models.dual_transformer_2d import DualTransformer2DModel
-from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
-
-from diffusers.models.unet_2d_blocks import DownBlock2D, ResnetDownsampleBlock2D, AttnDownBlock2D, CrossAttnDownBlock2D, SimpleCrossAttnDownBlock2D, SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, KCrossAttnDownBlock2D
-from diffusers.models.unet_2d_blocks import UpBlock2D, ResnetUpsampleBlock2D, CrossAttnUpBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D
-
-
-logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-
-
-def get_down_block(
- down_block_type,
- num_layers,
- in_channels,
- out_channels,
- temb_channels,
- add_downsample,
- resnet_eps,
- resnet_act_fn,
- transformer_layers_per_block=1,
- num_attention_heads=None,
- resnet_groups=None,
- cross_attention_dim=None,
- downsample_padding=None,
- dual_cross_attention=False,
- use_linear_projection=False,
- only_cross_attention=False,
- upcast_attention=False,
- resnet_time_scale_shift="default",
- resnet_skip_time_act=False,
- resnet_out_scale_factor=1.0,
- cross_attention_norm=None,
- attention_head_dim=None,
- downsample_type=None,
- num_views=1,
- cd_attention_last: bool = False,
- cd_attention_mid: bool = False,
- multiview_attention: bool = True,
- sparse_mv_attention: bool = False,
- selfattn_block: str = "custom",
- mvcd_attention: bool=False,
- use_dino: bool = False
-):
- # If attn head dim is not defined, we default it to the number of heads
- if attention_head_dim is None:
- logger.warn(
- f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
- )
- attention_head_dim = num_attention_heads
-
- down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
- if down_block_type == "DownBlock2D":
- return DownBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- downsample_padding=downsample_padding,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif down_block_type == "ResnetDownsampleBlock2D":
- return ResnetDownsampleBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- resnet_time_scale_shift=resnet_time_scale_shift,
- skip_time_act=resnet_skip_time_act,
- output_scale_factor=resnet_out_scale_factor,
- )
- elif down_block_type == "AttnDownBlock2D":
- if add_downsample is False:
- downsample_type = None
- else:
- downsample_type = downsample_type or "conv" # default to 'conv'
- return AttnDownBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- downsample_padding=downsample_padding,
- attention_head_dim=attention_head_dim,
- resnet_time_scale_shift=resnet_time_scale_shift,
- downsample_type=downsample_type,
- )
- elif down_block_type == "CrossAttnDownBlock2D":
- if cross_attention_dim is None:
- raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
- return CrossAttnDownBlock2D(
- num_layers=num_layers,
- transformer_layers_per_block=transformer_layers_per_block,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- downsample_padding=downsample_padding,
- cross_attention_dim=cross_attention_dim,
- num_attention_heads=num_attention_heads,
- dual_cross_attention=dual_cross_attention,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention,
- upcast_attention=upcast_attention,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- # custom MV2D attention block
- elif down_block_type == "CrossAttnDownBlockMV2D":
- if cross_attention_dim is None:
- raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMV2D")
- return CrossAttnDownBlockMV2D(
- num_layers=num_layers,
- transformer_layers_per_block=transformer_layers_per_block,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- downsample_padding=downsample_padding,
- cross_attention_dim=cross_attention_dim,
- num_attention_heads=num_attention_heads,
- dual_cross_attention=dual_cross_attention,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention,
- upcast_attention=upcast_attention,
- resnet_time_scale_shift=resnet_time_scale_shift,
- num_views=num_views,
- cd_attention_last=cd_attention_last,
- cd_attention_mid=cd_attention_mid,
- multiview_attention=multiview_attention,
- sparse_mv_attention=sparse_mv_attention,
- selfattn_block=selfattn_block,
- mvcd_attention=mvcd_attention,
- use_dino=use_dino
- )
- elif down_block_type == "SimpleCrossAttnDownBlock2D":
- if cross_attention_dim is None:
- raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
- return SimpleCrossAttnDownBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- cross_attention_dim=cross_attention_dim,
- attention_head_dim=attention_head_dim,
- resnet_time_scale_shift=resnet_time_scale_shift,
- skip_time_act=resnet_skip_time_act,
- output_scale_factor=resnet_out_scale_factor,
- only_cross_attention=only_cross_attention,
- cross_attention_norm=cross_attention_norm,
- )
- elif down_block_type == "SkipDownBlock2D":
- return SkipDownBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- downsample_padding=downsample_padding,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif down_block_type == "AttnSkipDownBlock2D":
- return AttnSkipDownBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- attention_head_dim=attention_head_dim,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif down_block_type == "DownEncoderBlock2D":
- return DownEncoderBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- downsample_padding=downsample_padding,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif down_block_type == "AttnDownEncoderBlock2D":
- return AttnDownEncoderBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- downsample_padding=downsample_padding,
- attention_head_dim=attention_head_dim,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif down_block_type == "KDownBlock2D":
- return KDownBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- )
- elif down_block_type == "KCrossAttnDownBlock2D":
- return KCrossAttnDownBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- cross_attention_dim=cross_attention_dim,
- attention_head_dim=attention_head_dim,
- add_self_attention=True if not add_downsample else False,
- )
- raise ValueError(f"{down_block_type} does not exist.")
-
-
-def get_up_block(
- up_block_type,
- num_layers,
- in_channels,
- out_channels,
- prev_output_channel,
- temb_channels,
- add_upsample,
- resnet_eps,
- resnet_act_fn,
- transformer_layers_per_block=1,
- num_attention_heads=None,
- resnet_groups=None,
- cross_attention_dim=None,
- dual_cross_attention=False,
- use_linear_projection=False,
- only_cross_attention=False,
- upcast_attention=False,
- resnet_time_scale_shift="default",
- resnet_skip_time_act=False,
- resnet_out_scale_factor=1.0,
- cross_attention_norm=None,
- attention_head_dim=None,
- upsample_type=None,
- num_views=1,
- cd_attention_last: bool = False,
- cd_attention_mid: bool = False,
- multiview_attention: bool = True,
- sparse_mv_attention: bool = False,
- selfattn_block: str = "custom",
- mvcd_attention: bool=False,
- use_dino: bool = False
-):
- # If attn head dim is not defined, we default it to the number of heads
- if attention_head_dim is None:
- logger.warn(
- f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
- )
- attention_head_dim = num_attention_heads
-
- up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
- if up_block_type == "UpBlock2D":
- return UpBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- prev_output_channel=prev_output_channel,
- temb_channels=temb_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif up_block_type == "ResnetUpsampleBlock2D":
- return ResnetUpsampleBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- prev_output_channel=prev_output_channel,
- temb_channels=temb_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- resnet_time_scale_shift=resnet_time_scale_shift,
- skip_time_act=resnet_skip_time_act,
- output_scale_factor=resnet_out_scale_factor,
- )
- elif up_block_type == "CrossAttnUpBlock2D":
- if cross_attention_dim is None:
- raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
- return CrossAttnUpBlock2D(
- num_layers=num_layers,
- transformer_layers_per_block=transformer_layers_per_block,
- in_channels=in_channels,
- out_channels=out_channels,
- prev_output_channel=prev_output_channel,
- temb_channels=temb_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- cross_attention_dim=cross_attention_dim,
- num_attention_heads=num_attention_heads,
- dual_cross_attention=dual_cross_attention,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention,
- upcast_attention=upcast_attention,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- # custom MV2D attention block
- elif up_block_type == "CrossAttnUpBlockMV2D":
- if cross_attention_dim is None:
- raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMV2D")
- return CrossAttnUpBlockMV2D(
- num_layers=num_layers,
- transformer_layers_per_block=transformer_layers_per_block,
- in_channels=in_channels,
- out_channels=out_channels,
- prev_output_channel=prev_output_channel,
- temb_channels=temb_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- cross_attention_dim=cross_attention_dim,
- num_attention_heads=num_attention_heads,
- dual_cross_attention=dual_cross_attention,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention,
- upcast_attention=upcast_attention,
- resnet_time_scale_shift=resnet_time_scale_shift,
- num_views=num_views,
- cd_attention_last=cd_attention_last,
- cd_attention_mid=cd_attention_mid,
- multiview_attention=multiview_attention,
- sparse_mv_attention=sparse_mv_attention,
- selfattn_block=selfattn_block,
- mvcd_attention=mvcd_attention,
- use_dino=use_dino
- )
- elif up_block_type == "SimpleCrossAttnUpBlock2D":
- if cross_attention_dim is None:
- raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
- return SimpleCrossAttnUpBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- prev_output_channel=prev_output_channel,
- temb_channels=temb_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- cross_attention_dim=cross_attention_dim,
- attention_head_dim=attention_head_dim,
- resnet_time_scale_shift=resnet_time_scale_shift,
- skip_time_act=resnet_skip_time_act,
- output_scale_factor=resnet_out_scale_factor,
- only_cross_attention=only_cross_attention,
- cross_attention_norm=cross_attention_norm,
- )
- elif up_block_type == "AttnUpBlock2D":
- if add_upsample is False:
- upsample_type = None
- else:
- upsample_type = upsample_type or "conv" # default to 'conv'
-
- return AttnUpBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- prev_output_channel=prev_output_channel,
- temb_channels=temb_channels,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- attention_head_dim=attention_head_dim,
- resnet_time_scale_shift=resnet_time_scale_shift,
- upsample_type=upsample_type,
- )
- elif up_block_type == "SkipUpBlock2D":
- return SkipUpBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- prev_output_channel=prev_output_channel,
- temb_channels=temb_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif up_block_type == "AttnSkipUpBlock2D":
- return AttnSkipUpBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- prev_output_channel=prev_output_channel,
- temb_channels=temb_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- attention_head_dim=attention_head_dim,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- elif up_block_type == "UpDecoderBlock2D":
- return UpDecoderBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- resnet_time_scale_shift=resnet_time_scale_shift,
- temb_channels=temb_channels,
- )
- elif up_block_type == "AttnUpDecoderBlock2D":
- return AttnUpDecoderBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- attention_head_dim=attention_head_dim,
- resnet_time_scale_shift=resnet_time_scale_shift,
- temb_channels=temb_channels,
- )
- elif up_block_type == "KUpBlock2D":
- return KUpBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- )
- elif up_block_type == "KCrossAttnUpBlock2D":
- return KCrossAttnUpBlock2D(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- cross_attention_dim=cross_attention_dim,
- attention_head_dim=attention_head_dim,
- )
-
- raise ValueError(f"{up_block_type} does not exist.")
-
-
-class UNetMidBlockMV2DCrossAttn(nn.Module):
- def __init__(
- self,
- in_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- transformer_layers_per_block: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- num_attention_heads=1,
- output_scale_factor=1.0,
- cross_attention_dim=1280,
- dual_cross_attention=False,
- use_linear_projection=False,
- upcast_attention=False,
- num_views: int = 1,
- cd_attention_last: bool = False,
- cd_attention_mid: bool = False,
- multiview_attention: bool = True,
- sparse_mv_attention: bool = False,
- selfattn_block: str = "custom",
- mvcd_attention: bool=False,
- use_dino: bool = False
- ):
- super().__init__()
-
- self.has_cross_attention = True
- self.num_attention_heads = num_attention_heads
- resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
- if selfattn_block == "custom":
- from .transformer_mv2d import TransformerMV2DModel
- elif selfattn_block == "rowwise":
- from .transformer_mv2d_rowwise import TransformerMV2DModel
- elif selfattn_block == "self_rowwise":
- from .transformer_mv2d_self_rowwise import TransformerMV2DModel
- else:
- raise NotImplementedError
-
- # there is always at least one resnet
- resnets = [
- ResnetBlock2D(
- in_channels=in_channels,
- out_channels=in_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- ]
- attentions = []
-
- for _ in range(num_layers):
- if not dual_cross_attention:
- attentions.append(
- TransformerMV2DModel(
- num_attention_heads,
- in_channels // num_attention_heads,
- in_channels=in_channels,
- num_layers=transformer_layers_per_block,
- cross_attention_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
- use_linear_projection=use_linear_projection,
- upcast_attention=upcast_attention,
- num_views=num_views,
- cd_attention_last=cd_attention_last,
- cd_attention_mid=cd_attention_mid,
- multiview_attention=multiview_attention,
- sparse_mv_attention=sparse_mv_attention,
- mvcd_attention=mvcd_attention,
- use_dino=use_dino
- )
- )
- else:
- raise NotImplementedError
- resnets.append(
- ResnetBlock2D(
- in_channels=in_channels,
- out_channels=in_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
-
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- temb: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
- dino_feature: Optional[torch.FloatTensor] = None
- ) -> torch.FloatTensor:
- hidden_states = self.resnets[0](hidden_states, temb)
- for attn, resnet in zip(self.attentions, self.resnets[1:]):
- hidden_states = attn(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- cross_attention_kwargs=cross_attention_kwargs,
- attention_mask=attention_mask,
- encoder_attention_mask=encoder_attention_mask,
- dino_feature=dino_feature,
- return_dict=False,
- )[0]
- hidden_states = resnet(hidden_states, temb)
-
- return hidden_states
-
-
-class CrossAttnUpBlockMV2D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- prev_output_channel: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- transformer_layers_per_block: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- num_attention_heads=1,
- cross_attention_dim=1280,
- output_scale_factor=1.0,
- add_upsample=True,
- dual_cross_attention=False,
- use_linear_projection=False,
- only_cross_attention=False,
- upcast_attention=False,
- num_views: int = 1,
- cd_attention_last: bool = False,
- cd_attention_mid: bool = False,
- multiview_attention: bool = True,
- sparse_mv_attention: bool = False,
- selfattn_block: str = "custom",
- mvcd_attention: bool=False,
- use_dino: bool = False
- ):
- super().__init__()
- resnets = []
- attentions = []
-
- self.has_cross_attention = True
- self.num_attention_heads = num_attention_heads
-
- if selfattn_block == "custom":
- from .transformer_mv2d import TransformerMV2DModel
- elif selfattn_block == "rowwise":
- from .transformer_mv2d_rowwise import TransformerMV2DModel
- elif selfattn_block == "self_rowwise":
- from .transformer_mv2d_self_rowwise import TransformerMV2DModel
- else:
- raise NotImplementedError
-
- for i in range(num_layers):
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
-
- resnets.append(
- ResnetBlock2D(
- in_channels=resnet_in_channels + res_skip_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
- if not dual_cross_attention:
- attentions.append(
- TransformerMV2DModel(
- num_attention_heads,
- out_channels // num_attention_heads,
- in_channels=out_channels,
- num_layers=transformer_layers_per_block,
- cross_attention_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention,
- upcast_attention=upcast_attention,
- num_views=num_views,
- cd_attention_last=cd_attention_last,
- cd_attention_mid=cd_attention_mid,
- multiview_attention=multiview_attention,
- sparse_mv_attention=sparse_mv_attention,
- mvcd_attention=mvcd_attention,
- use_dino=use_dino
- )
- )
- else:
- raise NotImplementedError
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
-
- if add_upsample:
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
- else:
- self.upsamplers = None
-
- self.gradient_checkpointing = False
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
- temb: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- upsample_size: Optional[int] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
- dino_feature: Optional[torch.FloatTensor] = None
- ):
- for resnet, attn in zip(self.resnets, self.attentions):
- # pop res hidden states
- res_hidden_states = res_hidden_states_tuple[-1]
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
-
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(attn, return_dict=False),
- hidden_states,
- encoder_hidden_states,
- dino_feature,
- None, # timestep
- None, # class_labels
- cross_attention_kwargs,
- attention_mask,
- encoder_attention_mask,
- **ckpt_kwargs,
- )[0]
- else:
- hidden_states = resnet(hidden_states, temb)
- hidden_states = attn(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- cross_attention_kwargs=cross_attention_kwargs,
- attention_mask=attention_mask,
- encoder_attention_mask=encoder_attention_mask,
- dino_feature=dino_feature,
- return_dict=False,
- )[0]
-
- if self.upsamplers is not None:
- for upsampler in self.upsamplers:
- hidden_states = upsampler(hidden_states, upsample_size)
-
- return hidden_states
-
-
-class CrossAttnDownBlockMV2D(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- transformer_layers_per_block: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- num_attention_heads=1,
- cross_attention_dim=1280,
- output_scale_factor=1.0,
- downsample_padding=1,
- add_downsample=True,
- dual_cross_attention=False,
- use_linear_projection=False,
- only_cross_attention=False,
- upcast_attention=False,
- num_views: int = 1,
- cd_attention_last: bool = False,
- cd_attention_mid: bool = False,
- multiview_attention: bool = True,
- sparse_mv_attention: bool = False,
- selfattn_block: str = "custom",
- mvcd_attention: bool=False,
- use_dino: bool = False
- ):
- super().__init__()
- resnets = []
- attentions = []
-
- self.has_cross_attention = True
- self.num_attention_heads = num_attention_heads
- if selfattn_block == "custom":
- from .transformer_mv2d import TransformerMV2DModel
- elif selfattn_block == "rowwise":
- from .transformer_mv2d_rowwise import TransformerMV2DModel
- elif selfattn_block == "self_rowwise":
- from .transformer_mv2d_self_rowwise import TransformerMV2DModel
- else:
- raise NotImplementedError
-
- for i in range(num_layers):
- in_channels = in_channels if i == 0 else out_channels
- resnets.append(
- ResnetBlock2D(
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
- if not dual_cross_attention:
- attentions.append(
- TransformerMV2DModel(
- num_attention_heads,
- out_channels // num_attention_heads,
- in_channels=out_channels,
- num_layers=transformer_layers_per_block,
- cross_attention_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention,
- upcast_attention=upcast_attention,
- num_views=num_views,
- cd_attention_last=cd_attention_last,
- cd_attention_mid=cd_attention_mid,
- multiview_attention=multiview_attention,
- sparse_mv_attention=sparse_mv_attention,
- mvcd_attention=mvcd_attention,
- use_dino=use_dino
- )
- )
- else:
- raise NotImplementedError
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
-
- if add_downsample:
- self.downsamplers = nn.ModuleList(
- [
- Downsample2D(
- out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
- )
- ]
- )
- else:
- self.downsamplers = None
-
- self.gradient_checkpointing = False
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- temb: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- dino_feature: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
- additional_residuals=None,
- ):
- output_states = ()
-
- blocks = list(zip(self.resnets, self.attentions))
-
- for i, (resnet, attn) in enumerate(blocks):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(attn, return_dict=False),
- hidden_states,
- encoder_hidden_states,
- dino_feature,
- None, # timestep
- None, # class_labels
- cross_attention_kwargs,
- attention_mask,
- encoder_attention_mask,
- **ckpt_kwargs,
- )[0]
- else:
- hidden_states = resnet(hidden_states, temb)
- hidden_states = attn(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- dino_feature=dino_feature,
- cross_attention_kwargs=cross_attention_kwargs,
- attention_mask=attention_mask,
- encoder_attention_mask=encoder_attention_mask,
- return_dict=False,
- )[0]
-
- # apply additional residuals to the output of the last pair of resnet and attention blocks
- if i == len(blocks) - 1 and additional_residuals is not None:
- hidden_states = hidden_states + additional_residuals
-
- output_states = output_states + (hidden_states,)
-
- if self.downsamplers is not None:
- for downsampler in self.downsamplers:
- hidden_states = downsampler(hidden_states)
-
- output_states = output_states + (hidden_states,)
-
- return hidden_states, output_states
-
diff --git a/apps/third_party/Era3D/mvdiffusion/models/unet_mv2d_condition.py b/apps/third_party/Era3D/mvdiffusion/models/unet_mv2d_condition.py
deleted file mode 100644
index f49b01322b7b7dae9a02258943730b4346fc1792..0000000000000000000000000000000000000000
--- a/apps/third_party/Era3D/mvdiffusion/models/unet_mv2d_condition.py
+++ /dev/null
@@ -1,1686 +0,0 @@
-# Copyright 2023 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from dataclasses import dataclass
-from typing import Any, Dict, List, Optional, Tuple, Union
-import os
-
-import torch
-import torch.nn as nn
-import torch.utils.checkpoint
-
-from diffusers.configuration_utils import ConfigMixin, register_to_config
-from diffusers.loaders import UNet2DConditionLoadersMixin
-from diffusers.utils import BaseOutput, logging
-from diffusers.models.activations import get_activation
-from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
-from diffusers.models.embeddings import (
- GaussianFourierProjection,
- ImageHintTimeEmbedding,
- ImageProjection,
- ImageTimeEmbedding,
- TextImageProjection,
- TextImageTimeEmbedding,
- TextTimeEmbedding,
- TimestepEmbedding,
- Timesteps,
-)
-from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model
-from diffusers.models.unet_2d_blocks import (
- CrossAttnDownBlock2D,
- CrossAttnUpBlock2D,
- DownBlock2D,
- UNetMidBlock2DCrossAttn,
- UNetMidBlock2DSimpleCrossAttn,
- UpBlock2D,
-)
-from diffusers.utils import (
- CONFIG_NAME,
- FLAX_WEIGHTS_NAME,
- SAFETENSORS_WEIGHTS_NAME,
- WEIGHTS_NAME,
- _add_variant,
- _get_model_file,
- deprecate,
- is_torch_version,
- logging,
-)
-from diffusers.utils.import_utils import is_accelerate_available
-from diffusers.utils.hub_utils import HF_HUB_OFFLINE
-from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
-DIFFUSERS_CACHE = HUGGINGFACE_HUB_CACHE
-
-from diffusers import __version__
-from .unet_mv2d_blocks import (
- CrossAttnDownBlockMV2D,
- CrossAttnUpBlockMV2D,
- UNetMidBlockMV2DCrossAttn,
- get_down_block,
- get_up_block,
-)
-from einops import rearrange, repeat
-
-from diffusers import __version__
-from mvdiffusion.models.unet_mv2d_blocks import (
- CrossAttnDownBlockMV2D,
- CrossAttnUpBlockMV2D,
- UNetMidBlockMV2DCrossAttn,
- get_down_block,
- get_up_block,
-)
-
-
-logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-
-
-@dataclass
-class UNetMV2DConditionOutput(BaseOutput):
- """
- The output of [`UNet2DConditionModel`].
-
- Args:
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
- The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
- """
-
- sample: torch.FloatTensor = None
-
-
-class ResidualBlock(nn.Module):
- def __init__(self, dim):
- super(ResidualBlock, self).__init__()
- self.linear1 = nn.Linear(dim, dim)
- self.activation = nn.SiLU()
- self.linear2 = nn.Linear(dim, dim)
-
- def forward(self, x):
- identity = x
- out = self.linear1(x)
- out = self.activation(out)
- out = self.linear2(out)
- out += identity
- out = self.activation(out)
- return out
-
-class ResidualLiner(nn.Module):
- def __init__(self, in_features, out_features, dim, act=None, num_block=1):
- super(ResidualLiner, self).__init__()
- self.linear_in = nn.Sequential(nn.Linear(in_features, dim), nn.SiLU())
-
- blocks = nn.ModuleList()
- for _ in range(num_block):
- blocks.append(ResidualBlock(dim))
- self.blocks = blocks
-
- self.linear_out = nn.Linear(dim, out_features)
- self.act = act
-
- def forward(self, x):
- out = self.linear_in(x)
- for block in self.blocks:
- out = block(out)
- out = self.linear_out(out)
- if self.act is not None:
- out = self.act(out)
- return out
-
-class BasicConvBlock(nn.Module):
- def __init__(self, in_channels, out_channels, stride=1):
- super(BasicConvBlock, self).__init__()
- self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
- self.norm1 = nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True)
- self.act = nn.SiLU()
- self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
- self.norm2 = nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True)
- self.downsample = nn.Sequential()
- if stride != 1 or in_channels != out_channels:
- self.downsample = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
- nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True)
- )
-
- def forward(self, x):
- identity = x
- out = self.conv1(x)
- out = self.norm1(out)
- out = self.act(out)
- out = self.conv2(out)
- out = self.norm2(out)
- out += self.downsample(identity)
- out = self.act(out)
- return out
-
-class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
- r"""
- A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
- shaped output.
-
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
- for all models (such as downloading or saving).
-
- Parameters:
- sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
- Height and width of input/output sample.
- in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
- out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
- center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
- flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
- Whether to flip the sin to cos in the time embedding.
- freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
- down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
- The tuple of downsample blocks to use.
- mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
- Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
- `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
- The tuple of upsample blocks to use.
- only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
- Whether to include self-attention in the basic transformer blocks, see
- [`~models.attention.BasicTransformerBlock`].
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
- The tuple of output channels for each block.
- layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
- downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
- mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
- norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
- If `None`, normalization and activation layers is skipped in post-processing.
- norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
- cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
- The dimension of the cross attention features.
- transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
- [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
- [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
- encoder_hid_dim (`int`, *optional*, defaults to None):
- If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
- dimension to `cross_attention_dim`.
- encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
- If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
- embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
- attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
- num_attention_heads (`int`, *optional*):
- The number of attention heads. If not defined, defaults to `attention_head_dim`
- resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
- for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
- class_embed_type (`str`, *optional*, defaults to `None`):
- The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
- `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
- addition_embed_type (`str`, *optional*, defaults to `None`):
- Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
- "text". "text" will use the `TextTimeEmbedding` layer.
- addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
- Dimension for the timestep embeddings.
- num_class_embeds (`int`, *optional*, defaults to `None`):
- Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
- class conditioning with `class_embed_type` equal to `None`.
- time_embedding_type (`str`, *optional*, defaults to `positional`):
- The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
- time_embedding_dim (`int`, *optional*, defaults to `None`):
- An optional override for the dimension of the projected time embedding.
- time_embedding_act_fn (`str`, *optional*, defaults to `None`):
- Optional activation function to use only once on the time embeddings before they are passed to the rest of
- the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
- timestep_post_act (`str`, *optional*, defaults to `None`):
- The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
- time_cond_proj_dim (`int`, *optional*, defaults to `None`):
- The dimension of `cond_proj` layer in the timestep embedding.
- conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
- conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
- projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
- `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
- class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
- embeddings with the class embeddings.
- mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
- Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
- `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
- `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
- otherwise.
- """
-
- _supports_gradient_checkpointing = True
-
- @register_to_config
- def __init__(
- self,
- sample_size: Optional[int] = None,
- in_channels: int = 4,
- out_channels: int = 4,
- center_input_sample: bool = False,
- flip_sin_to_cos: bool = True,
- freq_shift: int = 0,
- down_block_types: Tuple[str] = (
- "CrossAttnDownBlockMV2D",
- "CrossAttnDownBlockMV2D",
- "CrossAttnDownBlockMV2D",
- "DownBlock2D",
- ),
- mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn",
- up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"),
- only_cross_attention: Union[bool, Tuple[bool]] = False,
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
- layers_per_block: Union[int, Tuple[int]] = 2,
- downsample_padding: int = 1,
- mid_block_scale_factor: float = 1,
- act_fn: str = "silu",
- norm_num_groups: Optional[int] = 32,
- norm_eps: float = 1e-5,
- cross_attention_dim: Union[int, Tuple[int]] = 1280,
- transformer_layers_per_block: Union[int, Tuple[int]] = 1,
- encoder_hid_dim: Optional[int] = None,
- encoder_hid_dim_type: Optional[str] = None,
- attention_head_dim: Union[int, Tuple[int]] = 8,
- num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
- dual_cross_attention: bool = False,
- use_linear_projection: bool = False,
- class_embed_type: Optional[str] = None,
- addition_embed_type: Optional[str] = None,
- addition_time_embed_dim: Optional[int] = None,
- num_class_embeds: Optional[int] = None,
- upcast_attention: bool = False,
- resnet_time_scale_shift: str = "default",
- resnet_skip_time_act: bool = False,
- resnet_out_scale_factor: int = 1.0,
- time_embedding_type: str = "positional",
- time_embedding_dim: Optional[int] = None,
- time_embedding_act_fn: Optional[str] = None,
- timestep_post_act: Optional[str] = None,
- time_cond_proj_dim: Optional[int] = None,
- conv_in_kernel: int = 3,
- conv_out_kernel: int = 3,
- projection_class_embeddings_input_dim: Optional[int] = None,
- projection_camera_embeddings_input_dim: Optional[int] = None,
- class_embeddings_concat: bool = False,
- mid_block_only_cross_attention: Optional[bool] = None,
- cross_attention_norm: Optional[str] = None,
- addition_embed_type_num_heads=64,
- num_views: int = 1,
- cd_attention_last: bool = False,
- cd_attention_mid: bool = False,
- multiview_attention: bool = True,
- sparse_mv_attention: bool = False,
- selfattn_block: str = "custom",
- mvcd_attention: bool = False,
- regress_elevation: bool = False,
- regress_focal_length: bool = False,
- num_regress_blocks: int = 4,
- use_dino: bool = False,
- addition_downsample: bool = False,
- addition_channels: Optional[Tuple[int]] = (1280, 1280, 1280),
- ):
- super().__init__()
-
- self.sample_size = sample_size
- self.num_views = num_views
- self.mvcd_attention = mvcd_attention
- if num_attention_heads is not None:
- raise ValueError(
- "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
- )
-
- # If `num_attention_heads` is not defined (which is the case for most models)
- # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
- # The reason for this behavior is to correct for incorrectly named variables that were introduced
- # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
- # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
- # which is why we correct for the naming here.
- num_attention_heads = num_attention_heads or attention_head_dim
-
- # Check inputs
- if len(down_block_types) != len(up_block_types):
- raise ValueError(
- f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
- )
-
- if len(block_out_channels) != len(down_block_types):
- raise ValueError(
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
- )
-
- if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
- raise ValueError(
- f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
- )
-
- if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
- raise ValueError(
- f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
- )
-
- if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
- raise ValueError(
- f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
- )
-
- if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
- raise ValueError(
- f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
- )
-
- if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
- raise ValueError(
- f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
- )
-
- # input
- conv_in_padding = (conv_in_kernel - 1) // 2
- self.conv_in = nn.Conv2d(
- in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
- )
-
- # time
- if time_embedding_type == "fourier":
- time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
- if time_embed_dim % 2 != 0:
- raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
- self.time_proj = GaussianFourierProjection(
- time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
- )
- timestep_input_dim = time_embed_dim
- elif time_embedding_type == "positional":
- time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
-
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
- timestep_input_dim = block_out_channels[0]
- else:
- raise ValueError(
- f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
- )
-
- self.time_embedding = TimestepEmbedding(
- timestep_input_dim,
- time_embed_dim,
- act_fn=act_fn,
- post_act_fn=timestep_post_act,
- cond_proj_dim=time_cond_proj_dim,
- )
-
- if encoder_hid_dim_type is None and encoder_hid_dim is not None:
- encoder_hid_dim_type = "text_proj"
- self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
- logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
-
- if encoder_hid_dim is None and encoder_hid_dim_type is not None:
- raise ValueError(
- f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
- )
-
- if encoder_hid_dim_type == "text_proj":
- self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
- elif encoder_hid_dim_type == "text_image_proj":
- # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
- # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
- self.encoder_hid_proj = TextImageProjection(
- text_embed_dim=encoder_hid_dim,
- image_embed_dim=cross_attention_dim,
- cross_attention_dim=cross_attention_dim,
- )
- elif encoder_hid_dim_type == "image_proj":
- # Kandinsky 2.2
- self.encoder_hid_proj = ImageProjection(
- image_embed_dim=encoder_hid_dim,
- cross_attention_dim=cross_attention_dim,
- )
- elif encoder_hid_dim_type is not None:
- raise ValueError(
- f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
- )
- else:
- self.encoder_hid_proj = None
-
- # class embedding
- if class_embed_type is None and num_class_embeds is not None:
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
- elif class_embed_type == "timestep":
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
- elif class_embed_type == "identity":
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
- elif class_embed_type == "projection":
- if projection_class_embeddings_input_dim is None:
- raise ValueError(
- "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
- )
- # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
- # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
- # 2. it projects from an arbitrary input dimension.
- #
- # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
- # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
- # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
- self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
- elif class_embed_type == "simple_projection":
- if projection_class_embeddings_input_dim is None:
- raise ValueError(
- "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
- )
- self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
- else:
- self.class_embedding = None
-
- if addition_embed_type == "text":
- if encoder_hid_dim is not None:
- text_time_embedding_from_dim = encoder_hid_dim
- else:
- text_time_embedding_from_dim = cross_attention_dim
-
- self.add_embedding = TextTimeEmbedding(
- text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
- )
- elif addition_embed_type == "text_image":
- # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
- # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
- self.add_embedding = TextImageTimeEmbedding(
- text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
- )
- elif addition_embed_type == "text_time":
- self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
- self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
- elif addition_embed_type == "image":
- # Kandinsky 2.2
- self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
- elif addition_embed_type == "image_hint":
- # Kandinsky 2.2 ControlNet
- self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
- elif addition_embed_type is not None:
- raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
-
- if time_embedding_act_fn is None:
- self.time_embed_act = None
- else:
- self.time_embed_act = get_activation(time_embedding_act_fn)
-
- self.down_blocks = nn.ModuleList([])
- self.up_blocks = nn.ModuleList([])
-
- if isinstance(only_cross_attention, bool):
- if mid_block_only_cross_attention is None:
- mid_block_only_cross_attention = only_cross_attention
-
- only_cross_attention = [only_cross_attention] * len(down_block_types)
-
- if mid_block_only_cross_attention is None:
- mid_block_only_cross_attention = False
-
- if isinstance(num_attention_heads, int):
- num_attention_heads = (num_attention_heads,) * len(down_block_types)
-
- if isinstance(attention_head_dim, int):
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
-
- if isinstance(cross_attention_dim, int):
- cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
-
- if isinstance(layers_per_block, int):
- layers_per_block = [layers_per_block] * len(down_block_types)
-
- if isinstance(transformer_layers_per_block, int):
- transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
-
- if class_embeddings_concat:
- # The time embeddings are concatenated with the class embeddings. The dimension of the
- # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
- # regular time embeddings
- blocks_time_embed_dim = time_embed_dim * 2
- else:
- blocks_time_embed_dim = time_embed_dim
-
- # down
- output_channel = block_out_channels[0]
- for i, down_block_type in enumerate(down_block_types):
- input_channel = output_channel
- output_channel = block_out_channels[i]
- is_final_block = i == len(block_out_channels) - 1
-
- down_block = get_down_block(
- down_block_type,
- num_layers=layers_per_block[i],
- transformer_layers_per_block=transformer_layers_per_block[i],
- in_channels=input_channel,
- out_channels=output_channel,
- temb_channels=blocks_time_embed_dim,
- add_downsample=not is_final_block,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- resnet_groups=norm_num_groups,
- cross_attention_dim=cross_attention_dim[i],
- num_attention_heads=num_attention_heads[i],
- downsample_padding=downsample_padding,
- dual_cross_attention=dual_cross_attention,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention[i],
- upcast_attention=upcast_attention,
- resnet_time_scale_shift=resnet_time_scale_shift,
- resnet_skip_time_act=resnet_skip_time_act,
- resnet_out_scale_factor=resnet_out_scale_factor,
- cross_attention_norm=cross_attention_norm,
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
- num_views=num_views,
- cd_attention_last=cd_attention_last,
- cd_attention_mid=cd_attention_mid,
- multiview_attention=multiview_attention,
- sparse_mv_attention=sparse_mv_attention,
- selfattn_block=selfattn_block,
- mvcd_attention=mvcd_attention,
- use_dino=use_dino
- )
- self.down_blocks.append(down_block)
-
- # mid
- if mid_block_type == "UNetMidBlock2DCrossAttn":
- self.mid_block = UNetMidBlock2DCrossAttn(
- transformer_layers_per_block=transformer_layers_per_block[-1],
- in_channels=block_out_channels[-1],
- temb_channels=blocks_time_embed_dim,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- output_scale_factor=mid_block_scale_factor,
- resnet_time_scale_shift=resnet_time_scale_shift,
- cross_attention_dim=cross_attention_dim[-1],
- num_attention_heads=num_attention_heads[-1],
- resnet_groups=norm_num_groups,
- dual_cross_attention=dual_cross_attention,
- use_linear_projection=use_linear_projection,
- upcast_attention=upcast_attention,
- )
- # custom MV2D attention block
- elif mid_block_type == "UNetMidBlockMV2DCrossAttn":
- self.mid_block = UNetMidBlockMV2DCrossAttn(
- transformer_layers_per_block=transformer_layers_per_block[-1],
- in_channels=block_out_channels[-1],
- temb_channels=blocks_time_embed_dim,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- output_scale_factor=mid_block_scale_factor,
- resnet_time_scale_shift=resnet_time_scale_shift,
- cross_attention_dim=cross_attention_dim[-1],
- num_attention_heads=num_attention_heads[-1],
- resnet_groups=norm_num_groups,
- dual_cross_attention=dual_cross_attention,
- use_linear_projection=use_linear_projection,
- upcast_attention=upcast_attention,
- num_views=num_views,
- cd_attention_last=cd_attention_last,
- cd_attention_mid=cd_attention_mid,
- multiview_attention=multiview_attention,
- sparse_mv_attention=sparse_mv_attention,
- selfattn_block=selfattn_block,
- mvcd_attention=mvcd_attention,
- use_dino=use_dino
- )
- elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
- self.mid_block = UNetMidBlock2DSimpleCrossAttn(
- in_channels=block_out_channels[-1],
- temb_channels=blocks_time_embed_dim,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- output_scale_factor=mid_block_scale_factor,
- cross_attention_dim=cross_attention_dim[-1],
- attention_head_dim=attention_head_dim[-1],
- resnet_groups=norm_num_groups,
- resnet_time_scale_shift=resnet_time_scale_shift,
- skip_time_act=resnet_skip_time_act,
- only_cross_attention=mid_block_only_cross_attention,
- cross_attention_norm=cross_attention_norm,
- )
- elif mid_block_type is None:
- self.mid_block = None
- else:
- raise ValueError(f"unknown mid_block_type : {mid_block_type}")
-
- self.addition_downsample = addition_downsample
- if self.addition_downsample:
- inc = block_out_channels[-1]
- self.downsample = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
- self.conv_block = nn.ModuleList()
- self.conv_block.append(BasicConvBlock(inc, addition_channels[0], stride=1))
- for dim_ in addition_channels[1:-1]:
- self.conv_block.append(BasicConvBlock(dim_, dim_, stride=1))
- self.conv_block.append(BasicConvBlock(dim_, inc))
- self.addition_conv_out = nn.Conv2d(inc, inc, kernel_size=1, bias=False)
- nn.init.zeros_(self.addition_conv_out.weight.data)
- self.addition_act_out = nn.SiLU()
- self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
-
- self.regress_elevation = regress_elevation
- self.regress_focal_length = regress_focal_length
- if regress_elevation or regress_focal_length:
- self.pool = nn.AdaptiveAvgPool2d((1, 1))
- self.camera_embedding = TimestepEmbedding(projection_camera_embeddings_input_dim, time_embed_dim=time_embed_dim)
-
- regress_in_dim = block_out_channels[-1]*2 if mvcd_attention else block_out_channels
-
- if regress_elevation:
- self.elevation_regressor = ResidualLiner(regress_in_dim, 1, 1280, act=None, num_block=num_regress_blocks)
- if regress_focal_length:
- self.focal_regressor = ResidualLiner(regress_in_dim, 1, 1280, act=None, num_block=num_regress_blocks)
- '''
- self.regress_elevation = regress_elevation
- self.regress_focal_length = regress_focal_length
- if regress_elevation and (not regress_focal_length):
- print("Regressing elevation")
- cam_dim = 1
- elif regress_focal_length and (not regress_elevation):
- print("Regressing focal length")
- cam_dim = 6
- elif regress_elevation and regress_focal_length:
- print("Regressing both elevation and focal length")
- cam_dim = 7
- else:
- cam_dim = 0
- assert projection_camera_embeddings_input_dim == 2*cam_dim, "projection_camera_embeddings_input_dim should be 2*cam_dim"
- if regress_elevation or regress_focal_length:
- self.elevation_regressor = nn.ModuleList([
- nn.Linear(block_out_channels[-1], 1280),
- nn.SiLU(),
- nn.Linear(1280, 1280),
- nn.SiLU(),
- nn.Linear(1280, cam_dim)
- ])
- self.pool = nn.AdaptiveAvgPool2d((1, 1))
- self.focal_act = nn.Softmax(dim=-1)
- self.camera_embedding = TimestepEmbedding(projection_camera_embeddings_input_dim, time_embed_dim=time_embed_dim)
- '''
-
- # count how many layers upsample the images
- self.num_upsamplers = 0
-
- # up
- reversed_block_out_channels = list(reversed(block_out_channels))
- reversed_num_attention_heads = list(reversed(num_attention_heads))
- reversed_layers_per_block = list(reversed(layers_per_block))
- reversed_cross_attention_dim = list(reversed(cross_attention_dim))
- reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
- only_cross_attention = list(reversed(only_cross_attention))
-
- output_channel = reversed_block_out_channels[0]
- for i, up_block_type in enumerate(up_block_types):
- is_final_block = i == len(block_out_channels) - 1
-
- prev_output_channel = output_channel
- output_channel = reversed_block_out_channels[i]
- input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
-
- # add upsample block for all BUT final layer
- if not is_final_block:
- add_upsample = True
- self.num_upsamplers += 1
- else:
- add_upsample = False
-
- up_block = get_up_block(
- up_block_type,
- num_layers=reversed_layers_per_block[i] + 1,
- transformer_layers_per_block=reversed_transformer_layers_per_block[i],
- in_channels=input_channel,
- out_channels=output_channel,
- prev_output_channel=prev_output_channel,
- temb_channels=blocks_time_embed_dim,
- add_upsample=add_upsample,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- resnet_groups=norm_num_groups,
- cross_attention_dim=reversed_cross_attention_dim[i],
- num_attention_heads=reversed_num_attention_heads[i],
- dual_cross_attention=dual_cross_attention,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention[i],
- upcast_attention=upcast_attention,
- resnet_time_scale_shift=resnet_time_scale_shift,
- resnet_skip_time_act=resnet_skip_time_act,
- resnet_out_scale_factor=resnet_out_scale_factor,
- cross_attention_norm=cross_attention_norm,
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
- num_views=num_views,
- cd_attention_last=cd_attention_last,
- cd_attention_mid=cd_attention_mid,
- multiview_attention=multiview_attention,
- sparse_mv_attention=sparse_mv_attention,
- selfattn_block=selfattn_block,
- mvcd_attention=mvcd_attention,
- use_dino=use_dino
- )
- self.up_blocks.append(up_block)
- prev_output_channel = output_channel
-
- # out
- if norm_num_groups is not None:
- self.conv_norm_out = nn.GroupNorm(
- num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
- )
-
- self.conv_act = get_activation(act_fn)
-
- else:
- self.conv_norm_out = None
- self.conv_act = None
-
- conv_out_padding = (conv_out_kernel - 1) // 2
- self.conv_out = nn.Conv2d(
- block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
- )
-
- @property
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "set_processor"):
- processors[f"{name}.processor"] = module.processor
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
- def set_default_attn_processor(self):
- """
- Disables custom attention processors and sets the default attention implementation.
- """
- self.set_attn_processor(AttnProcessor())
-
- def set_attention_slice(self, slice_size):
- r"""
- Enable sliced attention computation.
-
- When this option is enabled, the attention module splits the input tensor in slices to compute attention in
- several steps. This is useful for saving some memory in exchange for a small decrease in speed.
-
- Args:
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
- When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
- `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
- must be a multiple of `slice_size`.
- """
- sliceable_head_dims = []
-
- def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
- if hasattr(module, "set_attention_slice"):
- sliceable_head_dims.append(module.sliceable_head_dim)
-
- for child in module.children():
- fn_recursive_retrieve_sliceable_dims(child)
-
- # retrieve number of attention layers
- for module in self.children():
- fn_recursive_retrieve_sliceable_dims(module)
-
- num_sliceable_layers = len(sliceable_head_dims)
-
- if slice_size == "auto":
- # half the attention head size is usually a good trade-off between
- # speed and memory
- slice_size = [dim // 2 for dim in sliceable_head_dims]
- elif slice_size == "max":
- # make smallest slice possible
- slice_size = num_sliceable_layers * [1]
-
- slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
-
- if len(slice_size) != len(sliceable_head_dims):
- raise ValueError(
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
- )
-
- for i in range(len(slice_size)):
- size = slice_size[i]
- dim = sliceable_head_dims[i]
- if size is not None and size > dim:
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
-
- # Recursively walk through all the children.
- # Any children which exposes the set_attention_slice method
- # gets the message
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
- if hasattr(module, "set_attention_slice"):
- module.set_attention_slice(slice_size.pop())
-
- for child in module.children():
- fn_recursive_set_attention_slice(child, slice_size)
-
- reversed_slice_size = list(reversed(slice_size))
- for module in self.children():
- fn_recursive_set_attention_slice(module, reversed_slice_size)
-
- def _set_gradient_checkpointing(self, module, value=False):
- if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)):
- module.gradient_checkpointing = value
-
- def forward(
- self,
- sample: torch.FloatTensor,
- timestep: Union[torch.Tensor, float, int],
- encoder_hidden_states: torch.Tensor,
- class_labels: Optional[torch.Tensor] = None,
- timestep_cond: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
- down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
- mid_block_additional_residual: Optional[torch.Tensor] = None,
- encoder_attention_mask: Optional[torch.Tensor] = None,
- dino_feature: Optional[torch.Tensor] = None,
- return_dict: bool = True,
- vis_max_min: bool = False,
- ) -> Union[UNetMV2DConditionOutput, Tuple]:
- r"""
- The [`UNet2DConditionModel`] forward method.
-
- Args:
- sample (`torch.FloatTensor`):
- The noisy input tensor with the following shape `(batch, channel, height, width)`.
- timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
- encoder_hidden_states (`torch.FloatTensor`):
- The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
- encoder_attention_mask (`torch.Tensor`):
- A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
- `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
- which adds large negative values to the attention scores corresponding to "discard" tokens.
- return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
- tuple.
- cross_attention_kwargs (`dict`, *optional*):
- A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
- added_cond_kwargs: (`dict`, *optional*):
- A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
- are passed along to the UNet blocks.
-
- Returns:
- [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
- If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
- a `tuple` is returned where the first element is the sample tensor.
- """
- record_max_min = {}
- # By default samples have to be AT least a multiple of the overall upsampling factor.
- # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
- # However, the upsampling interpolation output size can be forced to fit any upsampling size
- # on the fly if necessary.
- default_overall_up_factor = 2**self.num_upsamplers
-
- # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
- forward_upsample_size = False
- upsample_size = None
-
- if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
- logger.info("Forward upsample size to force interpolation output size.")
- forward_upsample_size = True
-
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
- # expects mask of shape:
- # [batch, key_tokens]
- # adds singleton query_tokens dimension:
- # [batch, 1, key_tokens]
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
- if attention_mask is not None:
- # assume that mask is expressed as:
- # (1 = keep, 0 = discard)
- # convert mask into a bias that can be added to attention scores:
- # (keep = +0, discard = -10000.0)
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
- attention_mask = attention_mask.unsqueeze(1)
-
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
- if encoder_attention_mask is not None:
- encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
-
- # 0. center input if necessary
- if self.config.center_input_sample:
- sample = 2 * sample - 1.0
- # 1. time
- timesteps = timestep
- if not torch.is_tensor(timesteps):
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
- # This would be a good case for the `match` statement (Python 3.10+)
- is_mps = sample.device.type == "mps"
- if isinstance(timestep, float):
- dtype = torch.float32 if is_mps else torch.float64
- else:
- dtype = torch.int32 if is_mps else torch.int64
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
- elif len(timesteps.shape) == 0:
- timesteps = timesteps[None].to(sample.device)
-
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
- timesteps = timesteps.expand(sample.shape[0])
-
- t_emb = self.time_proj(timesteps)
-
- # `Timesteps` does not contain any weights and will always return f32 tensors
- # but time_embedding might actually be running in fp16. so we need to cast here.
- # there might be better ways to encapsulate this.
- t_emb = t_emb.to(dtype=sample.dtype)
-
- emb = self.time_embedding(t_emb, timestep_cond)
- aug_emb = None
- if self.class_embedding is not None:
- if class_labels is None:
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
-
- if self.config.class_embed_type == "timestep":
- class_labels = self.time_proj(class_labels)
-
- # `Timesteps` does not contain any weights and will always return f32 tensors
- # there might be better ways to encapsulate this.
- class_labels = class_labels.to(dtype=sample.dtype)
-
- class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
- if self.config.class_embeddings_concat:
- emb = torch.cat([emb, class_emb], dim=-1)
- else:
- emb = emb + class_emb
-
- if self.config.addition_embed_type == "text":
- aug_emb = self.add_embedding(encoder_hidden_states)
- elif self.config.addition_embed_type == "text_image":
- # Kandinsky 2.1 - style
- if "image_embeds" not in added_cond_kwargs:
- raise ValueError(
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
- )
-
- image_embs = added_cond_kwargs.get("image_embeds")
- text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
- aug_emb = self.add_embedding(text_embs, image_embs)
- elif self.config.addition_embed_type == "text_time":
- # SDXL - style
- if "text_embeds" not in added_cond_kwargs:
- raise ValueError(
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
- )
- text_embeds = added_cond_kwargs.get("text_embeds")
- if "time_ids" not in added_cond_kwargs:
- raise ValueError(
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
- )
- time_ids = added_cond_kwargs.get("time_ids")
- time_embeds = self.add_time_proj(time_ids.flatten())
- time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
-
- add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
- add_embeds = add_embeds.to(emb.dtype)
- aug_emb = self.add_embedding(add_embeds)
- elif self.config.addition_embed_type == "image":
- # Kandinsky 2.2 - style
- if "image_embeds" not in added_cond_kwargs:
- raise ValueError(
- f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
- )
- image_embs = added_cond_kwargs.get("image_embeds")
- aug_emb = self.add_embedding(image_embs)
- elif self.config.addition_embed_type == "image_hint":
- # Kandinsky 2.2 - style
- if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
- raise ValueError(
- f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
- )
- image_embs = added_cond_kwargs.get("image_embeds")
- hint = added_cond_kwargs.get("hint")
- aug_emb, hint = self.add_embedding(image_embs, hint)
- sample = torch.cat([sample, hint], dim=1)
-
- emb = emb + aug_emb if aug_emb is not None else emb
- emb_pre_act = emb
- if self.time_embed_act is not None:
- emb = self.time_embed_act(emb)
-
- if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
- encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
- # Kadinsky 2.1 - style
- if "image_embeds" not in added_cond_kwargs:
- raise ValueError(
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
- )
-
- image_embeds = added_cond_kwargs.get("image_embeds")
- encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
- # Kandinsky 2.2 - style
- if "image_embeds" not in added_cond_kwargs:
- raise ValueError(
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
- )
- image_embeds = added_cond_kwargs.get("image_embeds")
- encoder_hidden_states = self.encoder_hid_proj(image_embeds)
- # 2. pre-process
- sample = self.conv_in(sample)
- # 3. down
-
- is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
- is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
-
- down_block_res_samples = (sample,)
- for i, downsample_block in enumerate(self.down_blocks):
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
- # For t2i-adapter CrossAttnDownBlock2D
- additional_residuals = {}
- if is_adapter and len(down_block_additional_residuals) > 0:
- additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
-
- sample, res_samples = downsample_block(
- hidden_states=sample,
- temb=emb,
- encoder_hidden_states=encoder_hidden_states,
- dino_feature=dino_feature,
- attention_mask=attention_mask,
- cross_attention_kwargs=cross_attention_kwargs,
- encoder_attention_mask=encoder_attention_mask,
- **additional_residuals,
- )
- else:
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
-
- if is_adapter and len(down_block_additional_residuals) > 0:
- sample += down_block_additional_residuals.pop(0)
-
- down_block_res_samples += res_samples
-
- if is_controlnet:
- new_down_block_res_samples = ()
-
- for down_block_res_sample, down_block_additional_residual in zip(
- down_block_res_samples, down_block_additional_residuals
- ):
- down_block_res_sample = down_block_res_sample + down_block_additional_residual
- new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
-
- down_block_res_samples = new_down_block_res_samples
-
- if self.addition_downsample:
- global_sample = sample
- global_sample = self.downsample(global_sample)
- for layer in self.conv_block:
- global_sample = layer(global_sample)
- global_sample = self.addition_act_out(self.addition_conv_out(global_sample))
- global_sample = self.upsample(global_sample)
- # 4. mid
- if self.mid_block is not None:
- sample = self.mid_block(
- sample,
- emb,
- encoder_hidden_states=encoder_hidden_states,
- dino_feature=dino_feature,
- attention_mask=attention_mask,
- cross_attention_kwargs=cross_attention_kwargs,
- encoder_attention_mask=encoder_attention_mask,
- )
- # 4.1 regress elevation and focal length
- # # predict elevation -> embed -> projection -> add to time emb
- if self.regress_elevation or self.regress_focal_length:
- pool_embeds = self.pool(sample.detach()).squeeze(-1).squeeze(-1) # (2B, C)
- if self.mvcd_attention:
- pool_embeds_normal, pool_embeds_color = torch.chunk(pool_embeds, 2, dim=0)
- pool_embeds = torch.cat([pool_embeds_normal, pool_embeds_color], dim=-1) # (B, 2C)
- pose_pred = []
- if self.regress_elevation:
- ele_pred = self.elevation_regressor(pool_embeds)
- ele_pred = rearrange(ele_pred, '(b v) c -> b v c', v=self.num_views)
- ele_pred = torch.mean(ele_pred, dim=1)
- pose_pred.append(ele_pred) # b, c
-
- if self.regress_focal_length:
- focal_pred = self.focal_regressor(pool_embeds)
- focal_pred = rearrange(focal_pred, '(b v) c -> b v c', v=self.num_views)
- focal_pred = torch.mean(focal_pred, dim=1)
- pose_pred.append(focal_pred)
- pose_pred = torch.cat(pose_pred, dim=-1)
- # 'e_de_da_sincos', (B, 2)
- pose_embeds = torch.cat([
- torch.sin(pose_pred),
- torch.cos(pose_pred)
- ], dim=-1)
- pose_embeds = self.camera_embedding(pose_embeds)
- pose_embeds = torch.repeat_interleave(pose_embeds, self.num_views, 0)
- if self.mvcd_attention:
- pose_embeds = torch.cat([pose_embeds,] * 2, dim=0)
-
- emb = pose_embeds + emb_pre_act
- if self.time_embed_act is not None:
- emb = self.time_embed_act(emb)
-
- if is_controlnet:
- sample = sample + mid_block_additional_residual
-
- if self.addition_downsample:
- sample = sample + global_sample
-
- # 5. up
- for i, upsample_block in enumerate(self.up_blocks):
- is_final_block = i == len(self.up_blocks) - 1
-
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
-
- # if we have not reached the final block and need to forward the
- # upsample size, we do it here
- if not is_final_block and forward_upsample_size:
- upsample_size = down_block_res_samples[-1].shape[2:]
-
- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
- sample = upsample_block(
- hidden_states=sample,
- temb=emb,
- res_hidden_states_tuple=res_samples,
- encoder_hidden_states=encoder_hidden_states,
- dino_feature=dino_feature,
- cross_attention_kwargs=cross_attention_kwargs,
- upsample_size=upsample_size,
- attention_mask=attention_mask,
- encoder_attention_mask=encoder_attention_mask,
- )
- else:
- sample = upsample_block(
- hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
- )
- if torch.isnan(sample).any() or torch.isinf(sample).any():
- print("NAN in sample, stop training.")
- exit()
- # 6. post-process
- if self.conv_norm_out:
- sample = self.conv_norm_out(sample)
- sample = self.conv_act(sample)
- sample = self.conv_out(sample)
- if not return_dict:
- return (sample, pose_pred)
- if self.regress_elevation or self.regress_focal_length:
- return UNetMV2DConditionOutput(sample=sample), pose_pred
- else:
- return UNetMV2DConditionOutput(sample=sample)
-
-
- @classmethod
- def from_pretrained_2d(
- cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
- camera_embedding_type: str, num_views: int, sample_size: int,
- zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False,
- projection_camera_embeddings_input_dim: int=2,
- cd_attention_last: bool = False, num_regress_blocks: int = 4,
- cd_attention_mid: bool = False, multiview_attention: bool = True,
- sparse_mv_attention: bool = False, selfattn_block: str = 'custom', mvcd_attention: bool = False,
- in_channels: int = 8, out_channels: int = 4, unclip: bool = False, regress_elevation: bool = False, regress_focal_length: bool = False,
- init_mvattn_with_selfattn: bool= False, use_dino: bool = False, addition_downsample: bool = False,
- **kwargs
- ):
- r"""
- Instantiate a pretrained PyTorch model from a pretrained model configuration.
-
- The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
- train the model, set it back in training mode with `model.train()`.
-
- Parameters:
- pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`~ModelMixin.save_pretrained`].
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- torch_dtype (`str` or `torch.dtype`, *optional*):
- Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
- dtype is automatically derived from the model's weights.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
- resume_download (`bool`, *optional*, defaults to `False`):
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
- incompletely downloaded files are deleted.
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- output_loading_info (`bool`, *optional*, defaults to `False`):
- Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
- local_files_only(`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- use_auth_token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- from_flax (`bool`, *optional*, defaults to `False`):
- Load the model weights from a Flax checkpoint save file.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- mirror (`str`, *optional*):
- Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
- guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
- information.
- device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
- A map that specifies where each submodule should go. It doesn't need to be defined for each
- parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
- same device.
-
- Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
- more information about each option see [designing a device
- map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
- max_memory (`Dict`, *optional*):
- A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
- each GPU and the available CPU RAM if unset.
- offload_folder (`str` or `os.PathLike`, *optional*):
- The path to offload weights if `device_map` contains the value `"disk"`.
- offload_state_dict (`bool`, *optional*):
- If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
- the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
- when there is some disk offload.
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
- Speed up model loading only loading the pretrained weights and not initializing the weights. This also
- tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
- Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
- argument to `True` will raise an error.
- variant (`str`, *optional*):
- Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
- loading `from_flax`.
- use_safetensors (`bool`, *optional*, defaults to `None`):
- If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
- `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
- weights. If set to `False`, `safetensors` weights are not loaded.
-
-
{content} | -