diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..4720b1a53b512a49e81661f68abecdc267a7bae7 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,35 +1,37 @@ -*.7z filter=lfs diff=lfs merge=lfs -text -*.arrow filter=lfs diff=lfs merge=lfs -text -*.bin filter=lfs diff=lfs merge=lfs -text -*.bz2 filter=lfs diff=lfs merge=lfs -text -*.ckpt filter=lfs diff=lfs merge=lfs -text -*.ftz filter=lfs diff=lfs merge=lfs -text -*.gz filter=lfs diff=lfs merge=lfs -text -*.h5 filter=lfs diff=lfs merge=lfs -text -*.joblib filter=lfs diff=lfs merge=lfs -text -*.lfs.* filter=lfs diff=lfs merge=lfs -text -*.mlmodel filter=lfs diff=lfs merge=lfs -text -*.model filter=lfs diff=lfs merge=lfs -text -*.msgpack filter=lfs diff=lfs merge=lfs -text -*.npy filter=lfs diff=lfs merge=lfs -text -*.npz filter=lfs diff=lfs merge=lfs -text -*.onnx filter=lfs diff=lfs merge=lfs -text -*.ot filter=lfs diff=lfs merge=lfs -text -*.parquet filter=lfs diff=lfs merge=lfs -text -*.pb filter=lfs diff=lfs merge=lfs -text -*.pickle filter=lfs diff=lfs merge=lfs -text -*.pkl filter=lfs diff=lfs merge=lfs -text -*.pt filter=lfs diff=lfs merge=lfs -text -*.pth filter=lfs diff=lfs merge=lfs -text -*.rar filter=lfs diff=lfs merge=lfs -text -*.safetensors filter=lfs diff=lfs merge=lfs -text -saved_model/**/* filter=lfs diff=lfs merge=lfs -text -*.tar.* filter=lfs diff=lfs merge=lfs -text -*.tar filter=lfs diff=lfs merge=lfs -text -*.tflite filter=lfs diff=lfs merge=lfs -text -*.tgz filter=lfs diff=lfs merge=lfs -text -*.wasm filter=lfs diff=lfs merge=lfs -text -*.xz filter=lfs diff=lfs merge=lfs -text -*.zip filter=lfs diff=lfs merge=lfs -text -*.zst filter=lfs diff=lfs merge=lfs -text -*tfevents* filter=lfs diff=lfs merge=lfs -text +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +assets/teaser.png filter=lfs diff=lfs merge=lfs -text +demos/example_003.png filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index 4bb7a445b43a31999cb8d125c044a86cfdf08fd4..251cf678ebe65c33fbb110fb78495e15c99fb1bd 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,11 @@ --- -title: InstantIR -emoji: 🖼 +title: Hunyuan3D-1.0 +emoji: 😻 colorFrom: purple colorTo: red sdk: gradio sdk_version: 4.42.0 app_file: app.py pinned: false -license: apache-2.0 -short_description: diffusion-based Image Restoration model ---- - -Check out the configuration reference at https://huggingface.co./docs/hub/spaces-config-reference \ No newline at end of file +short_description: Text-to-3D and Image-to-3D Generation +--- \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..5bc5dde13399a86d0ad1438afb50515f76eb550c --- /dev/null +++ b/app.py @@ -0,0 +1,284 @@ +import os +import warnings +from huggingface_hub import hf_hub_download +import gradio as gr +from glob import glob +import shutil +import torch +import numpy as np +from PIL import Image +from einops import rearrange +import argparse + +# Suppress warnings +warnings.simplefilter('ignore', category=UserWarning) +warnings.simplefilter('ignore', category=FutureWarning) +warnings.simplefilter('ignore', category=DeprecationWarning) + +def download_models(): + # Create weights directory if it doesn't exist + os.makedirs("weights", exist_ok=True) + os.makedirs("weights/hunyuanDiT", exist_ok=True) + + # Download Hunyuan3D-1 model + try: + hf_hub_download( + repo_id="tencent/Hunyuan3D-1", + local_dir="./weights", + resume_download=True + ) + print("Successfully downloaded Hunyuan3D-1 model") + except Exception as e: + print(f"Error downloading Hunyuan3D-1: {e}") + + # Download HunyuanDiT model + try: + hf_hub_download( + repo_id="Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled", + local_dir="./weights/hunyuanDiT", + resume_download=True + ) + print("Successfully downloaded HunyuanDiT model") + except Exception as e: + print(f"Error downloading HunyuanDiT: {e}") + +# Download models before starting the app +download_models() + +# Parse arguments +parser = argparse.ArgumentParser() +parser.add_argument("--use_lite", default=False, action="store_true") +parser.add_argument("--mv23d_cfg_path", default="./svrm/configs/svrm.yaml", type=str) +parser.add_argument("--mv23d_ckt_path", default="weights/svrm/svrm.safetensors", type=str) +parser.add_argument("--text2image_path", default="weights/hunyuanDiT", type=str) +parser.add_argument("--save_memory", default=False, action="store_true") +parser.add_argument("--device", default="cuda:0", type=str) +args = parser.parse_args() + +# Constants +CONST_PORT = 8080 +CONST_MAX_QUEUE = 1 +CONST_SERVER = '0.0.0.0' + +CONST_HEADER = ''' +

Official 🤗 Gradio Demo

+

+Hunyuan3D-1.0: A Unified Framework for Text-to-3D and Image-to-3D Generation

+''' + +# Helper functions +def get_example_img_list(): + print('Loading example img list ...') + return sorted(glob('./demos/example_*.png')) + +def get_example_txt_list(): + print('Loading example txt list ...') + txt_list = [] + for line in open('./demos/example_list.txt'): + txt_list.append(line.strip()) + return txt_list + +example_is = get_example_img_list() +example_ts = get_example_txt_list() + +# Import required workers +from infer import seed_everything, save_gif +from infer import Text2Image, Removebg, Image2Views, Views2Mesh, GifRenderer + +# Initialize workers +worker_xbg = Removebg() +print(f"loading {args.text2image_path}") +worker_t2i = Text2Image( + pretrain=args.text2image_path, + device=args.device, + save_memory=args.save_memory +) +worker_i2v = Image2Views( + use_lite=args.use_lite, + device=args.device +) +worker_v23 = Views2Mesh( + args.mv23d_cfg_path, + args.mv23d_ckt_path, + use_lite=args.use_lite, + device=args.device +) +worker_gif = GifRenderer(args.device) + +# Pipeline stages +def stage_0_t2i(text, image, seed, step): + os.makedirs('./outputs/app_output', exist_ok=True) + exists = set(int(_) for _ in os.listdir('./outputs/app_output') if not _.startswith(".")) + cur_id = min(set(range(30)) - exists) if len(exists) < 30 else 0 + + if os.path.exists(f"./outputs/app_output/{(cur_id + 1) % 30}"): + shutil.rmtree(f"./outputs/app_output/{(cur_id + 1) % 30}") + save_folder = f'./outputs/app_output/{cur_id}' + os.makedirs(save_folder, exist_ok=True) + + dst = save_folder + '/img.png' + + if not text: + if image is None: + return dst, save_folder + image.save(dst) + return dst, save_folder + + image = worker_t2i(text, seed, step) + image.save(dst) + dst = worker_xbg(image, save_folder) + return dst, save_folder + +def stage_1_xbg(image, save_folder): + if isinstance(image, str): + image = Image.open(image) + dst = save_folder + '/img_nobg.png' + rgba = worker_xbg(image) + rgba.save(dst) + return dst + +def stage_2_i2v(image, seed, step, save_folder): + if isinstance(image, str): + image = Image.open(image) + gif_dst = save_folder + '/views.gif' + res_img, pils = worker_i2v(image, seed, step) + save_gif(pils, gif_dst) + views_img, cond_img = res_img[0], res_img[1] + img_array = np.asarray(views_img, dtype=np.uint8) + show_img = rearrange(img_array, '(n h) (m w) c -> (n m) h w c', n=3, m=2) + show_img = show_img[worker_i2v.order, ...] + show_img = rearrange(show_img, '(n m) h w c -> (n h) (m w) c', n=2, m=3) + show_img = Image.fromarray(show_img) + return views_img, cond_img, show_img + +def stage_3_v23(views_pil, cond_pil, seed, save_folder, target_face_count=30000, + do_texture_mapping=True, do_render=True): + do_texture_mapping = do_texture_mapping or do_render + obj_dst = save_folder + '/mesh_with_colors.obj' + glb_dst = save_folder + '/mesh.glb' + worker_v23( + views_pil, + cond_pil, + seed=seed, + save_folder=save_folder, + target_face_count=target_face_count, + do_texture_mapping=do_texture_mapping + ) + return obj_dst, glb_dst + +def stage_4_gif(obj_dst, save_folder, do_render_gif=True): + if not do_render_gif: + return None + gif_dst = save_folder + '/output.gif' + worker_gif( + save_folder + '/mesh.obj', + gif_dst_path=gif_dst + ) + return gif_dst + +# Gradio Interface +with gr.Blocks() as demo: + gr.Markdown(CONST_HEADER) + + with gr.Row(variant="panel"): + with gr.Column(scale=2): + with gr.Tab("Text to 3D"): + with gr.Column(): + text = gr.TextArea('一只黑白相间的熊猫在白色背景上居中坐着,呈现出卡通风格和可爱氛围。', + lines=1, max_lines=10, label='Input text') + with gr.Row(): + textgen_seed = gr.Number(value=0, label="T2I seed", precision=0) + textgen_step = gr.Number(value=25, label="T2I step", precision=0) + textgen_SEED = gr.Number(value=0, label="Gen seed", precision=0) + textgen_STEP = gr.Number(value=50, label="Gen step", precision=0) + textgen_max_faces = gr.Number(value=90000, label="max number of faces", precision=0) + + with gr.Row(): + textgen_do_texture_mapping = gr.Checkbox(label="texture mapping", value=False) + textgen_do_render_gif = gr.Checkbox(label="Render gif", value=False) + textgen_submit = gr.Button("Generate", variant="primary") + + gr.Examples(examples=example_ts, inputs=[text], label="Txt examples") + + with gr.Tab("Image to 3D"): + with gr.Column(): + input_image = gr.Image(label="Input image", width=256, height=256, + type="pil", image_mode="RGBA", sources="upload") + with gr.Row(): + imggen_SEED = gr.Number(value=0, label="Gen seed", precision=0) + imggen_STEP = gr.Number(value=50, label="Gen step", precision=0) + imggen_max_faces = gr.Number(value=90000, label="max number of faces", precision=0) + + with gr.Row(): + imggen_do_texture_mapping = gr.Checkbox(label="texture mapping", value=False) + imggen_do_render_gif = gr.Checkbox(label="Render gif", value=False) + imggen_submit = gr.Button("Generate", variant="primary") + + gr.Examples(examples=example_is, inputs=[input_image], label="Img examples") + + with gr.Column(scale=3): + with gr.Tab("rembg image"): + rem_bg_image = gr.Image(label="No background image", width=256, height=256, + type="pil", image_mode="RGBA") + + with gr.Tab("Multi views"): + result_image = gr.Image(label="Multi views", type="pil") + with gr.Tab("Obj"): + result_3dobj = gr.Model3D(label="Output obj") + with gr.Tab("Glb"): + result_3dglb = gr.Model3D(label="Output glb") + with gr.Tab("GIF"): + result_gif = gr.Image(label="Rendered GIF") + + # States + none = gr.State(None) + save_folder = gr.State() + cond_image = gr.State() + views_image = gr.State() + text_image = gr.State() + + # Event handlers + textgen_submit.click( + fn=stage_0_t2i, + inputs=[text, none, textgen_seed, textgen_step], + outputs=[rem_bg_image, save_folder], + ).success( + fn=stage_2_i2v, + inputs=[rem_bg_image, textgen_SEED, textgen_STEP, save_folder], + outputs=[views_image, cond_image, result_image], + ).success( + fn=stage_3_v23, + inputs=[views_image, cond_image, textgen_SEED, save_folder, textgen_max_faces, + textgen_do_texture_mapping, textgen_do_render_gif], + outputs=[result_3dobj, result_3dglb], + ).success( + fn=stage_4_gif, + inputs=[result_3dglb, save_folder, textgen_do_render_gif], + outputs=[result_gif], + ) + + imggen_submit.click( + fn=stage_0_t2i, + inputs=[none, input_image, textgen_seed, textgen_step], + outputs=[text_image, save_folder], + ).success( + fn=stage_1_xbg, + inputs=[text_image, save_folder], + outputs=[rem_bg_image], + ).success( + fn=stage_2_i2v, + inputs=[rem_bg_image, imggen_SEED, imggen_STEP, save_folder], + outputs=[views_image, cond_image, result_image], + ).success( + fn=stage_3_v23, + inputs=[views_image, cond_image, imggen_SEED, save_folder, imggen_max_faces, + imggen_do_texture_mapping, imggen_do_render_gif], + outputs=[result_3dobj, result_3dglb], + ).success( + fn=stage_4_gif, + inputs=[result_3dglb, save_folder, imggen_do_render_gif], + outputs=[result_gif], + ) + + demo.queue(max_size=CONST_MAX_QUEUE) + demo.launch(server_name=CONST_SERVER, server_port=CONST_PORT) \ No newline at end of file diff --git a/assets/logo.png b/assets/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..06dab332d2138fe5ddbd0555780817e9602e779d Binary files /dev/null and b/assets/logo.png differ diff --git a/assets/overview_3.png b/assets/overview_3.png new file mode 100644 index 0000000000000000000000000000000000000000..edbec8202aeed23a277c94dc92d15515db2b18dc Binary files /dev/null and b/assets/overview_3.png differ diff --git a/assets/radar.png b/assets/radar.png new file mode 100644 index 0000000000000000000000000000000000000000..fd57b17e36f15e9e78d4e8177fd62cef00f0e94b Binary files /dev/null and b/assets/radar.png differ diff --git a/assets/runtime.png b/assets/runtime.png new file mode 100644 index 0000000000000000000000000000000000000000..220d288e9f78cfb4e6bf4ed811cb2890e52925ae Binary files /dev/null and b/assets/runtime.png differ diff --git a/assets/teaser.png b/assets/teaser.png new file mode 100644 index 0000000000000000000000000000000000000000..f459167e3a32df9629dce6df1ffc41a6994825af --- /dev/null +++ b/assets/teaser.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af24eeebe39864d377b7ef8e11521a8b7cba964c14032cc28bd0d95bd5219c00 +size 3097514 diff --git a/demos/example_000.png b/demos/example_000.png new file mode 100644 index 0000000000000000000000000000000000000000..6891f23a3a621255ee97a100bb40b5a27a9cfa20 Binary files /dev/null and b/demos/example_000.png differ diff --git a/demos/example_001.png b/demos/example_001.png new file mode 100644 index 0000000000000000000000000000000000000000..667dc22ce710ab308de3589b2f42cf7d10131ec8 Binary files /dev/null and b/demos/example_001.png differ diff --git a/demos/example_002.png b/demos/example_002.png new file mode 100644 index 0000000000000000000000000000000000000000..1c359f2bcffa3603462249a5046fe4f7a63c0435 Binary files /dev/null and b/demos/example_002.png differ diff --git a/demos/example_003.png b/demos/example_003.png new file mode 100644 index 0000000000000000000000000000000000000000..6c1e30acfb67a99da531571d2feaf5d3060ffefc --- /dev/null +++ b/demos/example_003.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d947e0ef10baf761abb78d2842519ae7428bc6eadab26a159510ddcaf2a47e67 +size 1066046 diff --git a/demos/example_list.txt b/demos/example_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..09d331354dc6978d24793bb4557b672485a9f5c9 --- /dev/null +++ b/demos/example_list.txt @@ -0,0 +1,2 @@ +a pot of green plants grows in a red flower pot. +a lovely rabbit eating carrots diff --git a/infer/__init__.py b/infer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5db8c89d17f7bdf774877d9d52f0c42f16d87617 --- /dev/null +++ b/infer/__init__.py @@ -0,0 +1,28 @@ +# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein: +# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited. + +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# The below software and/or models in this distribution may have been +# modified by THL A29 Limited ("Tencent Modifications"). +# All Tencent Modifications are Copyright (C) THL A29 Limited. + +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +from .utils import seed_everything, timing_decorator, auto_amp_inference +from .rembg import Removebg +from .text_to_image import Text2Image +from .image_to_views import Image2Views, save_gif +from .views_to_mesh import Views2Mesh +from .gif_render import GifRenderer diff --git a/infer/gif_render.py b/infer/gif_render.py new file mode 100644 index 0000000000000000000000000000000000000000..4d5bceffba12b5a595e7be889cc9c979f3db8c98 --- /dev/null +++ b/infer/gif_render.py @@ -0,0 +1,55 @@ +# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein: +# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited. + +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# The below software and/or models in this distribution may have been +# modified by THL A29 Limited ("Tencent Modifications"). +# All Tencent Modifications are Copyright (C) THL A29 Limited. + +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +from svrm.ldm.vis_util import render +from .utils import seed_everything, timing_decorator + +class GifRenderer(): + ''' + render frame(s) of mesh using pytorch3d + ''' + def __init__(self, device="cuda:0"): + self.device = device + + @timing_decorator("gif render") + def __call__( + self, + obj_filename, + elev=0, + azim=0, + resolution=512, + gif_dst_path='', + n_views=120, + fps=30, + rgb=True + ): + render( + obj_filename, + elev=elev, + azim=azim, + resolution=resolution, + gif_dst_path=gif_dst_path, + n_views=n_views, + fps=fps, + device=self.device, + rgb=rgb + ) diff --git a/infer/image_to_views.py b/infer/image_to_views.py new file mode 100644 index 0000000000000000000000000000000000000000..97a9ce3fe22a9bbd0bc0e05ea22c9007fc3e296d --- /dev/null +++ b/infer/image_to_views.py @@ -0,0 +1,81 @@ +# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein: +# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited. + +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# The below software and/or models in this distribution may have been +# modified by THL A29 Limited ("Tencent Modifications"). +# All Tencent Modifications are Copyright (C) THL A29 Limited. + +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +import os +import time +import torch +import random +import numpy as np +from PIL import Image +from einops import rearrange +from PIL import Image, ImageSequence + +from .utils import seed_everything, timing_decorator, auto_amp_inference +from .utils import get_parameter_number, set_parameter_grad_false +from mvd.hunyuan3d_mvd_std_pipeline import HunYuan3D_MVD_Std_Pipeline +from mvd.hunyuan3d_mvd_lite_pipeline import Hunyuan3d_MVD_Lite_Pipeline + + +def save_gif(pils, save_path, df=False): + # save a list of PIL.Image to gif + spf = 4000 / len(pils) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + pils[0].save(save_path, format="GIF", save_all=True, append_images=pils[1:], duration=spf, loop=0) + return save_path + + +class Image2Views(): + def __init__(self, device="cuda:0", use_lite=False): + self.device = device + if use_lite: + self.pipe = Hunyuan3d_MVD_Lite_Pipeline.from_pretrained( + "./weights/mvd_lite", + torch_dtype = torch.float16, + use_safetensors = True, + ) + else: + self.pipe = HunYuan3D_MVD_Std_Pipeline.from_pretrained( + "./weights/mvd_std", + torch_dtype = torch.float16, + use_safetensors = True, + ) + self.pipe = self.pipe.to(device) + self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1] + set_parameter_grad_false(self.pipe.unet) + print('image2views unet model', get_parameter_number(self.pipe.unet)) + + @torch.no_grad() + @timing_decorator("image to views") + @auto_amp_inference + def __call__(self, pil_img, seed=0, steps=50, guidance_scale=2.0, guidance_curve=lambda t:2.0): + seed_everything(seed) + generator = torch.Generator(device=self.device) + res_img = self.pipe(pil_img, + num_inference_steps=steps, + guidance_scale=guidance_scale, + guidance_curve=guidance_curve, + generat=generator).images + show_image = rearrange(np.asarray(res_img[0], dtype=np.uint8), '(n h) (m w) c -> (n m) h w c', n=3, m=2) + pils = [res_img[1]]+[Image.fromarray(show_image[idx]) for idx in self.order] + torch.cuda.empty_cache() + return res_img, pils + \ No newline at end of file diff --git a/infer/rembg.py b/infer/rembg.py new file mode 100644 index 0000000000000000000000000000000000000000..6ad1031e36670a393b88218a594939cfaab1262c --- /dev/null +++ b/infer/rembg.py @@ -0,0 +1,26 @@ +from rembg import remove, new_session +from .utils import timing_decorator + +class Removebg(): + def __init__(self, name="u2net"): + ''' + name: rembg + ''' + self.session = new_session(name) + + @timing_decorator("remove background") + def __call__(self, rgb_img, force=False): + ''' + inputs: + rgb_img: PIL.Image, with RGB mode expected + force: bool, input is RGBA mode + return: + rgba_img: PIL.Image with RGBA mode + ''' + if rgb_img.mode == "RGBA": + if force: + rgb_img = rgb_img.convert("RGB") + else: + return rgb_img + rgba_img = remove(rgb_img, session=self.session) + return rgba_img diff --git a/infer/text_to_image.py b/infer/text_to_image.py new file mode 100644 index 0000000000000000000000000000000000000000..ce1ed3b767b0a5d7ecf4260146f350fed0a72618 --- /dev/null +++ b/infer/text_to_image.py @@ -0,0 +1,80 @@ +# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein: +# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited. + +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# The below software and/or models in this distribution may have been +# modified by THL A29 Limited ("Tencent Modifications"). +# All Tencent Modifications are Copyright (C) THL A29 Limited. + +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +import torch +from .utils import seed_everything, timing_decorator, auto_amp_inference +from .utils import get_parameter_number, set_parameter_grad_false +from diffusers import HunyuanDiTPipeline, AutoPipelineForText2Image + +class Text2Image(): + def __init__(self, pretrain="weights/hunyuanDiT", device="cuda:0", save_memory=False): + ''' + save_memory: if GPU memory is low, can set it + ''' + self.save_memory = save_memory + self.device = device + self.pipe = AutoPipelineForText2Image.from_pretrained( + pretrain, + torch_dtype = torch.float16, + enable_pag = True, + pag_applied_layers = ["blocks.(16|17|18|19)"] + ) + set_parameter_grad_false(self.pipe.transformer) + print('text2image transformer model', get_parameter_number(self.pipe.transformer)) + if not save_memory: + self.pipe = self.pipe.to(device) + self.neg_txt = "文本,特写,裁剪,出框,最差质量,低质量,JPEG伪影,PGLY,重复,病态,残缺,多余的手指,变异的手," \ + "画得不好的手,画得不好的脸,变异,畸形,模糊,脱水,糟糕的解剖学,糟糕的比例,多余的肢体,克隆的脸," \ + "毁容,恶心的比例,畸形的肢体,缺失的手臂,缺失的腿,额外的手臂,额外的腿,融合的手指,手指太多,长脖子" + + @torch.no_grad() + @timing_decorator('text to image') + @auto_amp_inference + def __call__(self, *args, **kwargs): + if self.save_memory: + self.pipe = self.pipe.to(self.device) + torch.cuda.empty_cache() + res = self.call(*args, **kwargs) + self.pipe = self.pipe.to("cpu") + else: + res = self.call(*args, **kwargs) + torch.cuda.empty_cache() + return res + + def call(self, prompt, seed=0, steps=25): + ''' + inputs: + prompr: str + seed: int + steps: int + return: + rgb: PIL.Image + ''' + prompt = prompt + ",白色背景,3D风格,最佳质量" + seed_everything(seed) + generator = torch.Generator(device=self.device) + if seed is not None: generator = generator.manual_seed(int(seed)) + rgb = self.pipe(prompt=prompt, negative_prompt=self.neg_txt, num_inference_steps=steps, + pag_scale=1.3, width=1024, height=1024, generator=generator, return_dict=False)[0][0] + torch.cuda.empty_cache() + return rgb + \ No newline at end of file diff --git a/infer/utils.py b/infer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..659c9ac26e33b352d74083b104487871e6249b6a --- /dev/null +++ b/infer/utils.py @@ -0,0 +1,77 @@ +# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein: +# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited. + +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# The below software and/or models in this distribution may have been +# modified by THL A29 Limited ("Tencent Modifications"). +# All Tencent Modifications are Copyright (C) THL A29 Limited. + +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +import os +import time +import random +import numpy as np +import torch +from torch.cuda.amp import autocast, GradScaler +from functools import wraps + +def seed_everything(seed): + ''' + seed everthing + ''' + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + os.environ["PL_GLOBAL_SEED"] = str(seed) + +def timing_decorator(category: str): + ''' + timing_decorator: record time + ''' + def decorator(func): + func.call_count = 0 + @wraps(func) + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + elapsed_time = end_time - start_time + func.call_count += 1 + print(f"[HunYuan3D]-[{category}], cost time: {elapsed_time:.4f}s") # huiwen + return result + return wrapper + return decorator + +def auto_amp_inference(func): + ''' + with torch.cuda.amp.autocast()" + xxx + ''' + @wraps(func) + def wrapper(*args, **kwargs): + with autocast(): + output = func(*args, **kwargs) + return output + return wrapper + +def get_parameter_number(model): + total_num = sum(p.numel() for p in model.parameters()) + trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad) + return {'Total': total_num, 'Trainable': trainable_num} + +def set_parameter_grad_false(model): + for p in model.parameters(): + p.requires_grad = False diff --git a/infer/views_to_mesh.py b/infer/views_to_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..f0fe680f0fe235501e503bfd410e56931ee1eeff --- /dev/null +++ b/infer/views_to_mesh.py @@ -0,0 +1,94 @@ +# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein: +# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited. + +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# The below software and/or models in this distribution may have been +# modified by THL A29 Limited ("Tencent Modifications"). +# All Tencent Modifications are Copyright (C) THL A29 Limited. + +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +import os +import time +import torch +import random +import numpy as np +from PIL import Image +from einops import rearrange +from PIL import Image, ImageSequence + +from .utils import seed_everything, timing_decorator, auto_amp_inference +from .utils import get_parameter_number, set_parameter_grad_false +from svrm.predictor import MV23DPredictor + + +class Views2Mesh(): + def __init__(self, mv23d_cfg_path, mv23d_ckt_path, device="cuda:0", use_lite=False): + ''' + mv23d_cfg_path: config yaml file + mv23d_ckt_path: path to ckpt + use_lite: + ''' + self.mv23d_predictor = MV23DPredictor(mv23d_ckt_path, mv23d_cfg_path, device=device) + self.mv23d_predictor.model.eval() + self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1] + set_parameter_grad_false(self.mv23d_predictor.model) + print('view2mesh model', get_parameter_number(self.mv23d_predictor.model)) + + @torch.no_grad() + @timing_decorator("views to mesh") + @auto_amp_inference + def __call__( + self, + views_pil=None, + cond_pil=None, + gif_pil=None, + seed=0, + target_face_count = 10000, + do_texture_mapping = True, + save_folder='./outputs/test' + ): + ''' + can set views_pil, cond_pil simutaously or set gif_pil only + seed: int + target_face_count: int + save_folder: path to save mesh files + ''' + save_dir = save_folder + os.makedirs(save_dir, exist_ok=True) + + if views_pil is not None and cond_pil is not None: + show_image = rearrange(np.asarray(views_pil, dtype=np.uint8), + '(n h) (m w) c -> (n m) h w c', n=3, m=2) + views = [Image.fromarray(show_image[idx]) for idx in self.order] + image_list = [cond_pil]+ views + image_list = [img.convert('RGB') for img in image_list] + elif gif_pil is not None: + image_list = [img.convert('RGB') for img in ImageSequence.Iterator(gif_pil)] + + image_input = image_list[0] + image_list = image_list[1:] + image_list[:1] + + seed_everything(seed) + self.mv23d_predictor.predict( + image_list, + save_dir = save_dir, + image_input = image_input, + target_face_count = target_face_count, + do_texture_mapping = do_texture_mapping + ) + torch.cuda.empty_cache() + return save_dir + \ No newline at end of file diff --git a/mvd/__init__.py b/mvd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mvd/hunyuan3d_mvd_lite_pipeline.py b/mvd/hunyuan3d_mvd_lite_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..29abd3a1792a8b96b50bbb61845cb0284255b113 --- /dev/null +++ b/mvd/hunyuan3d_mvd_lite_pipeline.py @@ -0,0 +1,493 @@ +# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein: +# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited. + +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# The below software and/or models in this distribution may have been +# modified by THL A29 Limited ("Tencent Modifications"). +# All Tencent Modifications are Copyright (C) THL A29 Limited. + +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +import math +import numpy +import torch +import inspect +import warnings +from PIL import Image +from einops import rearrange +import torch.nn.functional as F +from diffusers.utils.torch_utils import randn_tensor +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from typing import Any, Callable, Dict, List, Optional, Union +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers import DDPMScheduler, EulerAncestralDiscreteScheduler, ImagePipelineOutput +from diffusers.loaders import ( + FromSingleFileMixin, + LoraLoaderMixin, + TextualInversionLoaderMixin +) +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection +) +from diffusers.models.attention_processor import ( + Attention, + AttnProcessor, + XFormersAttnProcessor, + AttnProcessor2_0 +) + +from .utils import to_rgb_image, white_out_background, recenter_img + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from here import Hunyuan3d_MVD_Qing_Pipeline + + >>> pipe = Hunyuan3d_MVD_Qing_Pipeline.from_pretrained( + ... "Tencent-Hunyuan-3D/MVD-Qing", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> img = Image.open("demo.png") + >>> res_img = pipe(img).images[0] +""" + +def unscale_latents(latents): return latents / 0.75 + 0.22 +def unscale_image (image ): return image / 0.50 * 0.80 + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + + +class ReferenceOnlyAttnProc(torch.nn.Module): + # reference attention + def __init__(self, chained_proc, enabled=False, name=None): + super().__init__() + self.enabled = enabled + self.chained_proc = chained_proc + self.name = name + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, mode="w", ref_dict=None): + if encoder_hidden_states is None: encoder_hidden_states = hidden_states + if self.enabled: + if mode == 'w': + ref_dict[self.name] = encoder_hidden_states + elif mode == 'r': + encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1) + res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask) + return res + + +# class RowWiseAttnProcessor2_0: +# def __call__(self, attn, +# hidden_states, +# encoder_hidden_states=None, +# attention_mask=None, +# temb=None, +# num_views=6, +# *args, +# **kwargs): +# residual = hidden_states +# if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) + +# input_ndim = hidden_states.ndim +# if input_ndim == 4: +# batch_size, channel, height, width = hidden_states.shape +# hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + +# if encoder_hidden_states is None: +# batch_size, sequence_length, _ = hidden_states.shape +# else: +# batch_size, sequence_length, _ = encoder_hidden_states.shape + +# if attention_mask is not None: +# attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) +# attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) +# if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + +# query = attn.to_q(hidden_states) +# if encoder_hidden_states is None: encoder_hidden_states = hidden_states +# elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + +# # encoder_hidden_states [B, 6hw+hw, C] if ref att +# key = attn.to_k(encoder_hidden_states) # [B, Vhw+hw, C] +# value = attn.to_v(encoder_hidden_states) # [B, Vhw+hw, C] + +# mv_flag = hidden_states.shape[1] < encoder_hidden_states.shape[1] and encoder_hidden_states.shape[1] != 77 +# if mv_flag: +# target_size = int(math.sqrt(hidden_states.shape[1] // num_views)) +# assert target_size ** 2 * num_views == hidden_states.shape[1] + +# gen_key = key[:, :num_views*target_size*target_size, :] +# ref_key = key[:, num_views*target_size*target_size:, :] +# gen_value = value[:, :num_views*target_size*target_size, :] +# ref_value = value[:, num_views*target_size*target_size:, :] + +# # rowwise attention +# query, gen_key, gen_value = \ +# rearrange( query, "b (v1 h v2 w) c -> (b h) (v1 v2 w) c", +# v1=num_views//2, v2=2, h=target_size, w=target_size), \ +# rearrange( gen_key, "b (v1 h v2 w) c -> (b h) (v1 v2 w) c", +# v1=num_views//2, v2=2, h=target_size, w=target_size), \ +# rearrange(gen_value, "b (v1 h v2 w) c -> (b h) (v1 v2 w) c", +# v1=num_views//2, v2=2, h=target_size, w=target_size) + +# inner_dim = key.shape[-1] +# ref_size = int(math.sqrt(ref_key.shape[1])) +# ref_key_expanded = ref_key.view(batch_size, 1, ref_size * ref_size, inner_dim) +# ref_key_expanded = ref_key_expanded.expand(-1, target_size, -1, -1).contiguous() +# ref_key_expanded = ref_key_expanded.view(batch_size * target_size, ref_size * ref_size, inner_dim) +# key = torch.cat([ gen_key, ref_key_expanded], dim=1) + +# ref_value_expanded = ref_value.view(batch_size, 1, ref_size * ref_size, inner_dim) +# ref_value_expanded = ref_value_expanded.expand(-1, target_size, -1, -1).contiguous() +# ref_value_expanded = ref_value_expanded.view(batch_size * target_size, ref_size * ref_size, inner_dim) +# value = torch.cat([gen_value, ref_value_expanded], dim=1) +# h = target_size +# else: +# target_size = int(math.sqrt(hidden_states.shape[1])) +# h = 1 +# num_views = 1 + +# inner_dim = key.shape[-1] +# head_dim = inner_dim // attn.heads + +# query = query.view(batch_size * h, -1, attn.heads, head_dim).transpose(1, 2) +# key = key.view(batch_size * h, -1, attn.heads, head_dim).transpose(1, 2) +# value = value.view(batch_size * h, -1, attn.heads, head_dim).transpose(1, 2) + +# hidden_states = F.scaled_dot_product_attention(query, key, value, +# attn_mask=attention_mask, +# dropout_p=0.0, +# is_causal=False) +# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size * h, +# -1, +# attn.heads * head_dim).to(query.dtype) +# hidden_states = attn.to_out[1](attn.to_out[0](hidden_states)) + +# if mv_flag: hidden_states = rearrange(hidden_states, "(b h) (v1 v2 w) c -> b (v1 h v2 w) c", +# b=batch_size, v1=num_views//2, +# v2=2, h=target_size, w=target_size) + +# if input_ndim == 4: +# hidden_states = hidden_states.transpose(-1, -2) +# hidden_states = hidden_states.reshape(batch_size, +# channel, +# target_size, +# target_size) +# if attn.residual_connection: hidden_states = hidden_states + residual +# hidden_states = hidden_states / attn.rescale_output_factor +# return hidden_states + + +class RefOnlyNoisedUNet(torch.nn.Module): + def __init__(self, unet, train_sched, val_sched): + super().__init__() + self.unet = unet + self.train_sched = train_sched + self.val_sched = val_sched + + unet_lora_attn_procs = dict() + for name, _ in unet.attn_processors.items(): + unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(AttnProcessor2_0(), + enabled=name.endswith("attn1.processor"), + name=name) + unet.set_attn_processor(unet_lora_attn_procs) + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.unet, name) + + def forward(self, sample, timestep, encoder_hidden_states, *args, cross_attention_kwargs, **kwargs): + cond_lat = cross_attention_kwargs['cond_lat'] + noise = torch.randn_like(cond_lat) + if self.training: + noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep) + noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep) + else: + noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1)) + noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1)) + + ref_dict = {} + self.unet(noisy_cond_lat, + timestep, + encoder_hidden_states, + *args, + cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict), + **kwargs) + return self.unet(sample, + timestep, + encoder_hidden_states, + *args, + cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict), + **kwargs) + + +class Hunyuan3d_MVD_Lite_Pipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin): + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + vision_encoder: CLIPVisionModelWithProjection, + feature_extractor_clip: CLIPImageProcessor, + feature_extractor_vae: CLIPImageProcessor, + ramping_coefficients: Optional[list] = None, + safety_checker=None, + ): + DiffusionPipeline.__init__(self) + self.register_modules( + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + text_encoder=text_encoder, + vision_encoder=vision_encoder, + feature_extractor_vae=feature_extractor_vae, + feature_extractor_clip=feature_extractor_clip) + ''' + rewrite the stable diffusion pipeline + vae: vae + unet: unet + tokenizer: tokenizer + scheduler: scheduler + text_encoder: text_encoder + vision_encoder: vision_encoder + feature_extractor_vae: feature_extractor_vae + feature_extractor_clip: feature_extractor_clip + ''' + self.register_to_config(ramping_coefficients=ramping_coefficients) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def prepare_extra_step_kwargs(self, generator, eta): + extra_step_kwargs = {} + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_eta: extra_step_kwargs["eta"] = eta + + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)[0] + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError() + elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): raise ValueError() + else: uncond_tokens = negative_prompt + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer(uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt") + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(device), attention_mask=attention_mask) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + @torch.no_grad() + def encode_condition_image(self, image: torch.Tensor): return self.vae.encode(image).latent_dist.sample() + + @torch.no_grad() + def __call__(self, image=None, + width=640, + height=960, + num_inference_steps=75, + return_dict=True, + generator=None, + **kwargs): + batch_size = 1 + num_images_per_prompt = 1 + output_type = 'pil' + do_classifier_free_guidance = True + guidance_rescale = 0. + if isinstance(self.unet, UNet2DConditionModel): + self.unet = RefOnlyNoisedUNet(self.unet, None, self.scheduler).eval() + + cond_image = recenter_img(image) + cond_image = to_rgb_image(image) + image = cond_image + image_1 = self.feature_extractor_vae(images=image, return_tensors="pt").pixel_values + image_2 = self.feature_extractor_clip(images=image, return_tensors="pt").pixel_values + image_1 = image_1.to(device=self.vae.device, dtype=self.vae.dtype) + image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype) + + cond_lat = self.encode_condition_image(image_1) + negative_lat = self.encode_condition_image(torch.zeros_like(image_1)) + cond_lat = torch.cat([negative_lat, cond_lat]) + cross_attention_kwargs = dict(cond_lat=cond_lat) + + global_embeds = self.vision_encoder(image_2, output_hidden_states=False).image_embeds.unsqueeze(-2) + encoder_hidden_states = self._encode_prompt('', self.device, num_images_per_prompt, False) + ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1) + prompt_embeds = torch.cat([encoder_hidden_states, encoder_hidden_states + global_embeds * ramp]) + + device = self._execution_device + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents(batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + None) + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, 0.0) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # set adaptive cfg + # the image order is: + # [0, 60, + # 120, 180, + # 240, 300] + # the cfg is set as 3, 2.5, 2, 1.5 + + tmp_guidance_scale = torch.ones_like(latents) + tmp_guidance_scale[:, :, :40, :40] = 3 + tmp_guidance_scale[:, :, :40, 40:] = 2.5 + tmp_guidance_scale[:, :, 40:80, :40] = 2 + tmp_guidance_scale[:, :, 40:80, 40:] = 1.5 + tmp_guidance_scale[:, :, 80:120, :40] = 2 + tmp_guidance_scale[:, :, 80:120, 40:] = 2.5 + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + noise_pred = self.unet(latent_model_input, t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False)[0] + + adaptive_guidance_scale = (2 + 16 * (t / 1000) ** 5) / 3 + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + \ + tmp_guidance_scale * adaptive_guidance_scale * \ + (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if i==len(timesteps)-1 or ((i+1)>num_warmup_steps and (i+1)%self.scheduler.order==0): + progress_bar.update() + + latents = unscale_latents(latents) + image = unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]) + image = self.image_processor.postprocess(image, output_type='pil')[0] + image = [image, cond_image] + return ImagePipelineOutput(images=image) if return_dict else (image,) + diff --git a/mvd/hunyuan3d_mvd_std_pipeline.py b/mvd/hunyuan3d_mvd_std_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..88cd8bae1a35d663e3275e50077af59ee3fcae33 --- /dev/null +++ b/mvd/hunyuan3d_mvd_std_pipeline.py @@ -0,0 +1,471 @@ +# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein: +# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited. + +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# The below software and/or models in this distribution may have been +# modified by THL A29 Limited ("Tencent Modifications"). +# All Tencent Modifications are Copyright (C) THL A29 Limited. + +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +import inspect +from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional, Tuple, Union + +import os +import torch +import numpy as np +from PIL import Image + +import diffusers +from diffusers.image_processor import VaeImageProcessor +from diffusers.utils.import_utils import is_xformers_available +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils.torch_utils import randn_tensor +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention_processor import ( + Attention, + AttnProcessor, + XFormersAttnProcessor, + AttnProcessor2_0 +) +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + EulerAncestralDiscreteScheduler, + UNet2DConditionModel, + ImagePipelineOutput +) +import transformers +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + CLIPTextModelWithProjection +) + +from .utils import to_rgb_image, white_out_background, recenter_img + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import Hunyuan3d_MVD_XL_Pipeline + + >>> pipe = Hunyuan3d_MVD_XL_Pipeline.from_pretrained( + ... "Tencent-Hunyuan-3D/MVD-XL", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> img = Image.open("demo.png") + >>> res_img = pipe(img).images[0] + ``` +""" + + + +def scale_latents(latents): return (latents - 0.22) * 0.75 +def unscale_latents(latents): return (latents / 0.75) + 0.22 +def scale_image(image): return (image - 0.5) / 0.5 +def scale_image_2(image): return (image * 0.5) / 0.8 +def unscale_image(image): return (image * 0.5) + 0.5 +def unscale_image_2(image): return (image * 0.8) / 0.5 + + + + +class ReferenceOnlyAttnProc(torch.nn.Module): + def __init__(self, chained_proc, enabled=False, name=None): + super().__init__() + self.enabled = enabled + self.chained_proc = chained_proc + self.name = name + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, mode="w", ref_dict=None): + encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states + if self.enabled: + if mode == 'w': ref_dict[self.name] = encoder_hidden_states + elif mode == 'r': encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1) + else: raise Exception(f"mode should not be {mode}") + return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask) + + +class RefOnlyNoisedUNet(torch.nn.Module): + def __init__(self, unet, scheduler) -> None: + super().__init__() + self.unet = unet + self.scheduler = scheduler + + unet_attn_procs = dict() + for name, _ in unet.attn_processors.items(): + if torch.__version__ >= '2.0': default_attn_proc = AttnProcessor2_0() + elif is_xformers_available(): default_attn_proc = XFormersAttnProcessor() + else: default_attn_proc = AttnProcessor() + unet_attn_procs[name] = ReferenceOnlyAttnProc( + default_attn_proc, enabled=name.endswith("attn1.processor"), name=name + ) + unet.set_attn_processor(unet_attn_procs) + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.unet, name) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + class_labels: Optional[torch.Tensor] = None, + down_block_res_samples: Optional[Tuple[torch.Tensor]] = None, + mid_block_res_sample: Optional[Tuple[torch.Tensor]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + return_dict: bool = True, + **kwargs + ): + + dtype = self.unet.dtype + + # cond_lat add same level noise + cond_lat = cross_attention_kwargs['cond_lat'] + noise = torch.randn_like(cond_lat) + + noisy_cond_lat = self.scheduler.add_noise(cond_lat, noise, timestep.reshape(-1)) + noisy_cond_lat = self.scheduler.scale_model_input(noisy_cond_lat, timestep.reshape(-1)) + + ref_dict = {} + + _ = self.unet( + noisy_cond_lat, + timestep, + encoder_hidden_states = encoder_hidden_states, + class_labels = class_labels, + cross_attention_kwargs = dict(mode="w", ref_dict=ref_dict), + added_cond_kwargs = added_cond_kwargs, + return_dict = return_dict, + **kwargs + ) + + res = self.unet( + sample, + timestep, + encoder_hidden_states, + class_labels=class_labels, + cross_attention_kwargs = dict(mode="r", ref_dict=ref_dict), + down_block_additional_residuals = [ + sample.to(dtype=dtype) for sample in down_block_res_samples + ] if down_block_res_samples is not None else None, + mid_block_additional_residual = ( + mid_block_res_sample.to(dtype=dtype) + if mid_block_res_sample is not None else None), + added_cond_kwargs = added_cond_kwargs, + return_dict = return_dict, + **kwargs + ) + return res + + + +class HunYuan3D_MVD_Std_Pipeline(diffusers.DiffusionPipeline): + def __init__( + self, + vae: AutoencoderKL, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + feature_extractor_vae: CLIPImageProcessor, + vision_processor: CLIPImageProcessor, + vision_encoder: CLIPVisionModelWithProjection, + vision_encoder_2: CLIPVisionModelWithProjection, + ramping_coefficients: Optional[list] = None, + add_watermarker: Optional[bool] = None, + safety_checker = None, + ): + DiffusionPipeline.__init__(self) + + self.register_modules( + vae=vae, unet=unet, scheduler=scheduler, safety_checker=None, feature_extractor_vae=feature_extractor_vae, + vision_processor=vision_processor, vision_encoder=vision_encoder, vision_encoder_2=vision_encoder_2, + ) + self.register_to_config( ramping_coefficients = ramping_coefficients) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = self.unet.config.sample_size + self.watermark = None + self.prepare_init = False + + def prepare(self): + assert isinstance(self.unet, UNet2DConditionModel), "unet should be UNet2DConditionModel" + self.unet = RefOnlyNoisedUNet(self.unet, self.scheduler).eval() + self.prepare_init = True + + def encode_image(self, image: torch.Tensor, scale_factor: bool = False): + latent = self.vae.encode(image).latent_dist.sample() + return (latent * self.vae.config.scaling_factor) if scale_factor else latent + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, " \ + f"but a vector of {passed_add_embed_dim} was created. The model has an incorrect config." \ + f" Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def interrupt(self): + return self._interrupt + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @torch.no_grad() + def __call__( + self, + image: Image.Image = None, + guidance_scale = 2.0, + output_type: Optional[str] = "pil", + num_inference_steps: int = 50, + return_dict: bool = True, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + latent: torch.Tensor = None, + guidance_curve = None, + **kwargs + ): + if not self.prepare_init: + self.prepare() + + here = dict(device=self.vae.device, dtype=self.vae.dtype) + + batch_size = 1 + num_images_per_prompt = 1 + width, height = 512 * 2, 512 * 3 + target_size = original_size = (height, width) + + self._guidance_scale = guidance_scale + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + device = self._execution_device + + # Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + self.vae.dtype, + device, + generator, + latents=latent, + ) + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + + # Prepare added time ids & embeddings + text_encoder_projection_dim = 1280 + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=self.vae.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + negative_add_time_ids = add_time_ids + + # hw: preprocess + cond_image = recenter_img(image) + cond_image = to_rgb_image(image) + image_vae = self.feature_extractor_vae(images=cond_image, return_tensors="pt").pixel_values.to(**here) + image_clip = self.vision_processor(images=cond_image, return_tensors="pt").pixel_values.to(**here) + + # hw: get cond_lat from cond_img using vae + cond_lat = self.encode_image(image_vae, scale_factor=False) + negative_lat = self.encode_image(torch.zeros_like(image_vae), scale_factor=False) + cond_lat = torch.cat([negative_lat, cond_lat]) + + # hw: get visual global embedding using clip + global_embeds_1 = self.vision_encoder(image_clip, output_hidden_states=False).image_embeds.unsqueeze(-2) + global_embeds_2 = self.vision_encoder_2(image_clip, output_hidden_states=False).image_embeds.unsqueeze(-2) + global_embeds = torch.concat([global_embeds_1, global_embeds_2], dim=-1) + + ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1) + prompt_embeds = self.uc_text_emb.to(**here) + pooled_prompt_embeds = self.uc_text_emb_2.to(**here) + + prompt_embeds = prompt_embeds + global_embeds * ramp + add_text_embeds = pooled_prompt_embeds + + if self.do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + timestep_cond = None + self._num_timesteps = len(timesteps) + + if guidance_curve is None: + guidance_curve = lambda t: guidance_scale + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=dict(cond_lat=cond_lat), + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + + # cur_guidance_scale = self.guidance_scale + cur_guidance_scale = guidance_curve(t) # 1.5 + 2.5 * ((t/1000)**2) + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + cur_guidance_scale * (noise_pred_text - noise_pred_uncond) + + # cur_guidance_scale_topleft = (cur_guidance_scale - 1.0) * 4 + 1.0 + # noise_pred_top_left = noise_pred_uncond + + # cur_guidance_scale_topleft * (noise_pred_text - noise_pred_uncond) + # _, _, h, w = noise_pred.shape + # noise_pred[:, :, :h//3, :w//2] = noise_pred_top_left[:, :, :h//3, :w//2] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + latents = unscale_latents(latents) + + if output_type=="latent": + image = latents + else: + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = unscale_image(unscale_image_2(image)).clamp(0, 1) + image = [ + Image.fromarray((image[0]*255+0.5).clamp_(0, 255).permute(1, 2, 0).cpu().numpy().astype("uint8")), + # self.image_processor.postprocess(image, output_type=output_type)[0], + cond_image.resize((512, 512)) + ] + + if not return_dict: return (image,) + return ImagePipelineOutput(images=image) + + def save_pretrained(self, save_directory): + # uc_text_emb.pt and uc_text_emb_2.pt are inferenced and saved in advance + super().save_pretrained(save_directory) + torch.save(self.uc_text_emb, os.path.join(save_directory, "uc_text_emb.pt")) + torch.save(self.uc_text_emb_2, os.path.join(save_directory, "uc_text_emb_2.pt")) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + # uc_text_emb.pt and uc_text_emb_2.pt are inferenced and saved in advance + pipeline = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + pipeline.uc_text_emb = torch.load(os.path.join(pretrained_model_name_or_path, "uc_text_emb.pt")) + pipeline.uc_text_emb_2 = torch.load(os.path.join(pretrained_model_name_or_path, "uc_text_emb_2.pt")) + return pipeline diff --git a/mvd/utils.py b/mvd/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f7423e9d75446363d15b0eb9e547e17f31c65f88 --- /dev/null +++ b/mvd/utils.py @@ -0,0 +1,85 @@ +# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein: +# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited. + +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# The below software and/or models in this distribution may have been +# modified by THL A29 Limited ("Tencent Modifications"). +# All Tencent Modifications are Copyright (C) THL A29 Limited. + +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +import numpy as np +from PIL import Image + +def to_rgb_image(maybe_rgba: Image.Image): + ''' + convert a PIL.Image to rgb mode with white background + maybe_rgba: PIL.Image + return: PIL.Image + ''' + if maybe_rgba.mode == 'RGB': + return maybe_rgba + elif maybe_rgba.mode == 'RGBA': + rgba = maybe_rgba + img = np.random.randint(255, 256, size=[rgba.size[1], rgba.size[0], 3], dtype=np.uint8) + img = Image.fromarray(img, 'RGB') + img.paste(rgba, mask=rgba.getchannel('A')) + return img + else: + raise ValueError("Unsupported image type.", maybe_rgba.mode) + +def white_out_background(pil_img, is_gray_fg=True): + data = pil_img.getdata() + new_data = [] + # convert fore-ground white to gray + for r, g, b, a in data: + if a < 16: + new_data.append((255, 255, 255, 0)) # back-ground to be black + else: + is_white = is_gray_fg and (r>235) and (g>235) and (b>235) + new_r = 235 if is_white else r + new_g = 235 if is_white else g + new_b = 235 if is_white else b + new_data.append((new_r, new_g, new_b, a)) + pil_img.putdata(new_data) + return pil_img + +def recenter_img(img, size=512, color=(255,255,255)): + img = white_out_background(img) + mask = np.array(img)[..., 3] + image = np.array(img)[..., :3] + + H, W, C = image.shape + coords = np.nonzero(mask) + x_min, x_max = coords[0].min(), coords[0].max() + y_min, y_max = coords[1].min(), coords[1].max() + h = x_max - x_min + w = y_max - y_min + if h == 0 or w == 0: raise ValueError + roi = image[x_min:x_max, y_min:y_max] + + border_ratio = 0.15 # 0.2 + pad_h = int(h * border_ratio) + pad_w = int(w * border_ratio) + + result_tmp = np.full((h + pad_h, w + pad_w, C), color, dtype=np.uint8) + result_tmp[pad_h // 2: pad_h // 2 + h, pad_w // 2: pad_w // 2 + w] = roi + + cur_h, cur_w = result_tmp.shape[:2] + side = max(cur_h, cur_w) + result = np.full((side, side, C), color, dtype=np.uint8) + result[(side-cur_h)//2:(side-cur_h)//2+cur_h, (side-cur_w)//2:(side - cur_w)//2+cur_w,:] = result_tmp + result = Image.fromarray(result) + return result.resize((size, size), Image.LANCZOS) if size else result diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a187439832c19e4a7902be324919d035c28a097c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,22 @@ +--find-links https://download.pytorch.org/whl/cu118 +torch==2.2.0 +torchvision==0.17.0 +diffusers +transformers +rembg +tqdm +omegaconf +matplotlib +opencv-python +imageio +jaxtyping +einops +SentencePiece +accelerate +trimesh +PyMCubes +xatlas +libigl +git+https://github.com/facebookresearch/pytorch3d +git+https://github.com/NVlabs/nvdiffrast +open3d \ No newline at end of file diff --git a/scripts/image_to_3d.sh b/scripts/image_to_3d.sh new file mode 100644 index 0000000000000000000000000000000000000000..730bef20d97efe4dc715c4164d46007b5072bfe8 --- /dev/null +++ b/scripts/image_to_3d.sh @@ -0,0 +1,8 @@ +# image to 3d + +python main.py \ + --image_prompt ./demos/example_000.png \ + --save_folder ./outputs/test/ \ + --max_faces_num 90000 \ + --do_texture \ + --do_render diff --git a/scripts/image_to_3d_demo.sh b/scripts/image_to_3d_demo.sh new file mode 100644 index 0000000000000000000000000000000000000000..eea6263c4e3e7a4ca1685bd26c0299c05f8e2e83 --- /dev/null +++ b/scripts/image_to_3d_demo.sh @@ -0,0 +1,8 @@ +# image to 3d + +python main.py \ + --image_prompt ./demos/example_000.png \ + --save_folder ./outputs/test/ \ + --max_faces_num 90000 \ + --do_texture_mapping \ + --do_render diff --git a/scripts/image_to_3d_fast.sh b/scripts/image_to_3d_fast.sh new file mode 100644 index 0000000000000000000000000000000000000000..024107a438e01dbad99924bf59cad92fb924e307 --- /dev/null +++ b/scripts/image_to_3d_fast.sh @@ -0,0 +1,6 @@ +# image to 3d fast +python main.py \ + --image_prompt ./demos/example_000.png \ + --save_folder ./outputs/test/ \ + --max_faces_num 10000 \ + --use_lite diff --git a/scripts/image_to_3d_fast_demo.sh b/scripts/image_to_3d_fast_demo.sh new file mode 100644 index 0000000000000000000000000000000000000000..024107a438e01dbad99924bf59cad92fb924e307 --- /dev/null +++ b/scripts/image_to_3d_fast_demo.sh @@ -0,0 +1,6 @@ +# image to 3d fast +python main.py \ + --image_prompt ./demos/example_000.png \ + --save_folder ./outputs/test/ \ + --max_faces_num 10000 \ + --use_lite diff --git a/scripts/text_to_3d.sh b/scripts/text_to_3d.sh new file mode 100644 index 0000000000000000000000000000000000000000..a81341df18eceef0f47ad8e41229fbc6ff0de3b6 --- /dev/null +++ b/scripts/text_to_3d.sh @@ -0,0 +1,7 @@ +# text to 3d fast +python main.py \ + --text_prompt "a lovely cat" \ + --save_folder ./outputs/test/ \ + --max_faces_num 90000 \ + --do_texture \ + --do_render diff --git a/scripts/text_to_3d_demo.sh b/scripts/text_to_3d_demo.sh new file mode 100644 index 0000000000000000000000000000000000000000..88f46bfc1d3da38b08014a4dda85103dc6e19b9d --- /dev/null +++ b/scripts/text_to_3d_demo.sh @@ -0,0 +1,7 @@ +# text to 3d fast +python main.py \ + --text_prompt "a lovely rabbit" \ + --save_folder ./outputs/test/ \ + --max_faces_num 90000 \ + --do_texture_mapping \ + --do_render diff --git a/scripts/text_to_3d_fast.sh b/scripts/text_to_3d_fast.sh new file mode 100644 index 0000000000000000000000000000000000000000..ed41c4a887d050dbd6b2da839c87d174b5c6efc1 --- /dev/null +++ b/scripts/text_to_3d_fast.sh @@ -0,0 +1,6 @@ +# text to 3d fast +python main.py \ + --text_prompt "一个广式茶杯" \ + --save_folder ./outputs/test/ \ + --max_faces_num 10000 \ + --use_lite diff --git a/scripts/text_to_3d_fast_demo.sh b/scripts/text_to_3d_fast_demo.sh new file mode 100644 index 0000000000000000000000000000000000000000..ed41c4a887d050dbd6b2da839c87d174b5c6efc1 --- /dev/null +++ b/scripts/text_to_3d_fast_demo.sh @@ -0,0 +1,6 @@ +# text to 3d fast +python main.py \ + --text_prompt "一个广式茶杯" \ + --save_folder ./outputs/test/ \ + --max_faces_num 10000 \ + --use_lite diff --git a/svrm/.DS_Store b/svrm/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..d9f5f94a8d4408fa092739ff9d5775a56e8da914 Binary files /dev/null and b/svrm/.DS_Store differ diff --git a/svrm/configs/2024-10-24T22-36-18-project.yaml b/svrm/configs/2024-10-24T22-36-18-project.yaml new file mode 100644 index 0000000000000000000000000000000000000000..286ff8e108a3b241aed860d21681479ccf0a1e63 --- /dev/null +++ b/svrm/configs/2024-10-24T22-36-18-project.yaml @@ -0,0 +1,32 @@ +model: + base_learning_rate: 3.0e-05 + target: svrm.ldm.models.svrm.SVRMModel + params: + + img_encoder_config: + target: svrm.ldm.modules.encoders.dinov2_mod.FrozenDinoV2ImageEmbedder + params: + version: dinov2_vitb14 + + img_to_triplane_config: + target: svrm.ldm.modules.translator.img_to_triplane.ImgToTriplaneModel + params: + pos_emb_size: 64 + pos_emb_dim: 1024 + cam_cond_dim: 20 + n_heads: 16 + d_head: 64 + depth: 16 + context_dim: 768 + triplane_dim: 120 + use_fp16: true + use_bf16: false + upsample_time: 2 + + render_config: + target: svrm.ldm.modules.rendering_neus.synthesizer.TriplaneSynthesizer + params: + triplane_dim: 120 + samples_per_ray: 128 + + diff --git a/svrm/configs/svrm.yaml b/svrm/configs/svrm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..286ff8e108a3b241aed860d21681479ccf0a1e63 --- /dev/null +++ b/svrm/configs/svrm.yaml @@ -0,0 +1,32 @@ +model: + base_learning_rate: 3.0e-05 + target: svrm.ldm.models.svrm.SVRMModel + params: + + img_encoder_config: + target: svrm.ldm.modules.encoders.dinov2_mod.FrozenDinoV2ImageEmbedder + params: + version: dinov2_vitb14 + + img_to_triplane_config: + target: svrm.ldm.modules.translator.img_to_triplane.ImgToTriplaneModel + params: + pos_emb_size: 64 + pos_emb_dim: 1024 + cam_cond_dim: 20 + n_heads: 16 + d_head: 64 + depth: 16 + context_dim: 768 + triplane_dim: 120 + use_fp16: true + use_bf16: false + upsample_time: 2 + + render_config: + target: svrm.ldm.modules.rendering_neus.synthesizer.TriplaneSynthesizer + params: + triplane_dim: 120 + samples_per_ray: 128 + + diff --git a/svrm/ldm/.DS_Store b/svrm/ldm/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..fb57e5a2bfa5f6208a6a12d082e457410cb415be Binary files /dev/null and b/svrm/ldm/.DS_Store differ diff --git a/svrm/ldm/models/svrm.py b/svrm/ldm/models/svrm.py new file mode 100644 index 0000000000000000000000000000000000000000..e12511ec16095176b1d3bed70eaa18a29b7e1b0f --- /dev/null +++ b/svrm/ldm/models/svrm.py @@ -0,0 +1,263 @@ +# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein: +# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited. + +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# The below software and/or models in this distribution may have been +# modified by THL A29 Limited ("Tencent Modifications"). +# All Tencent Modifications are Copyright (C) THL A29 Limited. + +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +import os +import time +import math +import cv2 +import numpy as np +import itertools +import shutil +from tqdm import tqdm +import torch +import torch.nn.functional as F +from einops import rearrange +try: + import trimesh + import mcubes + import xatlas + import open3d as o3d +except: + raise "failed to import 3d libraries " + +from ..modules.rendering_neus.mesh import Mesh +from ..modules.rendering_neus.rasterize import NVDiffRasterizerContext + +from ..utils.ops import scale_tensor +from ..util import count_params, instantiate_from_config +from ..vis_util import render + + +def unwrap_uv(v_pos, t_pos_idx): + print("Using xatlas to perform UV unwrapping, may take a while ...") + atlas = xatlas.Atlas() + atlas.add_mesh(v_pos, t_pos_idx) + atlas.generate(xatlas.ChartOptions(), xatlas.PackOptions()) + _, indices, uvs = atlas.get_mesh(0) + indices = indices.astype(np.int64, casting="same_kind") + return uvs, indices + + +def uv_padding(image, hole_mask, uv_padding_size = 2): + return cv2.inpaint( + (image.detach().cpu().numpy() * 255).astype(np.uint8), + (hole_mask.detach().cpu().numpy() * 255).astype(np.uint8), + uv_padding_size, + cv2.INPAINT_TELEA + ) + +def refine_mesh(vtx_refine, faces_refine): + mesh = o3d.geometry.TriangleMesh( + vertices=o3d.utility.Vector3dVector(vtx_refine), + triangles=o3d.utility.Vector3iVector(faces_refine)) + + mesh = mesh.remove_unreferenced_vertices() + mesh = mesh.remove_duplicated_triangles() + mesh = mesh.remove_duplicated_vertices() + + voxel_size = max(mesh.get_max_bound() - mesh.get_min_bound()) + + mesh = mesh.simplify_vertex_clustering( + voxel_size=0.007, # 0.005 + contraction=o3d.geometry.SimplificationContraction.Average) + + mesh = mesh.filter_smooth_simple(number_of_iterations=2) + + vtx_refine = np.asarray(mesh.vertices).astype(np.float32) + faces_refine = np.asarray(mesh.triangles) + return vtx_refine, faces_refine, mesh + + +class SVRMModel(torch.nn.Module): + def __init__( + self, + img_encoder_config, + img_to_triplane_config, + render_config, + device = "cuda:0", + **kwargs + ): + super().__init__() + + self.img_encoder = instantiate_from_config(img_encoder_config).half() + self.img_to_triplane_decoder = instantiate_from_config(img_to_triplane_config).half() + self.render = instantiate_from_config(render_config).half() + self.device = device + count_params(self, verbose=True) + + @torch.no_grad() + def export_mesh_with_uv( + self, + data, + mesh_size: int = 384, + ctx = None, + context_type = 'cuda', + texture_res = 1024, + target_face_count = 10000, + do_texture_mapping = True, + out_dir = 'outputs/test' + ): + """ + color_type: 0 for ray texture, 1 for vertices texture + """ + st = time.time() + here = {'device': self.device, 'dtype': torch.float16} + input_view_image = data["input_view"].to(**here) # [b, m, c, h, w] + input_view_cam = data["input_view_cam"].to(**here) # [b, m, 20] + + batch_size, input_view_num, *_ = input_view_image.shape + assert batch_size == 1, "batch size should be 1" + + input_view_image = rearrange(input_view_image, 'b m c h w -> (b m) c h w') + input_view_cam = rearrange(input_view_cam, 'b m d -> (b m) d') + input_view_feat = self.img_encoder(input_view_image, input_view_cam) + input_view_feat = rearrange(input_view_feat, '(b m) l d -> b (l m) d', m=input_view_num) + + # -- decoder + torch.cuda.empty_cache() + triplane_gen = self.img_to_triplane_decoder(input_view_feat) # [b, 3, tri_dim, h, w] + del input_view_feat + torch.cuda.empty_cache() + + # --- triplane nerf render + + cur_triplane = triplane_gen[0:1] + + aabb = torch.tensor([[-0.6, -0.6, -0.6], [0.6, 0.6, 0.6]]).unsqueeze(0).to(**here) + grid_out = self.render.forward_grid(planes=cur_triplane, grid_size=mesh_size, aabb=aabb) + + print(f"=====> LRM forward time: {time.time() - st}") + st = time.time() + + vtx, faces = mcubes.marching_cubes(0. - grid_out['sdf'].squeeze(0).squeeze(-1).cpu().float().numpy(), 0) + + bbox = aabb[0].cpu().numpy() + vtx = vtx / (mesh_size - 1) + vtx = vtx * (bbox[1] - bbox[0]) + bbox[0] + + # refine mesh + vtx_refine, faces_refine, mesh = refine_mesh(vtx, faces) + + # reduce faces + if faces_refine.shape[0] > target_face_count: + print(f"reduce face: {faces_refine.shape[0]} -> {target_face_count}") + mesh = o3d.geometry.TriangleMesh( + vertices = o3d.utility.Vector3dVector(vtx_refine), + triangles = o3d.utility.Vector3iVector(faces_refine) + ) + + # Function to simplify mesh using Quadric Error Metric Decimation by Garland and Heckbert + mesh = mesh.simplify_quadric_decimation(target_face_count, boundary_weight=1.0) + + mesh = Mesh( + v_pos = torch.from_numpy(np.asarray(mesh.vertices)).to(self.device), + t_pos_idx = torch.from_numpy(np.asarray(mesh.triangles)).to(self.device), + v_rgb = torch.from_numpy(np.asarray(mesh.vertex_colors)).to(self.device) + ) + vtx_refine = mesh.v_pos.cpu().numpy() + faces_refine = mesh.t_pos_idx.cpu().numpy() + + vtx_colors = self.render.forward_points(cur_triplane, torch.tensor(vtx_refine).unsqueeze(0).to(**here)) + vtx_colors = vtx_colors['rgb'].float().squeeze(0).cpu().numpy() + + color_ratio = 0.8 # increase brightness + with open(f'{out_dir}/mesh_with_colors.obj', 'w') as fid: + verts = vtx_refine[:, [1,2,0]] + for pidx, pp in enumerate(verts): + color = vtx_colors[pidx] + color = [color[0]**color_ratio, color[1]**color_ratio, color[2]**color_ratio] + fid.write('v %f %f %f %f %f %f\n' % (pp[0], pp[1], pp[2], color[0], color[1], color[2])) + for i, f in enumerate(faces_refine): + f1 = f + 1 + fid.write('f %d %d %d\n' % (f1[0], f1[1], f1[2])) + + mesh = trimesh.load_mesh(f'{out_dir}/mesh_with_colors.obj') + print(f"=====> generate mesh with vertex shading time: {time.time() - st}") + st = time.time() + + if not do_texture_mapping: + shutil.copy(f'{out_dir}/mesh_with_colors.obj', f'{out_dir}/mesh.obj') + mesh.export(f'{out_dir}/mesh.glb', file_type='glb') + return None + + ########## export texture ######## + st = time.time() + + # uv unwrap + vtx_tex, t_tex_idx = unwrap_uv(vtx_refine, faces_refine) + vtx_refine = torch.from_numpy(vtx_refine).to(self.device) + faces_refine = torch.from_numpy(faces_refine).to(self.device) + t_tex_idx = torch.from_numpy(t_tex_idx).to(self.device) + uv_clip = torch.from_numpy(vtx_tex * 2.0 - 1.0).to(self.device) + + # rasterize + ctx = NVDiffRasterizerContext(context_type, cur_triplane.device) if ctx is None else ctx + rast = ctx.rasterize_one( + torch.cat([ + uv_clip, + torch.zeros_like(uv_clip[..., 0:1]), + torch.ones_like(uv_clip[..., 0:1]) + ], dim=-1), + t_tex_idx, + (texture_res, texture_res) + )[0] + hole_mask = ~(rast[:, :, 3] > 0) + + # Interpolate world space position + gb_pos = ctx.interpolate_one(vtx_refine, rast[None, ...], faces_refine)[0][0] + with torch.no_grad(): + gb_mask_pos_scale = scale_tensor(gb_pos.unsqueeze(0).view(1, -1, 3), (-1, 1), (-1, 1)) + tex_map = self.render.forward_points(cur_triplane, gb_mask_pos_scale)['rgb'] + tex_map = tex_map.float().squeeze(0) # (0, 1) + tex_map = tex_map.view((texture_res, texture_res, 3)) + img = uv_padding(tex_map, hole_mask) + img = ((img/255.0) ** color_ratio) * 255 # increase brightness + img = img.clip(0, 255).astype(np.uint8) + + verts = vtx_refine.cpu().numpy()[:, [1,2,0]] + faces = faces_refine.cpu().numpy() + + with open(f'{out_dir}/texture.mtl', 'w') as fid: + fid.write('newmtl material_0\n') + fid.write("Ka 1.000 1.000 1.000\n") + fid.write("Kd 1.000 1.000 1.000\n") + fid.write("Ks 0.000 0.000 0.000\n") + fid.write("d 1.0\n") + fid.write("illum 2\n") + fid.write(f'map_Kd texture.png\n') + + with open(f'{out_dir}/mesh.obj', 'w') as fid: + fid.write(f'mtllib texture.mtl\n') + for pidx, pp in enumerate(verts): + fid.write('v %f %f %f\n' % (pp[0], pp[1], pp[2])) + for pidx, pp in enumerate(vtx_tex): + fid.write('vt %f %f\n' % (pp[0], 1 - pp[1])) + fid.write('usemtl material_0\n') + for i, f in enumerate(faces): + f1 = f + 1 + f2 = t_tex_idx[i] + 1 + fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2],)) + + cv2.imwrite(f'{out_dir}/texture.png', img[..., [2, 1, 0]]) + mesh = trimesh.load_mesh(f'{out_dir}/mesh.obj') + mesh.export(f'{out_dir}/mesh.glb', file_type='glb') + \ No newline at end of file diff --git a/svrm/ldm/modules/attention.py b/svrm/ldm/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..380d3d7b8f1e87aee101c4ac6a9797e3de0682c9 --- /dev/null +++ b/svrm/ldm/modules/attention.py @@ -0,0 +1,457 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat +import numpy as np + +FLASH_IS_AVAILABLE = XFORMERS_IS_AVAILBLE = False +try: + from flash_attn import flash_attn_qkvpacked_func, flash_attn_func + FLASH_IS_AVAILABLE = True +except: + try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True + except: + pass + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + self.scale = dim_head ** -0.5 + self.heads = heads + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + out = einsum('b i j, b j d -> b i d', attn, v) # [b*h, n, d] + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class FlashAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{heads} heads.") + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + self.scale = dim_head ** -0.5 + self.heads = heads + self.dropout = dropout + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + context = default(context, x) + h = self.heads + dtype = torch.bfloat16 # torch.half + q = self.to_q(x).to(dtype) + k = self.to_k(context).to(dtype) + v = self.to_v(context).to(dtype) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q, k, v)) # q is [b, 3079, 16, 64] + out = flash_attn_func(q, k, v, dropout_p=self.dropout, softmax_scale=None, causal=False, window_size=(-1, -1)) # out is same shape to q + out = rearrange(out, 'b n h d -> b n (h d)', h=h) + return self.to_out(out.float()) + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{heads} heads.") + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False): + super().__init__() + self.disable_self_attn = disable_self_attn + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = Fp32LayerNorm(dim) + self.norm2 = Fp32LayerNorm(dim) + self.norm3 = Fp32LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + +ATTENTION_MODES = { + "softmax": CrossAttention, # vanilla attention + "softmax-xformers": MemoryEfficientCrossAttention, + "softmax-flash": FlashAttention +} + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +class AdaNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 2 * dim, bias=True) + ) + self.norm = Fp32LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, c): # x is fp32, c is fp16 + shift, scale = self.adaLN_modulation(c.float()).chunk(2, dim=1) # bf16 + x = modulate(self.norm(x), shift, scale) # fp32 + return x + + +class BasicTransformerBlockLRM(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, \ + checkpoint=True): + super().__init__() + + attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" + attn_mode = "softmax-flash" if FLASH_IS_AVAILABLE else attn_mode + assert attn_mode in ATTENTION_MODES + attn_cls = ATTENTION_MODES[attn_mode] + + self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, \ + context_dim=context_dim) # cross-attn + self.attn2 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, \ + context_dim=None) # self-attn + + self.norm1 = Fp32LayerNorm(dim) + self.norm2 = Fp32LayerNorm(dim) + self.norm3 = Fp32LayerNorm(dim) + + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.checkpoint = checkpoint + + def forward(self, x, context=None, cam_emb=None): # (torch.float32, torch.float32, torch.bfloat16) + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + + def _forward(self, x, context=None, cam_emb=None): + + x = self.attn1(self.norm1(x), context=context) + x # cross-attn + x = self.attn2(self.norm2(x), context=None) + x # self-attn + x = self.ff(self.norm3(x)) + x + + return x + +class ImgToTriplaneTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, query_dim, n_heads, d_head, depth=1, dropout=0., context_dim=None, triplane_size=64): + super().__init__() + + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlockLRM(query_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for d in range(depth)]) + + self.norm = Fp32LayerNorm(query_dim, eps=1e-6) + + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + elif isinstance(module, nn.LayerNorm): + if module.bias is not None: + nn.init.constant_(module.bias, 0) + if module.weight is not None: + nn.init.constant_(module.weight, 1.0) + self.apply(_basic_init) + + def forward(self, x, context=None, cam_emb=None): + # note: if no context is given, cross-attention defaults to self-attention + for block in self.transformer_blocks: + x = block(x, context=context) + x = self.norm(x) + return x + + + + + diff --git a/svrm/ldm/modules/encoders/__init__.py b/svrm/ldm/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/svrm/ldm/modules/encoders/dinov2/__init__.py b/svrm/ldm/modules/encoders/dinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/svrm/ldm/modules/encoders/dinov2/hub/__init__.py b/svrm/ldm/modules/encoders/dinov2/hub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/svrm/ldm/modules/encoders/dinov2/hub/backbones.py b/svrm/ldm/modules/encoders/dinov2/hub/backbones.py new file mode 100644 index 0000000000000000000000000000000000000000..51bb1ffe50525db3ed2ac78c3bb40776dd530645 --- /dev/null +++ b/svrm/ldm/modules/encoders/dinov2/hub/backbones.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch + +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + LVD142M = "LVD142M" + + +def _make_dinov2_model( + *, + arch_name: str = "vit_large", + img_size: int = 518, + patch_size: int = 14, + init_values: float = 1.0, + ffn_layer: str = "mlp", + block_chunks: int = 0, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD142M, + **kwargs, +): + from ..models import vision_transformer as vits + + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + ) + vit_kwargs.update(**kwargs) + model = vits.__dict__[arch_name](**vit_kwargs) + + if pretrained: + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + model.load_state_dict(state_dict, strict=True) + + return model + + +def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + **kwargs, + ) + + +def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_base", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/svrm/ldm/modules/encoders/dinov2/hub/utils.py b/svrm/ldm/modules/encoders/dinov2/hub/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7afea3273713518e891d1e6b8e86d58b4700fddc --- /dev/null +++ b/svrm/ldm/modules/encoders/dinov2/hub/utils.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import itertools +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + +def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" + return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" + + +class CenterPadding(nn.Module): + def __init__(self, multiple): + super().__init__() + self.multiple = multiple + + def _get_pad(self, size): + new_size = math.ceil(size / self.multiple) * self.multiple + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + @torch.inference_mode() + def forward(self, x): + pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) + output = F.pad(x, pads) + return output diff --git a/svrm/ldm/modules/encoders/dinov2/layers/__init__.py b/svrm/ldm/modules/encoders/dinov2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..70dff7a567de20c911559522e678dd6b605fdb9d --- /dev/null +++ b/svrm/ldm/modules/encoders/dinov2/layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlockMod +from .attention import MemEffAttention diff --git a/svrm/ldm/modules/encoders/dinov2/layers/attention.py b/svrm/ldm/modules/encoders/dinov2/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..b0f62db59d697e151d4aaec272b897fd07c8a8ab --- /dev/null +++ b/svrm/ldm/modules/encoders/dinov2/layers/attention.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Attention)") + else: + warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/svrm/ldm/modules/encoders/dinov2/layers/block.py b/svrm/ldm/modules/encoders/dinov2/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..d0de346f02648eda86b6c9dc8a599e5da0f88636 --- /dev/null +++ b/svrm/ldm/modules/encoders/dinov2/layers/block.py @@ -0,0 +1,269 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import os +import logging +import warnings +from typing import Callable, List, Any, Tuple, Dict + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + +from ....attention import AdaNorm + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Block)") + else: + warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + + warnings.warn("xFormers is not available (Block)") + + +class BlockMod(nn.Module): + ''' + using Modified Block, see below + ''' + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = AdaNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, cam_emb: Tensor) -> Tensor: + def attn_residual_func(x: Tensor, cam_emb: Tensor = None) -> Tensor: + return self.ls1(self.attn(self.norm1(x, cam_emb))) + + def ffn_residual_func(x: Tensor, cam_emb: Tensor = None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x, cam_emb))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x, cam_emb)) + x = x + self.drop_path1(ffn_residual_func(x, cam_emb)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x, cam_emb) + x = x + ffn_residual_func(x, cam_emb) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # drop_add_residual_stochastic_depth_list + + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + # get_branges_scales + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + # add residuals + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlockMod(BlockMod): + def forward_nested(self, x_list: List[Tensor], cam_emb_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, cam_emb: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x, cam_emb), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, cam_emb: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x, cam_emb)) + + x_list = drop_add_residual_stochastic_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, cam_emb: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x, cam_emb), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, cam_emb: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x, cam_emb))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list, cam_emb_or_cam_emb_list): + if isinstance(x_or_x_list, Tensor) and isinstance(cam_emb_or_cam_emb_list, Tensor) : + return super().forward(x_or_x_list, cam_emb_or_cam_emb_list) + elif isinstance(x_or_x_list, list) and isinstance(cam_emb_or_cam_emb_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list, cam_emb_or_cam_emb_list) + else: + raise AssertionError diff --git a/svrm/ldm/modules/encoders/dinov2/layers/dino_head.py b/svrm/ldm/modules/encoders/dinov2/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ccca59999e1d686e1341281c61e8961f1b0e6545 --- /dev/null +++ b/svrm/ldm/modules/encoders/dinov2/layers/dino_head.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/svrm/ldm/modules/encoders/dinov2/layers/drop_path.py b/svrm/ldm/modules/encoders/dinov2/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..4bb1487b0eed4cb14dc0d5d1ee57a2acc78de34a --- /dev/null +++ b/svrm/ldm/modules/encoders/dinov2/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/svrm/ldm/modules/encoders/dinov2/layers/layer_scale.py b/svrm/ldm/modules/encoders/dinov2/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..5468ee2dce0a9446c028791de5cff1ff068a4fe5 --- /dev/null +++ b/svrm/ldm/modules/encoders/dinov2/layers/layer_scale.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/svrm/ldm/modules/encoders/dinov2/layers/mlp.py b/svrm/ldm/modules/encoders/dinov2/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..0965768a9aef04ac6b81322f4dd60cf035159e91 --- /dev/null +++ b/svrm/ldm/modules/encoders/dinov2/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/svrm/ldm/modules/encoders/dinov2/layers/patch_embed.py b/svrm/ldm/modules/encoders/dinov2/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..8c3aaf46c523ab1ae27430419187bbad11e302ab --- /dev/null +++ b/svrm/ldm/modules/encoders/dinov2/layers/patch_embed.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/svrm/ldm/modules/encoders/dinov2/layers/swiglu_ffn.py b/svrm/ldm/modules/encoders/dinov2/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..3765d5def655f0a23f3803f4c7f79c33d3ecfd55 --- /dev/null +++ b/svrm/ldm/modules/encoders/dinov2/layers/swiglu_ffn.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (SwiGLU)") + else: + warnings.warn("xFormers is disabled (SwiGLU)") + raise ImportError +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/svrm/ldm/modules/encoders/dinov2/models/__init__.py b/svrm/ldm/modules/encoders/dinov2/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..01c40a2a10ce2b1e39f666328132c2a80111072d --- /dev/null +++ b/svrm/ldm/modules/encoders/dinov2/models/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging + +from . import vision_transformer as vits + + +logger = logging.getLogger("dinov2") + + +def build_model(args, only_teacher=False, img_size=224): + args.arch = args.arch.removesuffix("_memeff") + if "vit" in args.arch: + vit_kwargs = dict( + img_size=img_size, + patch_size=args.patch_size, + init_values=args.layerscale, + ffn_layer=args.ffn_layer, + block_chunks=args.block_chunks, + qkv_bias=args.qkv_bias, + proj_bias=args.proj_bias, + ffn_bias=args.ffn_bias, + num_register_tokens=args.num_register_tokens, + interpolate_offset=args.interpolate_offset, + interpolate_antialias=args.interpolate_antialias, + ) + teacher = vits.__dict__[args.arch](**vit_kwargs) + if only_teacher: + return teacher, teacher.embed_dim + student = vits.__dict__[args.arch]( + **vit_kwargs, + drop_path_rate=args.drop_path_rate, + drop_path_uniform=args.drop_path_uniform, + ) + embed_dim = student.embed_dim + return student, teacher, embed_dim + + +def build_model_from_cfg(cfg, only_teacher=False): + return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) diff --git a/svrm/ldm/modules/encoders/dinov2/models/vision_transformer.py b/svrm/ldm/modules/encoders/dinov2/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..666987e13e5c93ecc62e1d3e3b97bca0987d5edd --- /dev/null +++ b/svrm/ldm/modules/encoders/dinov2/models/vision_transformer.py @@ -0,0 +1,415 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlockMod as BlockMod +from ....attention import AdaNorm + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=BlockMod, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + pos_emb_dim=768, + cam_cond_dim=20 + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + + norm_layer = AdaNorm + self.cam_embed = nn.Sequential( + nn.Linear(cam_cond_dim, pos_emb_dim, bias=True), + nn.SiLU(), + nn.Linear(pos_emb_dim, pos_emb_dim, bias=True)) + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset + + sqrt_N = math.sqrt(N) + sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), + scale_factor=(sx, sy), + mode="bicubic", + antialias=self.interpolate_antialias, + ) + + assert int(w0) == patch_pos_embed.shape[-2] + assert int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features_list_with_camera(self, x_list, cam_cond_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + cam_emb = [self.cam_embed(cam_cond) for cam_cond in cam_cond_list] + for blk in self.blocks: + x = blk(x, cam_emb) + + all_x = x + all_cam_emb = cam_emb + output = [] + for x, cam_emb, masks in zip(all_x, all_cam_emb, masks_list): + x_norm = self.norm(x, cam_emb) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def forward_features_with_camera(self, x, cam_cond, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, cam_cond, masks) + cam_emb = self.cam_embed(cam_cond) + x = self.prepare_tokens_with_masks(x, masks) + for blk in self.blocks: + x = blk(x, cam_emb) + x_norm = self.norm(x, cam_emb) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_inter_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_inter_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + + ret = self.forward_features_with_camera(*args, **kwargs) + + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, AdaNorm): + nn.init.constant_(module.adaLN_modulation[-1].weight, 0) + nn.init.constant_(module.adaLN_modulation[-1].bias, 0) + elif isinstance(module, nn.LayerNorm): + if module.bias is not None: + nn.init.constant_(module.bias, 0) + if module.weight is not None: + nn.init.constant_(module.weight, 1.0) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(BlockMod, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(BlockMod, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/svrm/ldm/modules/encoders/dinov2_mod.py b/svrm/ldm/modules/encoders/dinov2_mod.py new file mode 100644 index 0000000000000000000000000000000000000000..19c4bb6c9e09734b18d621b9955d46ca7cd6275f --- /dev/null +++ b/svrm/ldm/modules/encoders/dinov2_mod.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from .dinov2.hub.backbones import dinov2_vitb14 + +class FrozenDinoV2ImageEmbedder(nn.Module): + """ + Uses the dinov2 image encoder with camera modulation. + Not actually frozen... If you want that set cond_stage_trainable=False in cfg + """ + def __init__( + self, + version='dinov2_vitb14', + ckpt_path=None, + lrm_mode='plain_lrm', + ): + super().__init__() + self.lrm_mode = lrm_mode + assert version in ['dinov2_vitb14', 'dinov2_vits14', 'dinov2_vitl14', 'dinov2_vitg14'] + + + self.model = dinov2_vitb14(pretrained=False) + + if ckpt_path is not None: + self.load_pretrained(ckpt_path) + else: + print('None pretrained model for dinov2 encoder ...') + + + def load_pretrained(self, ckpt_path): + print('Loading dinov2 encoder ...') + orig_state_dict = torch.load(ckpt_path, map_location='cpu') + try: + ret = self.model.load_state_dict(orig_state_dict, strict=False) + print(ret) + print('Successfully loaded orig state dict') + except: + new_state_dict = OrderedDict() + for k, v in orig_state_dict['state_dict'].items(): + if 'img_encoder' in k: + new_state_dict[k.replace('img_encoder.model.', '')] = v + ret = self.model.load_state_dict(new_state_dict, strict=False) + print(ret) + print('Successfully loaded new state dict') + + + def forward(self, x, *args, **kwargs): + ret = self.model.forward_features_with_camera(x, *args, **kwargs) + output = torch.cat([ret['x_norm_clstoken'].unsqueeze(1), ret['x_norm_patchtokens']], dim=1) + return output diff --git a/svrm/ldm/modules/rendering_neus/__init__.py b/svrm/ldm/modules/rendering_neus/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a7f09c03caa76149581d5b15b541aecb3892d6a --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Empty diff --git a/svrm/ldm/modules/rendering_neus/mesh.py b/svrm/ldm/modules/rendering_neus/mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..6e224b84d91cf5a999f83f24cd94873151494094 --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/mesh.py @@ -0,0 +1,311 @@ +from __future__ import annotations + +import numpy as np +import torch +import torch.nn.functional as F + +from ...utils.typing import * + +def dot(x, y): + return torch.sum(x * y, -1, keepdim=True) + +class Mesh: + def __init__( + self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], v_rgb: Integer[Tensor, "Nf 3"], **kwargs + ) -> None: + self.v_pos: Float[Tensor, "Nv 3"] = v_pos + self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx + self.v_rgb: Optional[Float[Tensor, "Nv 3"]] = v_rgb + + self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None + self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None + self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None + self._t_tex_idx: Optional[Float[Tensor, "Nf 3"]] = None + # self._v_rgb: Optional[Float[Tensor, "Nv 3"]] = None + self._edges: Optional[Integer[Tensor, "Ne 2"]] = None + self.extras: Dict[str, Any] = {} + for k, v in kwargs.items(): + self.add_extra(k, v) + + def add_extra(self, k, v) -> None: + self.extras[k] = v + + def remove_outlier(self, outlier_n_faces_threshold: Union[int, float]) -> Mesh: + if self.requires_grad: + print("Mesh is differentiable, not removing outliers") + return self + + # use trimesh to first split the mesh into connected components + # then remove the components with less than n_face_threshold faces + import trimesh + + # construct a trimesh object + mesh = trimesh.Trimesh( + vertices=self.v_pos.detach().cpu().numpy(), + faces=self.t_pos_idx.detach().cpu().numpy(), + ) + + # split the mesh into connected components + components = mesh.split(only_watertight=False) + # log the number of faces in each component + print( + "Mesh has {} components, with faces: {}".format( + len(components), [c.faces.shape[0] for c in components] + ) + ) + + n_faces_threshold: int + if isinstance(outlier_n_faces_threshold, float): + # set the threshold to the number of faces in the largest component multiplied by outlier_n_faces_threshold + n_faces_threshold = int( + max([c.faces.shape[0] for c in components]) * outlier_n_faces_threshold + ) + else: + # set the threshold directly to outlier_n_faces_threshold + n_faces_threshold = outlier_n_faces_threshold + + # log the threshold + print( + "Removing components with less than {} faces".format(n_faces_threshold) + ) + + # remove the components with less than n_face_threshold faces + components = [c for c in components if c.faces.shape[0] >= n_faces_threshold] + + # log the number of faces in each component after removing outliers + print( + "Mesh has {} components after removing outliers, with faces: {}".format( + len(components), [c.faces.shape[0] for c in components] + ) + ) + # merge the components + mesh = trimesh.util.concatenate(components) + + # convert back to our mesh format + v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos) + t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx) + + clean_mesh = Mesh(v_pos, t_pos_idx) + # keep the extras unchanged + + if len(self.extras) > 0: + clean_mesh.extras = self.extras + print( + f"The following extra attributes are inherited from the original mesh unchanged: {list(self.extras.keys())}" + ) + return clean_mesh + + @property + def requires_grad(self): + return self.v_pos.requires_grad + + @property + def v_nrm(self): + if self._v_nrm is None: + self._v_nrm = self._compute_vertex_normal() + return self._v_nrm + + @property + def v_tng(self): + if self._v_tng is None: + self._v_tng = self._compute_vertex_tangent() + return self._v_tng + + @property + def v_tex(self): + if self._v_tex is None: + self._v_tex, self._t_tex_idx = self._unwrap_uv() + return self._v_tex + + @property + def t_tex_idx(self): + if self._t_tex_idx is None: + self._v_tex, self._t_tex_idx = self._unwrap_uv() + return self._t_tex_idx + + # @property + # def v_rgb(self): + # return self._v_rgb + + @property + def edges(self): + if self._edges is None: + self._edges = self._compute_edges() + return self._edges + + def _compute_vertex_normal(self): + i0 = self.t_pos_idx[:, 0] + i1 = self.t_pos_idx[:, 1] + i2 = self.t_pos_idx[:, 2] + + v0 = self.v_pos[i0, :] + v1 = self.v_pos[i1, :] + v2 = self.v_pos[i2, :] + + face_normals = torch.cross(v1 - v0, v2 - v0) + + # Splat face normals to vertices + v_nrm = torch.zeros_like(self.v_pos) + v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) + v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) + v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) + + # Normalize, replace zero (degenerated) normals with some default value + v_nrm = torch.where( + dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) + ) + v_nrm = F.normalize(v_nrm, dim=1) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(v_nrm)) + + return v_nrm + + def _compute_vertex_tangent(self): + vn_idx = [None] * 3 + pos = [None] * 3 + tex = [None] * 3 + for i in range(0, 3): + pos[i] = self.v_pos[self.t_pos_idx[:, i]] + tex[i] = self.v_tex[self.t_tex_idx[:, i]] + # t_nrm_idx is always the same as t_pos_idx + vn_idx[i] = self.t_pos_idx[:, i] + + tangents = torch.zeros_like(self.v_nrm) + tansum = torch.zeros_like(self.v_nrm) + + # Compute tangent space for each triangle + uve1 = tex[1] - tex[0] + uve2 = tex[2] - tex[0] + pe1 = pos[1] - pos[0] + pe2 = pos[2] - pos[0] + + nom = pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2] + denom = uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1] + + # Avoid division by zero for degenerated texture coordinates + tang = nom / torch.where( + denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6) + ) + + # Update all 3 vertices + for i in range(0, 3): + idx = vn_idx[i][:, None].repeat(1, 3) + tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang + tansum.scatter_add_( + 0, idx, torch.ones_like(tang) + ) # tansum[n_i] = tansum[n_i] + 1 + tangents = tangents / tansum + + # Normalize and make sure tangent is perpendicular to normal + tangents = F.normalize(tangents, dim=1) + tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(tangents)) + + return tangents + + def _unwrap_uv( + self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {} + ): + print("Using xatlas to perform UV unwrapping, may take a while ...") + + import xatlas + + atlas = xatlas.Atlas() + atlas.add_mesh( + self.v_pos.detach().cpu().numpy(), + self.t_pos_idx.cpu().numpy(), + ) + co = xatlas.ChartOptions() + po = xatlas.PackOptions() + for k, v in xatlas_chart_options.items(): + setattr(co, k, v) + for k, v in xatlas_pack_options.items(): + setattr(po, k, v) + atlas.generate(co, po) + vmapping, indices, uvs = atlas.get_mesh(0) + vmapping = ( + torch.from_numpy( + vmapping.astype(np.uint64, casting="same_kind").view(np.int64) + ) + .to(self.v_pos.device) + .long() + ) + uvs = torch.from_numpy(uvs).to(self.v_pos.device).float() + indices = ( + torch.from_numpy( + indices.astype(np.uint64, casting="same_kind").view(np.int64) + ) + .to(self.v_pos.device) + .long() + ) + return uvs, indices + + def unwrap_uv( + self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {} + ): + self._v_tex, self._t_tex_idx = self._unwrap_uv( + xatlas_chart_options, xatlas_pack_options + ) + + def set_vertex_color(self, v_rgb): + assert v_rgb.shape[0] == self.v_pos.shape[0] + self._v_rgb = v_rgb + + def _compute_edges(self): + # Compute edges + edges = torch.cat( + [ + self.t_pos_idx[:, [0, 1]], + self.t_pos_idx[:, [1, 2]], + self.t_pos_idx[:, [2, 0]], + ], + dim=0, + ) + edges = edges.sort()[0] + edges = torch.unique(edges, dim=0) + return edges + + def normal_consistency(self) -> Float[Tensor, ""]: + edge_nrm: Float[Tensor, "Ne 2 3"] = self.v_nrm[self.edges] + nc = ( + 1.0 - torch.cosine_similarity(edge_nrm[:, 0], edge_nrm[:, 1], dim=-1) + ).mean() + return nc + + def _laplacian_uniform(self): + # from stable-dreamfusion + # https://github.com/ashawkey/stable-dreamfusion/blob/8fb3613e9e4cd1ded1066b46e80ca801dfb9fd06/nerf/renderer.py#L224 + verts, faces = self.v_pos, self.t_pos_idx + + V = verts.shape[0] + F = faces.shape[0] + + # Neighbor indices + ii = faces[:, [1, 2, 0]].flatten() + jj = faces[:, [2, 0, 1]].flatten() + adj = torch.stack([torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique( + dim=1 + ) + adj_values = torch.ones(adj.shape[1]).to(verts) + + # Diagonal indices + diag_idx = adj[0] + + # Build the sparse matrix + idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1) + values = torch.cat((-adj_values, adj_values)) + + # The coalesce operation sums the duplicate indices, resulting in the + # correct diagonal + return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce() + + def laplacian(self) -> Float[Tensor, ""]: + with torch.no_grad(): + L = self._laplacian_uniform() + loss = L.mm(self.v_pos) + loss = loss.norm(dim=1) + loss = loss.mean() + return loss \ No newline at end of file diff --git a/svrm/ldm/modules/rendering_neus/rasterize.py b/svrm/ldm/modules/rendering_neus/rasterize.py new file mode 100644 index 0000000000000000000000000000000000000000..5e270fbf75be18e84540e31bcf566a83efb4827a --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/rasterize.py @@ -0,0 +1,78 @@ +import nvdiffrast.torch as dr +import torch + +from ...utils.typing import * + + +class NVDiffRasterizerContext: + def __init__(self, context_type: str, device: torch.device) -> None: + self.device = device + self.ctx = self.initialize_context(context_type, device) + + def initialize_context( + self, context_type: str, device: torch.device + ) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]: + if context_type == "gl": + return dr.RasterizeGLContext(device=device) + elif context_type == "cuda": + return dr.RasterizeCudaContext(device=device) + else: + raise ValueError(f"Unknown rasterizer context type: {context_type}") + + def vertex_transform( + self, verts: Float[Tensor, "Nv 3"], mvp_mtx: Float[Tensor, "B 4 4"] + ) -> Float[Tensor, "B Nv 4"]: + verts_homo = torch.cat( + [verts, torch.ones([verts.shape[0], 1]).to(verts)], dim=-1 + ) + return torch.matmul(verts_homo, mvp_mtx.permute(0, 2, 1)) + + def rasterize( + self, + pos: Float[Tensor, "B Nv 4"], + tri: Integer[Tensor, "Nf 3"], + resolution: Union[int, Tuple[int, int]], + ): + # rasterize in instance mode (single topology) + return dr.rasterize(self.ctx, pos.float(), tri.int(), resolution, grad_db=True) + + def rasterize_one( + self, + pos: Float[Tensor, "Nv 4"], + tri: Integer[Tensor, "Nf 3"], + resolution: Union[int, Tuple[int, int]], + ): + # rasterize one single mesh under a single viewpoint + rast, rast_db = self.rasterize(pos[None, ...], tri, resolution) + return rast[0], rast_db[0] + + def antialias( + self, + color: Float[Tensor, "B H W C"], + rast: Float[Tensor, "B H W 4"], + pos: Float[Tensor, "B Nv 4"], + tri: Integer[Tensor, "Nf 3"], + ) -> Float[Tensor, "B H W C"]: + return dr.antialias(color.float(), rast, pos.float(), tri.int()) + + def interpolate( + self, + attr: Float[Tensor, "B Nv C"], + rast: Float[Tensor, "B H W 4"], + tri: Integer[Tensor, "Nf 3"], + rast_db=None, + diff_attrs=None, + ) -> Float[Tensor, "B H W C"]: + return dr.interpolate( + attr.float(), rast, tri.int(), rast_db=rast_db, diff_attrs=diff_attrs + ) + + def interpolate_one( + self, + attr: Float[Tensor, "Nv C"], + rast: Float[Tensor, "B H W 4"], + tri: Integer[Tensor, "Nf 3"], + rast_db=None, + diff_attrs=None, + ) -> Float[Tensor, "B H W C"]: + return self.interpolate(attr[None, ...], rast, tri, rast_db, diff_attrs) diff --git a/svrm/ldm/modules/rendering_neus/synthesizer.py b/svrm/ldm/modules/rendering_neus/synthesizer.py new file mode 100644 index 0000000000000000000000000000000000000000..7acb391a27dead7490b87f742abc05fd1087d218 --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/synthesizer.py @@ -0,0 +1,277 @@ +# ORIGINAL LICENSE +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Modified by Zexin He +# The modifications are subject to the same license as the original. + + +import itertools +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils.renderer import ImportanceRenderer, sample_from_planes +from .utils.ray_sampler import RaySampler +from ...utils.ops import get_rank + + +class OSGDecoder(nn.Module): + """ + Triplane decoder that gives RGB and sigma values from sampled features. + Using ReLU here instead of Softplus in the original implementation. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 + """ + def __init__(self, n_features: int, + hidden_dim: int = 64, + num_layers: int = 2, + activation: nn.Module = nn.ReLU, + sdf_bias='sphere', + sdf_bias_params=0.5, + output_normal=True, + normal_type='finite_difference'): + super().__init__() + self.sdf_bias = sdf_bias + self.sdf_bias_params = sdf_bias_params + self.output_normal = output_normal + self.normal_type = normal_type + self.net = nn.Sequential( + nn.Linear(3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 1 + 3), + ) + # init all bias to zero + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.bias) + + def forward(self, ray_directions, sample_coordinates, plane_axes, planes, options): + # Aggregate features by mean + # sampled_features = sampled_features.mean(1) + # Aggregate features by concatenation + # torch.set_grad_enabled(True) + # sample_coordinates.requires_grad_(True) + + sampled_features = sample_from_planes(plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp']) + + + _N, n_planes, _M, _C = sampled_features.shape + sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) + x = sampled_features + + N, M, C = x.shape + # x = x.contiguous().view(N*M, C) + + x = self.net(x) + x = x.view(N, M, -1) + rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF + + sdf = x[..., 0:1] + # import ipdb; ipdb.set_trace() + # print(f'sample_coordinates shape: {sample_coordinates.shape}') + # sdf = self.get_shifted_sdf(sample_coordinates, sdf) + + # calculate normal + eps = 0.01 + offsets = torch.as_tensor( + [[eps, 0.0, 0.0], [0.0, eps, 0.0], [0.0, 0.0, eps]] + ).to(sample_coordinates) + points_offset = ( + sample_coordinates[..., None, :] + offsets # Float[Tensor, "... 3 3"] + ).clamp(options['sampler_bbox_min'], options['sampler_bbox_max']) + + sdf_offset_list = [self.forward_sdf( + plane_axes, + planes, + points_offset[:,:,i,:], + options + ).unsqueeze(-2) for i in range(points_offset.shape[-2])] # Float[Tensor, "... 3 1"] + # import ipdb; ipdb.set_trace() + + sdf_offset = torch.cat(sdf_offset_list, -2) + sdf_grad = (sdf_offset[..., 0::1, 0] - sdf) / eps + + normal = F.normalize(sdf_grad, dim=-1).to(sdf.dtype) + return {'rgb': rgb, 'sdf': sdf, 'normal': normal, 'sdf_grad': sdf_grad} + + def forward_sdf(self, plane_axes, planes, points_offset, options): + + sampled_features = sample_from_planes(plane_axes, planes, points_offset, padding_mode='zeros', box_warp=options['box_warp']) + _N, n_planes, _M, _C = sampled_features.shape + sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) + x = sampled_features + + N, M, C = x.shape + # x = x.contiguous().view(N*M, C) + + x = self.net(x) + x = x.view(N, M, -1) + sdf = x[..., 0:1] + # sdf = self.get_shifted_sdf(points_offset, sdf) + return sdf + + def get_shifted_sdf( + self, points, sdf + ): + if self.sdf_bias == "sphere": + assert isinstance(self.sdf_bias_params, float) + radius = self.sdf_bias_params + sdf_bias = (points**2).sum(dim=-1, keepdim=True).sqrt() - radius + else: + raise ValueError(f"Unknown sdf bias {self.cfg.sdf_bias}") + return sdf + sdf_bias.to(sdf.dtype) + + +class TriplaneSynthesizer(nn.Module): + """ + Synthesizer that renders a triplane volume with planes and a camera. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19 + """ + + DEFAULT_RENDERING_KWARGS = { + 'ray_start': 'auto', + 'ray_end': 'auto', + 'box_warp': 1.2, + # 'box_warp': 1., + 'white_back': True, + 'disparity_space_sampling': False, + 'clamp_mode': 'softplus', + # 'sampler_bbox_min': -1, + # 'sampler_bbox_max': 1., + 'sampler_bbox_min': -0.6, + 'sampler_bbox_max': 0.6, + } + print('DEFAULT_RENDERING_KWARGS') + print(DEFAULT_RENDERING_KWARGS) + + + def __init__(self, triplane_dim: int, samples_per_ray: int, osg_decoder='default'): + super().__init__() + + # attributes + self.triplane_dim = triplane_dim + self.rendering_kwargs = { + **self.DEFAULT_RENDERING_KWARGS, + 'depth_resolution': samples_per_ray, + 'depth_resolution_importance': 0 + # 'depth_resolution': samples_per_ray // 2, + # 'depth_resolution_importance': samples_per_ray // 2, + } + + # renderings + self.renderer = ImportanceRenderer() + self.ray_sampler = RaySampler() + # modules + if osg_decoder == 'default': + self.decoder = OSGDecoder(n_features=triplane_dim) + else: + raise NotImplementedError + + def forward(self, planes, ray_origins, ray_directions, render_size, bgcolor=None): + # planes: (N, 3, D', H', W') + # render_size: int + assert ray_origins.dim() == 3, "ray_origins should be 3-dimensional" + + + # Perform volume rendering + rgb_samples, depth_samples, weights_samples, sdf_grad, normal_samples = self.renderer( + planes, self.decoder, ray_origins, ray_directions, self.rendering_kwargs, bgcolor + ) + N = planes.shape[0] + + # zhaohx : add for normals + normal_samples = F.normalize(normal_samples, dim=-1) + normal_samples = (normal_samples + 1.0) / 2.0 # for visualization + normal_samples = torch.lerp(torch.zeros_like(normal_samples), normal_samples, weights_samples) + + # Reshape into 'raw' neural-rendered image + Himg = Wimg = render_size + rgb_images = rgb_samples.permute(0, 2, 1).reshape(N, rgb_samples.shape[-1], Himg, Wimg).contiguous() + depth_images = depth_samples.permute(0, 2, 1).reshape(N, 1, Himg, Wimg) + weight_images = weights_samples.permute(0, 2, 1).reshape(N, 1, Himg, Wimg) + + # zhaohx : add for normals + normal_images = normal_samples.permute(0, 2, 1).reshape(N, normal_samples.shape[-1], Himg, Wimg).contiguous() + + # return { + # 'images_rgb': rgb_images, + # 'images_depth': depth_images, + # 'images_weight': weight_images, + # } + + return { + 'comp_rgb': rgb_images, + 'comp_depth': depth_images, + 'opacity': weight_images, + 'sdf_grad': sdf_grad, + 'comp_normal': normal_images + } + # 输出normal的话在这个return里加 + + def forward_grid(self, planes, grid_size: int, aabb: torch.Tensor = None): + # planes: (N, 3, D', H', W') + # grid_size: int + # aabb: (N, 2, 3) + if aabb is None: + aabb = torch.tensor([ + [self.rendering_kwargs['sampler_bbox_min']] * 3, + [self.rendering_kwargs['sampler_bbox_max']] * 3, + ], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1) + assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb" + N = planes.shape[0] + + # create grid points for triplane query + grid_points = [] + for i in range(N): + grid_points.append(torch.stack(torch.meshgrid( + torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device), + torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device), + torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device), + indexing='ij', + ), dim=-1).reshape(-1, 3)) + cube_grid = torch.stack(grid_points, dim=0).to(planes.device) + + features = self.forward_points(planes, cube_grid) + + # reshape into grid + features = { + k: v.reshape(N, grid_size, grid_size, grid_size, -1) + for k, v in features.items() + } + return features + + def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**20): + # planes: (N, 3, D', H', W') + # points: (N, P, 3) + N, P = points.shape[:2] + + # query triplane in chunks + outs = [] + for i in range(0, points.shape[1], chunk_size): + chunk_points = points[:, i:i+chunk_size] + + # query triplane + # chunk_out = self.renderer.run_model_activated( + chunk_out = self.renderer.run_model( + planes=planes, + decoder=self.decoder, + sample_coordinates=chunk_points, + sample_directions=torch.zeros_like(chunk_points), + options=self.rendering_kwargs, + ) + outs.append(chunk_out) + + # concatenate the outputs + point_features = { + k: torch.cat([out[k] for out in outs], dim=1) + for k in outs[0].keys() + } + return point_features diff --git a/svrm/ldm/modules/rendering_neus/third_party/__init__.py b/svrm/ldm/modules/rendering_neus/third_party/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5bb64cc793df9f8474d19b5b29acdd82ed7870da --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/third_party/__init__.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +# empty diff --git a/svrm/ldm/modules/rendering_neus/third_party/custom_ops.py b/svrm/ldm/modules/rendering_neus/third_party/custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..3a70b33337358d67f34332ba1456f3efc1f08d68 --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/third_party/custom_ops.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import glob +import hashlib +import importlib +import os +import re +import shutil +import uuid + +import torch +import torch.utils.cpp_extension +from torch.utils.file_baton import FileBaton + +#---------------------------------------------------------------------------- +# Global options. + +verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' + +#---------------------------------------------------------------------------- +# Internal helper funcs. + +def _find_compiler_bindir(): + patterns = [ + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', + ] + for pattern in patterns: + matches = sorted(glob.glob(pattern)) + if len(matches): + return matches[-1] + return None + +#---------------------------------------------------------------------------- + +def _get_mangled_gpu_name(): + name = torch.cuda.get_device_name().lower() + out = [] + for c in name: + if re.match('[a-z0-9_-]+', c): + out.append(c) + else: + out.append('-') + return ''.join(out) + +#---------------------------------------------------------------------------- +# Main entry point for compiling and loading C++/CUDA plugins. + +_cached_plugins = dict() + +def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): + assert verbosity in ['none', 'brief', 'full'] + if headers is None: + headers = [] + if source_dir is not None: + sources = [os.path.join(source_dir, fname) for fname in sources] + headers = [os.path.join(source_dir, fname) for fname in headers] + + # Already cached? + if module_name in _cached_plugins: + return _cached_plugins[module_name] + + # Print status. + if verbosity == 'full': + print(f'Setting up PyTorch plugin "{module_name}"...') + elif verbosity == 'brief': + print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) + verbose_build = (verbosity == 'full') + + # Compile and load. + try: # pylint: disable=too-many-nested-blocks + # Make sure we can find the necessary compiler binaries. + if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: + compiler_bindir = _find_compiler_bindir() + if compiler_bindir is None: + raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') + os.environ['PATH'] += ';' + compiler_bindir + + # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either + # break the build or unnecessarily restrict what's available to nvcc. + # Unset it to let nvcc decide based on what's available on the + # machine. + os.environ['TORCH_CUDA_ARCH_LIST'] = '' + + # Incremental build md5sum trickery. Copies all the input source files + # into a cached build directory under a combined md5 digest of the input + # source files. Copying is done only if the combined digest has changed. + # This keeps input file timestamps and filenames the same as in previous + # extension builds, allowing for fast incremental rebuilds. + # + # This optimization is done only in case all the source files reside in + # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR + # environment variable is set (we take this as a signal that the user + # actually cares about this.) + # + # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work + # around the *.cu dependency bug in ninja config. + # + all_source_files = sorted(sources + headers) + all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) + if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): + + # Compute combined hash digest for all source files. + hash_md5 = hashlib.md5() + for src in all_source_files: + with open(src, 'rb') as f: + hash_md5.update(f.read()) + + # Select cached build directory name. + source_digest = hash_md5.hexdigest() + build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access + cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') + + if not os.path.isdir(cached_build_dir): + tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' + os.makedirs(tmpdir) + for src in all_source_files: + shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) + try: + os.replace(tmpdir, cached_build_dir) # atomic + except OSError: + # source directory already exists, delete tmpdir and its contents. + shutil.rmtree(tmpdir) + if not os.path.isdir(cached_build_dir): raise + + # Compile. + cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] + torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, + verbose=verbose_build, sources=cached_sources, **build_kwargs) + else: + torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) + + # Load. + module = importlib.import_module(module_name) + + except: + if verbosity == 'brief': + print('Failed!') + raise + + # Print status and add to cache dict. + if verbosity == 'full': + print(f'Done setting up PyTorch plugin "{module_name}".') + elif verbosity == 'brief': + print('Done.') + _cached_plugins[module_name] = module + return module + +#---------------------------------------------------------------------------- diff --git a/svrm/ldm/modules/rendering_neus/third_party/dnnlib/__init__.py b/svrm/ldm/modules/rendering_neus/third_party/dnnlib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e4540b93aba3437b83b517eeecdf66c133fbdfd --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/third_party/dnnlib/__init__.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from .util import EasyDict, make_cache_dir_path diff --git a/svrm/ldm/modules/rendering_neus/third_party/dnnlib/util.py b/svrm/ldm/modules/rendering_neus/third_party/dnnlib/util.py new file mode 100644 index 0000000000000000000000000000000000000000..a26590e44be0e9a5011699a88910c0df12ad68d3 --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/third_party/dnnlib/util.py @@ -0,0 +1,493 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Miscellaneous utility classes and functions.""" + +import ctypes +import fnmatch +import importlib +import inspect +import numpy as np +import os +import shutil +import sys +import types +import io +import pickle +import re +import requests +import html +import hashlib +import glob +import tempfile +import urllib +import urllib.request +import uuid + +from distutils.util import strtobool +from typing import Any, List, Tuple, Union + + +# Util classes +# ------------------------------------------------------------------------------------------ + + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +class Logger(object): + """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" + + def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): + self.file = None + + if file_name is not None: + self.file = open(file_name, file_mode) + + self.should_flush = should_flush + self.stdout = sys.stdout + self.stderr = sys.stderr + + sys.stdout = self + sys.stderr = self + + def __enter__(self) -> "Logger": + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.close() + + def write(self, text: Union[str, bytes]) -> None: + """Write text to stdout (and a file) and optionally flush.""" + if isinstance(text, bytes): + text = text.decode() + if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash + return + + if self.file is not None: + self.file.write(text) + + self.stdout.write(text) + + if self.should_flush: + self.flush() + + def flush(self) -> None: + """Flush written text to both stdout and a file, if open.""" + if self.file is not None: + self.file.flush() + + self.stdout.flush() + + def close(self) -> None: + """Flush, close possible files, and remove stdout/stderr mirroring.""" + self.flush() + + # if using multiple loggers, prevent closing in wrong order + if sys.stdout is self: + sys.stdout = self.stdout + if sys.stderr is self: + sys.stderr = self.stderr + + if self.file is not None: + self.file.close() + self.file = None + + +# Cache directories +# ------------------------------------------------------------------------------------------ + +_dnnlib_cache_dir = None + +def set_cache_dir(path: str) -> None: + global _dnnlib_cache_dir + _dnnlib_cache_dir = path + +def make_cache_dir_path(*paths: str) -> str: + if _dnnlib_cache_dir is not None: + return os.path.join(_dnnlib_cache_dir, *paths) + if 'DNNLIB_CACHE_DIR' in os.environ: + return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) + if 'HOME' in os.environ: + return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) + if 'USERPROFILE' in os.environ: + return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) + return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) + +# Small util functions +# ------------------------------------------------------------------------------------------ + + +def format_time(seconds: Union[int, float]) -> str: + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) + else: + return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) + + +def format_time_brief(seconds: Union[int, float]) -> str: + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) + else: + return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) + + +def ask_yes_no(question: str) -> bool: + """Ask the user the question until the user inputs a valid answer.""" + while True: + try: + print("{0} [y/n]".format(question)) + return strtobool(input().lower()) + except ValueError: + pass + + +def tuple_product(t: Tuple) -> Any: + """Calculate the product of the tuple elements.""" + result = 1 + + for v in t: + result *= v + + return result + + +_str_to_ctype = { + "uint8": ctypes.c_ubyte, + "uint16": ctypes.c_uint16, + "uint32": ctypes.c_uint32, + "uint64": ctypes.c_uint64, + "int8": ctypes.c_byte, + "int16": ctypes.c_int16, + "int32": ctypes.c_int32, + "int64": ctypes.c_int64, + "float32": ctypes.c_float, + "float64": ctypes.c_double +} + + +def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: + """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" + type_str = None + + if isinstance(type_obj, str): + type_str = type_obj + elif hasattr(type_obj, "__name__"): + type_str = type_obj.__name__ + elif hasattr(type_obj, "name"): + type_str = type_obj.name + else: + raise RuntimeError("Cannot infer type name from input") + + assert type_str in _str_to_ctype.keys() + + my_dtype = np.dtype(type_str) + my_ctype = _str_to_ctype[type_str] + + assert my_dtype.itemsize == ctypes.sizeof(my_ctype) + + return my_dtype, my_ctype + + +def is_pickleable(obj: Any) -> bool: + try: + with io.BytesIO() as stream: + pickle.dump(obj, stream) + return True + except: + return False + + +# Functionality to import modules/objects by name, and call functions by name +# ------------------------------------------------------------------------------------------ + +def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: + """Searches for the underlying module behind the name to some python object. + Returns the module and the object name (original name with module part removed).""" + + # allow convenience shorthands, substitute them by full names + obj_name = re.sub("^np.", "numpy.", obj_name) + obj_name = re.sub("^tf.", "tensorflow.", obj_name) + + # list alternatives for (module_name, local_obj_name) + parts = obj_name.split(".") + name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] + + # try each alternative in turn + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + return module, local_obj_name + except: + pass + + # maybe some of the modules themselves contain errors? + for module_name, _local_obj_name in name_pairs: + try: + importlib.import_module(module_name) # may raise ImportError + except ImportError: + if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): + raise + + # maybe the requested attribute is missing? + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + except ImportError: + pass + + # we are out of luck, but we have no idea why + raise ImportError(obj_name) + + +def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: + """Traverses the object name and returns the last (rightmost) python object.""" + if obj_name == '': + return module + obj = module + for part in obj_name.split("."): + obj = getattr(obj, part) + return obj + + +def get_obj_by_name(name: str) -> Any: + """Finds the python object with the given name.""" + module, obj_name = get_module_from_obj_name(name) + return get_obj_from_module(module, obj_name) + + +def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: + """Finds the python object with the given name and calls it as a function.""" + assert func_name is not None + func_obj = get_obj_by_name(func_name) + assert callable(func_obj) + return func_obj(*args, **kwargs) + + +def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: + """Finds the python class with the given name and constructs it with the given arguments.""" + return call_func_by_name(*args, func_name=class_name, **kwargs) + + +def get_module_dir_by_obj_name(obj_name: str) -> str: + """Get the directory path of the module containing the given object name.""" + module, _ = get_module_from_obj_name(obj_name) + return os.path.dirname(inspect.getfile(module)) + + +def is_top_level_function(obj: Any) -> bool: + """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" + return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ + + +def get_top_level_function_name(obj: Any) -> str: + """Return the fully-qualified name of a top-level function.""" + assert is_top_level_function(obj) + module = obj.__module__ + if module == '__main__': + module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] + return module + "." + obj.__name__ + + +# File system helpers +# ------------------------------------------------------------------------------------------ + +def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: + """List all files recursively in a given directory while ignoring given file and directory names. + Returns list of tuples containing both absolute and relative paths.""" + assert os.path.isdir(dir_path) + base_name = os.path.basename(os.path.normpath(dir_path)) + + if ignores is None: + ignores = [] + + result = [] + + for root, dirs, files in os.walk(dir_path, topdown=True): + for ignore_ in ignores: + dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] + + # dirs need to be edited in-place + for d in dirs_to_remove: + dirs.remove(d) + + files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] + + absolute_paths = [os.path.join(root, f) for f in files] + relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] + + if add_base_to_relative: + relative_paths = [os.path.join(base_name, p) for p in relative_paths] + + assert len(absolute_paths) == len(relative_paths) + result += zip(absolute_paths, relative_paths) + + return result + + +def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: + """Takes in a list of tuples of (src, dst) paths and copies files. + Will create all necessary directories.""" + for file in files: + target_dir_name = os.path.dirname(file[1]) + + # will create all intermediate-level directories + if not os.path.exists(target_dir_name): + os.makedirs(target_dir_name) + + shutil.copyfile(file[0], file[1]) + + +# URL helpers +# ------------------------------------------------------------------------------------------ + +def is_url(obj: Any, allow_file_urls: bool = False) -> bool: + """Determine whether the given object is a valid URL string.""" + if not isinstance(obj, str) or not "://" in obj: + return False + if allow_file_urls and obj.startswith('file://'): + return True + try: + res = requests.compat.urlparse(obj) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + except: + return False + return True + + +def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: + """Download the given URL and return a binary-mode file object to access the data.""" + assert num_attempts >= 1 + assert not (return_filename and (not cache)) + + # Doesn't look like an URL scheme so interpret it as a local filename. + if not re.match('^[a-z]+://', url): + return url if return_filename else open(url, "rb") + + # Handle file URLs. This code handles unusual file:// patterns that + # arise on Windows: + # + # file:///c:/foo.txt + # + # which would translate to a local '/c:/foo.txt' filename that's + # invalid. Drop the forward slash for such pathnames. + # + # If you touch this code path, you should test it on both Linux and + # Windows. + # + # Some internet resources suggest using urllib.request.url2pathname() but + # but that converts forward slashes to backslashes and this causes + # its own set of problems. + if url.startswith('file://'): + filename = urllib.parse.urlparse(url).path + if re.match(r'^/[a-zA-Z]:', filename): + filename = filename[1:] + return filename if return_filename else open(filename, "rb") + + assert is_url(url) + + # Lookup from cache. + if cache_dir is None: + cache_dir = make_cache_dir_path('downloads') + + url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() + if cache: + cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) + if len(cache_files) == 1: + filename = cache_files[0] + return filename if return_filename else open(filename, "rb") + + # Download. + url_name = None + url_data = None + with requests.Session() as session: + if verbose: + print("Downloading %s ..." % url, end="", flush=True) + for attempts_left in reversed(range(num_attempts)): + try: + with session.get(url) as res: + res.raise_for_status() + if len(res.content) == 0: + raise IOError("No data received") + + if len(res.content) < 8192: + content_str = res.content.decode("utf-8") + if "download_warning" in res.headers.get("Set-Cookie", ""): + links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] + if len(links) == 1: + url = requests.compat.urljoin(url, links[0]) + raise IOError("Google Drive virus checker nag") + if "Google Drive - Quota exceeded" in content_str: + raise IOError("Google Drive download quota exceeded -- please try again later") + + match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) + url_name = match[1] if match else url + url_data = res.content + if verbose: + print(" done") + break + except KeyboardInterrupt: + raise + except: + if not attempts_left: + if verbose: + print(" failed") + raise + if verbose: + print(".", end="", flush=True) + + # Save to cache. + if cache: + safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) + cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) + temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) + os.makedirs(cache_dir, exist_ok=True) + with open(temp_file, "wb") as f: + f.write(url_data) + os.replace(temp_file, cache_file) # atomic + if return_filename: + return cache_file + + # Return data as file object. + assert not return_filename + return io.BytesIO(url_data) diff --git a/svrm/ldm/modules/rendering_neus/third_party/misc.py b/svrm/ldm/modules/rendering_neus/third_party/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..09428a800825bee0e3828460ccd5a0c9e1e648aa --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/third_party/misc.py @@ -0,0 +1,268 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import re +import contextlib +import numpy as np +import torch +import warnings +from ldm.modules.neus.third_party import dnnlib + +#---------------------------------------------------------------------------- +# Cached construction of constant tensors. Avoids CPU=>GPU copy when the +# same constant is used multiple times. + +_constant_cache = dict() + +def constant(value, shape=None, dtype=None, device=None, memory_format=None): + value = np.asarray(value) + if shape is not None: + shape = tuple(shape) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device('cpu') + if memory_format is None: + memory_format = torch.contiguous_format + + key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) + tensor = _constant_cache.get(key, None) + if tensor is None: + tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) + if shape is not None: + tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) + tensor = tensor.contiguous(memory_format=memory_format) + _constant_cache[key] = tensor + return tensor + +#---------------------------------------------------------------------------- +# Replace NaN/Inf with specified numerical values. + +try: + nan_to_num = torch.nan_to_num # 1.8.0a0 +except AttributeError: + def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin + assert isinstance(input, torch.Tensor) + if posinf is None: + posinf = torch.finfo(input.dtype).max + if neginf is None: + neginf = torch.finfo(input.dtype).min + assert nan == 0 + return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) + +#---------------------------------------------------------------------------- +# Symbolic assert. + +try: + symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access +except AttributeError: + symbolic_assert = torch.Assert # 1.7.0 + +#---------------------------------------------------------------------------- +# Context manager to temporarily suppress known warnings in torch.jit.trace(). +# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 + +@contextlib.contextmanager +def suppress_tracer_warnings(): + flt = ('ignore', None, torch.jit.TracerWarning, None, 0) + warnings.filters.insert(0, flt) + yield + warnings.filters.remove(flt) + +#---------------------------------------------------------------------------- +# Assert that the shape of a tensor matches the given list of integers. +# None indicates that the size of a dimension is allowed to vary. +# Performs symbolic assertion when used in torch.jit.trace(). + +def assert_shape(tensor, ref_shape): + if tensor.ndim != len(ref_shape): + raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') + for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): + if ref_size is None: + pass + elif isinstance(ref_size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') + elif isinstance(size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') + elif size != ref_size: + raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') + +#---------------------------------------------------------------------------- +# Function decorator that calls torch.autograd.profiler.record_function(). + +def profiled_function(fn): + def decorator(*args, **kwargs): + with torch.autograd.profiler.record_function(fn.__name__): + return fn(*args, **kwargs) + decorator.__name__ = fn.__name__ + return decorator + +#---------------------------------------------------------------------------- +# Sampler for torch.utils.data.DataLoader that loops over the dataset +# indefinitely, shuffling items as it goes. + +class InfiniteSampler(torch.utils.data.Sampler): + def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): + assert len(dataset) > 0 + assert num_replicas > 0 + assert 0 <= rank < num_replicas + assert 0 <= window_size <= 1 + super().__init__(dataset) + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self): + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + +#---------------------------------------------------------------------------- +# Utilities for operating with torch.nn.Module parameters and buffers. + +def params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.parameters()) + list(module.buffers()) + +def named_params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.named_parameters()) + list(module.named_buffers()) + +def copy_params_and_buffers(src_module, dst_module, require_all=False): + assert isinstance(src_module, torch.nn.Module) + assert isinstance(dst_module, torch.nn.Module) + src_tensors = dict(named_params_and_buffers(src_module)) + for name, tensor in named_params_and_buffers(dst_module): + assert (name in src_tensors) or (not require_all) + if name in src_tensors: + tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) + +#---------------------------------------------------------------------------- +# Context manager for easily enabling/disabling DistributedDataParallel +# synchronization. + +@contextlib.contextmanager +def ddp_sync(module, sync): + assert isinstance(module, torch.nn.Module) + if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): + yield + else: + with module.no_sync(): + yield + +#---------------------------------------------------------------------------- +# Check DistributedDataParallel consistency across processes. + +def check_ddp_consistency(module, ignore_regex=None): + assert isinstance(module, torch.nn.Module) + for name, tensor in named_params_and_buffers(module): + fullname = type(module).__name__ + '.' + name + if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): + continue + tensor = tensor.detach() + if tensor.is_floating_point(): + tensor = nan_to_num(tensor) + other = tensor.clone() + torch.distributed.broadcast(tensor=other, src=0) + assert (tensor == other).all(), fullname + +#---------------------------------------------------------------------------- +# Print summary table of module hierarchy. + +def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): + assert isinstance(module, torch.nn.Module) + assert not isinstance(module, torch.jit.ScriptModule) + assert isinstance(inputs, (tuple, list)) + + # Register hooks. + entries = [] + nesting = [0] + def pre_hook(_mod, _inputs): + nesting[0] += 1 + def post_hook(mod, _inputs, outputs): + nesting[0] -= 1 + if nesting[0] <= max_nesting: + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [t for t in outputs if isinstance(t, torch.Tensor)] + entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) + hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] + hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] + + # Run module. + outputs = module(*inputs) + for hook in hooks: + hook.remove() + + # Identify unique outputs, parameters, and buffers. + tensors_seen = set() + for e in entries: + e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] + e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] + e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] + tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} + + # Filter out redundant entries. + if skip_redundant: + entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] + + # Construct table. + rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] + rows += [['---'] * len(rows[0])] + param_total = 0 + buffer_total = 0 + submodule_names = {mod: name for name, mod in module.named_modules()} + for e in entries: + name = '' if e.mod is module else submodule_names[e.mod] + param_size = sum(t.numel() for t in e.unique_params) + buffer_size = sum(t.numel() for t in e.unique_buffers) + output_shapes = [str(list(t.shape)) for t in e.outputs] + output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] + rows += [[ + name + (':0' if len(e.outputs) >= 2 else ''), + str(param_size) if param_size else '-', + str(buffer_size) if buffer_size else '-', + (output_shapes + ['-'])[0], + (output_dtypes + ['-'])[0], + ]] + for idx in range(1, len(e.outputs)): + rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] + param_total += param_size + buffer_total += buffer_size + rows += [['---'] * len(rows[0])] + rows += [['Total', str(param_total), str(buffer_total), '-', '-']] + + # Print table. + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + print() + for row in rows: + print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) + print() + return outputs + +#---------------------------------------------------------------------------- diff --git a/svrm/ldm/modules/rendering_neus/third_party/ops/__init__.py b/svrm/ldm/modules/rendering_neus/third_party/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5bb64cc793df9f8474d19b5b29acdd82ed7870da --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/third_party/ops/__init__.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +# empty diff --git a/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.cpp b/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f507b2807e4e2e8cd6e79de122767a88ee1b052a --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.cpp @@ -0,0 +1,103 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ + +static bool has_same_layout(torch::Tensor x, torch::Tensor y) +{ + if (x.dim() != y.dim()) + return false; + for (int64_t i = 0; i < x.dim(); i++) + { + if (x.size(i) != y.size(i)) + return false; + if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) + return false; + } + return true; +} + +//------------------------------------------------------------------------ + +static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); + TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); + TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); + TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(b.dim() == 1, "b must have rank 1"); + TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); + TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); + TORCH_CHECK(grad >= 0, "grad must be non-negative"); + + // Validate layout. + TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); + TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); + TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); + TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); + TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + torch::Tensor y = torch::empty_like(x); + TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); + + // Initialize CUDA kernel parameters. + bias_act_kernel_params p; + p.x = x.data_ptr(); + p.b = (b.numel()) ? b.data_ptr() : NULL; + p.xref = (xref.numel()) ? xref.data_ptr() : NULL; + p.yref = (yref.numel()) ? yref.data_ptr() : NULL; + p.dy = (dy.numel()) ? dy.data_ptr() : NULL; + p.y = y.data_ptr(); + p.grad = grad; + p.act = act; + p.alpha = alpha; + p.gain = gain; + p.clamp = clamp; + p.sizeX = (int)x.numel(); + p.sizeB = (int)b.numel(); + p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; + + // Choose CUDA kernel. + void* kernel; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + kernel = choose_bias_act_kernel(p); + }); + TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); + + // Launch CUDA kernel. + p.loopX = 4; + int blockSize = 4 * 32; + int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("bias_act", &bias_act); +} + +//------------------------------------------------------------------------ diff --git a/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.cu b/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.cu new file mode 100644 index 0000000000000000000000000000000000000000..a0118b34928895bc3c4f68d96c5bb2750b42aa72 --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.cu @@ -0,0 +1,177 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +//------------------------------------------------------------------------ +// CUDA kernel. + +template +__global__ void bias_act_kernel(bias_act_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + int G = p.grad; + scalar_t alpha = (scalar_t)p.alpha; + scalar_t gain = (scalar_t)p.gain; + scalar_t clamp = (scalar_t)p.clamp; + scalar_t one = (scalar_t)1; + scalar_t two = (scalar_t)2; + scalar_t expRange = (scalar_t)80; + scalar_t halfExpRange = (scalar_t)40; + scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; + scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; + + // Loop over elements. + int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; + for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) + { + // Load. + scalar_t x = (scalar_t)((const T*)p.x)[xi]; + scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; + scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; + scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; + scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; + scalar_t yy = (gain != 0) ? yref / gain : 0; + scalar_t y = 0; + + // Apply bias. + ((G == 0) ? x : xref) += b; + + // linear + if (A == 1) + { + if (G == 0) y = x; + if (G == 1) y = x; + } + + // relu + if (A == 2) + { + if (G == 0) y = (x > 0) ? x : 0; + if (G == 1) y = (yy > 0) ? x : 0; + } + + // lrelu + if (A == 3) + { + if (G == 0) y = (x > 0) ? x : x * alpha; + if (G == 1) y = (yy > 0) ? x : x * alpha; + } + + // tanh + if (A == 4) + { + if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } + if (G == 1) y = x * (one - yy * yy); + if (G == 2) y = x * (one - yy * yy) * (-two * yy); + } + + // sigmoid + if (A == 5) + { + if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); + if (G == 1) y = x * yy * (one - yy); + if (G == 2) y = x * yy * (one - yy) * (one - two * yy); + } + + // elu + if (A == 6) + { + if (G == 0) y = (x >= 0) ? x : exp(x) - one; + if (G == 1) y = (yy >= 0) ? x : x * (yy + one); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); + } + + // selu + if (A == 7) + { + if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); + if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); + } + + // softplus + if (A == 8) + { + if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); + if (G == 1) y = x * (one - exp(-yy)); + if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } + } + + // swish + if (A == 9) + { + if (G == 0) + y = (x < -expRange) ? 0 : x / (exp(-x) + one); + else + { + scalar_t c = exp(xref); + scalar_t d = c + one; + if (G == 1) + y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); + else + y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); + yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; + } + } + + // Apply gain. + y *= gain * dy; + + // Clamp. + if (clamp >= 0) + { + if (G == 0) + y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; + else + y = (yref > -clamp & yref < clamp) ? y : 0; + } + + // Store. + ((T*)p.y)[xi] = (T)y; + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p) +{ + if (p.act == 1) return (void*)bias_act_kernel; + if (p.act == 2) return (void*)bias_act_kernel; + if (p.act == 3) return (void*)bias_act_kernel; + if (p.act == 4) return (void*)bias_act_kernel; + if (p.act == 5) return (void*)bias_act_kernel; + if (p.act == 6) return (void*)bias_act_kernel; + if (p.act == 7) return (void*)bias_act_kernel; + if (p.act == 8) return (void*)bias_act_kernel; + if (p.act == 9) return (void*)bias_act_kernel; + return NULL; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.h b/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.h new file mode 100644 index 0000000000000000000000000000000000000000..8a2c3a861ce21e6470de2aaaec24bb2713259c63 --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.h @@ -0,0 +1,42 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct bias_act_kernel_params +{ + const void* x; // [sizeX] + const void* b; // [sizeB] or NULL + const void* xref; // [sizeX] or NULL + const void* yref; // [sizeX] or NULL + const void* dy; // [sizeX] or NULL + void* y; // [sizeX] + + int grad; + int act; + float alpha; + float gain; + float clamp; + + int sizeX; + int sizeB; + int stepB; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.py b/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4e2aedd6deddafe1fd9c80a009f6d1ec99746b --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Custom PyTorch ops for efficient bias and activation.""" + +import os +import numpy as np +import torch +from ldm.modules.neus.third_party import dnnlib + +from .. import custom_ops +from .. import misc + +#---------------------------------------------------------------------------- + +activation_funcs = { + 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), + 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), + 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), + 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), + 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), + 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), + 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), + 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), + 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), +} + +#---------------------------------------------------------------------------- + +_plugin = None +_null_tensor = torch.empty([0]) + +def _init(): + global _plugin + if _plugin is None: + _plugin = custom_ops.get_plugin( + module_name='bias_act_plugin', + sources=['bias_act.cpp', 'bias_act.cu'], + headers=['bias_act.h'], + source_dir=os.path.dirname(__file__), + extra_cuda_cflags=['--use_fast_math'], + ) + return True + +#---------------------------------------------------------------------------- + +def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): + r"""Fused bias and activation function. + + Adds bias `b` to activation tensor `x`, evaluates activation function `act`, + and scales the result by `gain`. Each of the steps is optional. In most cases, + the fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports first and second order gradients, + but not third order gradients. + + Args: + x: Input activation tensor. Can be of any shape. + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The shape must be known, and it must match the dimension of `x` + corresponding to `dim`. + dim: The dimension in `x` corresponding to the elements of `b`. + The value of `dim` is ignored if `b` is not specified. + act: Name of the activation function to evaluate, or `"linear"` to disable. + Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. + See `activation_funcs` for a full list. `None` is not allowed. + alpha: Shape parameter for the activation function, or `None` to use the default. + gain: Scaling factor for the output tensor, or `None` to use default. + See `activation_funcs` for the default scaling of each activation function. + If unsure, consider specifying 1. + clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable + the clamping (default). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the same shape and datatype as `x`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) + return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Slow reference implementation of `bias_act()` using standard TensorFlow ops. + """ + assert isinstance(x, torch.Tensor) + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Add bias. + if b is not None: + assert isinstance(b, torch.Tensor) and b.ndim == 1 + assert 0 <= dim < x.ndim + assert b.shape[0] == x.shape[dim] + x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) + + # Evaluate activation function. + alpha = float(alpha) + x = spec.func(x, alpha=alpha) + + # Scale by gain. + gain = float(gain) + if gain != 1: + x = x * gain + + # Clamp. + if clamp >= 0: + x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type + return x + +#---------------------------------------------------------------------------- + +_bias_act_cuda_cache = dict() + +def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Fast CUDA implementation of `bias_act()` using custom ops. + """ + # Parse arguments. + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Lookup from cache. + key = (dim, act, alpha, gain, clamp) + if key in _bias_act_cuda_cache: + return _bias_act_cuda_cache[key] + + # Forward op. + class BiasActCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, b): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format + x = x.contiguous(memory_format=ctx.memory_format) + b = b.contiguous() if b is not None else _null_tensor + y = x + if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: + y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + y if 'y' in spec.ref else _null_tensor) + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + dy = dy.contiguous(memory_format=ctx.memory_format) + x, b, y = ctx.saved_tensors + dx = None + db = None + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + dx = dy + if act != 'linear' or gain != 1 or clamp >= 0: + dx = BiasActCudaGrad.apply(dy, x, b, y) + + if ctx.needs_input_grad[1]: + db = dx.sum([i for i in range(dx.ndim) if i != dim]) + + return dx, db + + # Backward op. + class BiasActCudaGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format + dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + dy if spec.has_2nd_grad else _null_tensor, + x, b, y) + return dx + + @staticmethod + def backward(ctx, d_dx): # pylint: disable=arguments-differ + d_dx = d_dx.contiguous(memory_format=ctx.memory_format) + dy, x, b, y = ctx.saved_tensors + d_dy = None + d_x = None + d_b = None + d_y = None + + if ctx.needs_input_grad[0]: + d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) + + if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): + d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) + + if spec.has_2nd_grad and ctx.needs_input_grad[2]: + d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) + + return d_dy, d_x, d_b, d_y + + # Add to cache. + _bias_act_cuda_cache[key] = BiasActCuda + return BiasActCuda + +#---------------------------------------------------------------------------- diff --git a/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample.cpp b/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample.cpp new file mode 100644 index 0000000000000000000000000000000000000000..05ba9baa08b8bf0aac3643219306b871f9dee650 --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample.cpp @@ -0,0 +1,57 @@ +#include + +#include + +// CUDA forward declarations + +namespace at {namespace native { +std::vector grid_sample2d_cuda_grad2( + const torch::Tensor &grad2_grad_input, + const torch::Tensor &grad2_grad_grid, + const torch::Tensor &grad_output, + const torch::Tensor &input, + const torch::Tensor &grid, + bool padding_mode, + bool align_corners); +std::vector grid_sample3d_cuda_grad2( + const torch::Tensor &grad2_grad_input, + const torch::Tensor &grad2_grad_grid, + const torch::Tensor &grad_output, + const torch::Tensor &input, + const torch::Tensor &grid, + bool padding_mode, + bool align_corners); +}} + +std::vector grid_sample2d_grad2( + const torch::Tensor &grad2_grad_input, + const torch::Tensor &grad2_grad_grid, + const torch::Tensor &grad_output, + const torch::Tensor &input, + const torch::Tensor &grid, + bool padding_mode, + bool align_corners) { + + return at::native::grid_sample2d_cuda_grad2(grad2_grad_input, grad2_grad_grid, + grad_output, input, grid, padding_mode, align_corners); +} + +std::vector grid_sample3d_grad2( + const torch::Tensor &grad2_grad_input, + const torch::Tensor &grad2_grad_grid, + const torch::Tensor &grad_output, + const torch::Tensor &input, + const torch::Tensor &grid, + bool padding_mode, + bool align_corners) { + + return at::native::grid_sample3d_cuda_grad2(grad2_grad_input, grad2_grad_grid, + grad_output, input, grid, padding_mode, align_corners); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("grad2_2d", &grid_sample2d_grad2, "grid_sample2d second derivative"); + m.def("grad2_3d", &grid_sample3d_grad2, "grid_sample3d second derivative"); +} + diff --git a/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample.cu b/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample.cu new file mode 100644 index 0000000000000000000000000000000000000000..02b5a33918eec1c6baac348b765bb2e3467ce319 --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample.cu @@ -0,0 +1,668 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include + +namespace at { namespace native { +namespace { + +using namespace at::cuda::detail; + +using at::native::detail::GridSamplerInterpolation; +using at::native::detail::GridSamplerPadding; + +template + C10_LAUNCH_BOUNDS_1(256) + __global__ void grid_sampler_2d_grad2_kernel( + const index_t nthreads, + TensorInfo grad2_grad_input, + TensorInfo grad2_grad_grid, + TensorInfo grad_output, + TensorInfo input, + TensorInfo grid, + TensorInfo grad_grad_output, + TensorInfo grad_input, + TensorInfo grad_grid, + const GridSamplerPadding padding_mode, + bool align_corners, + const index_t grad_input_memory_span) { + + index_t C = input.sizes[1]; + index_t inp_H = input.sizes[2]; + index_t inp_W = input.sizes[3]; + + index_t out_H = grid.sizes[1]; + index_t out_W = grid.sizes[2]; + + index_t g2inp_sN = grad2_grad_input.strides[0]; + index_t g2inp_sC = grad2_grad_input.strides[1]; + index_t g2inp_sH = grad2_grad_input.strides[2]; + index_t g2inp_sW = grad2_grad_input.strides[3]; + + index_t g2grid_sN = grad2_grad_grid.strides[0]; + index_t g2grid_sH = grad2_grad_grid.strides[1]; + index_t g2grid_sW = grad2_grad_grid.strides[2]; + index_t g2grid_sCoor = grad2_grad_grid.strides[3]; + + index_t gOut_sN = grad_output.strides[0]; + index_t gOut_sC = grad_output.strides[1]; + index_t gOut_sH = grad_output.strides[2]; + index_t gOut_sW = grad_output.strides[3]; + + index_t inp_sN = input.strides[0]; + index_t inp_sC = input.strides[1]; + index_t inp_sH = input.strides[2]; + index_t inp_sW = input.strides[3]; + + index_t grid_sN = grid.strides[0]; + index_t grid_sH = grid.strides[1]; + index_t grid_sW = grid.strides[2]; + index_t grid_sCoor = grid.strides[3]; + + index_t gInp_sN = grad_input.strides[0]; + index_t gInp_sC = grad_input.strides[1]; + index_t gInp_sH = grad_input.strides[2]; + index_t gInp_sW = grad_input.strides[3]; + + index_t gGrid_sW = grad_grid.strides[2]; + + index_t ggOut_sN = grad_grad_output.strides[0]; + index_t ggOut_sC = grad_grad_output.strides[1]; + index_t ggOut_sH = grad_grad_output.strides[2]; + index_t ggOut_sW = grad_grad_output.strides[3]; + + CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) { + const index_t w = index % out_W; + const index_t h = (index / out_W) % out_H; + const index_t n = index / (out_H * out_W); + + /* Grid related staff */ + index_t grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y co-ordinates from grid + scalar_t x = grid.data[grid_offset]; + scalar_t y = grid.data[grid_offset + grid_sCoor]; + + // multipliers for gradients on ix and iy + scalar_t gix_mult, giy_mult; + scalar_t ix = grid_sampler_compute_source_index_set_grad(x, inp_W, padding_mode, align_corners, &gix_mult); + scalar_t iy = grid_sampler_compute_source_index_set_grad(y, inp_H, padding_mode, align_corners, &giy_mult); + + // get NE, NW, SE, SW pixel values from (x, y) + index_t ix_nw = static_cast(::floor(ix)); + index_t iy_nw = static_cast(::floor(iy)); + index_t ix_ne = ix_nw + 1; + index_t iy_ne = iy_nw; + index_t ix_sw = ix_nw; + index_t iy_sw = iy_nw + 1; + index_t ix_se = ix_nw + 1; + index_t iy_se = iy_nw + 1; + + // get surfaces to each neighbor: + scalar_t nw = (ix_se - ix) * (iy_se - iy); + scalar_t ne = (ix - ix_sw) * (iy_sw - iy); + scalar_t sw = (ix_ne - ix) * (iy - iy_ne); + scalar_t se = (ix - ix_nw) * (iy - iy_nw); + + /* grad2_grad_input related init */ + scalar_t *g2_inp_ptr_NC = grad2_grad_input.data + n * g2inp_sN; + + /* grad2_grad_grid related init */ + grid_offset = n * g2grid_sN + h * g2grid_sH + w * g2grid_sW; + scalar_t dx = grad2_grad_grid.data[grid_offset]; + scalar_t dy = grad2_grad_grid.data[grid_offset + g2grid_sCoor]; + + dx = dx * gix_mult; + dy = dy * giy_mult; + + /* grad_output related init */ + scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW; + + /* input related init */ + scalar_t *inp_ptr_NC = input.data + n * inp_sN; + + /* grad_grad_output related init */ + scalar_t *ggOut_ptr_NCHW = grad_grad_output.data + n * ggOut_sN + h * ggOut_sH + w * ggOut_sW; + + /* grad_input related init */ + index_t NC_offset = n * gInp_sN; + + /* grad_grid related init */ + scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW; + scalar_t gix = static_cast(0), giy = static_cast(0); + + scalar_t nw_val, ne_val, sw_val, se_val; + scalar_t g2_nw_val, g2_ne_val, g2_sw_val, g2_se_val; + + scalar_t zero = static_cast(0); + for (index_t c = 0; c < C; + ++c, + g2_inp_ptr_NC += g2inp_sC, + inp_ptr_NC += inp_sC, + NC_offset += gInp_sC, + gOut_ptr_NCHW += gOut_sC, + ggOut_ptr_NCHW += ggOut_sC) { + + nw_val = within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)? inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW]: zero; + ne_val = within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)? inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW]: zero; + sw_val = within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)? inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW]: zero; + se_val = within_bounds_2d(iy_se, ix_se, inp_H, inp_W)? inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW]: zero; + + g2_nw_val = within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)? g2_inp_ptr_NC[iy_nw * g2inp_sH + ix_nw * g2inp_sW]: zero; + g2_ne_val = within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)? g2_inp_ptr_NC[iy_ne * g2inp_sH + ix_ne * g2inp_sW]: zero; + g2_sw_val = within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)? g2_inp_ptr_NC[iy_sw * g2inp_sH + ix_sw * g2inp_sW]: zero; + g2_se_val = within_bounds_2d(iy_se, ix_se, inp_H, inp_W)? g2_inp_ptr_NC[iy_se * g2inp_sH + ix_se * g2inp_sW]: zero; + + // Computing gradient wrt to grad_output = grad2_grad_input * x * y + grad2_grad_grid_x * y * val + grad2_grad_grid_y * x * val + // grad2_grad_input * x * y + *ggOut_ptr_NCHW = static_cast(0); + *ggOut_ptr_NCHW += g2_nw_val * nw + g2_ne_val * ne + g2_sw_val * sw + g2_se_val * se; + + scalar_t nw_tmp = -dx * (iy_se - iy) - dy * (ix_se - ix); + scalar_t ne_tmp = +dx * (iy_sw - iy) - dy * (ix - ix_sw); + scalar_t sw_tmp = -dx * (iy - iy_ne) + dy * (ix_ne - ix); + scalar_t se_tmp = +dx * (iy - iy_nw) + dy * (ix - ix_nw); + + + // grad2_grad_grid_x * y * val + grad2_grad_grid_y * x * val + *ggOut_ptr_NCHW += nw_val * nw_tmp + ne_tmp * ne_val + sw_tmp * sw_val + se_tmp * se_val; + + // Computing gradient wrt input = grad2_grad_grid_x * grad_output * y + grad2_grad_grid_y * grad_output * x + scalar_t gOut = *gOut_ptr_NCHW; + //scalar_t val; + //val = gOut * (-dx * (iy_se - iy) - dy * (ix_se - ix)); + safe_add_2d(grad_input.data, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, nw_tmp * gOut, NC_offset, grad_input_memory_span); + //val = gOut * (+dx * (iy_sw - iy) - dy * (ix - ix_sw)); + safe_add_2d(grad_input.data, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, ne_tmp * gOut, NC_offset, grad_input_memory_span); + //val = gOut * (-dx * (iy - iy_ne) + dy * (ix_ne - ix)); + safe_add_2d(grad_input.data, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, sw_tmp * gOut, NC_offset, grad_input_memory_span); + //val = gOut * (+dx * (iy - iy_nw) + dy * (ix - ix_nw)); + safe_add_2d(grad_input.data, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, se_tmp * gOut, NC_offset, grad_input_memory_span); + + scalar_t dxy = nw_val - ne_val - sw_val + se_val; + // Computing gradient wrt grid_x = grad2_grad_input * y * gOut + grad2_grad_grid_y * val * gOut + gix += gOut * (-g2_nw_val * (iy_se - iy) + g2_ne_val * (iy_sw - iy) + -g2_sw_val * (iy - iy_ne) + g2_se_val * (iy - iy_nw)); + gix += gOut * dy * dxy; + + // Computing gradient wrt grid_y = grad2_grad_input * x * gOut + grad2_grad_grid_x * val * gOut + giy += gOut * (-g2_nw_val * (ix_se - ix) - g2_ne_val * (ix - ix_sw) + +g2_sw_val * (ix_ne - ix) + g2_se_val * (ix - ix_nw)); + giy += gOut * dx * dxy; + } + + gGrid_ptr_NHW[0] = gix * gix_mult; + gGrid_ptr_NHW[1] = giy * giy_mult; + } +} + +template + C10_LAUNCH_BOUNDS_1(256) + __global__ void grid_sampler_3d_grad2_kernel( + const index_t nthreads, + TensorInfo grad2_grad_input, + TensorInfo grad2_grad_grid, + TensorInfo grad_output, + TensorInfo input, + TensorInfo grid, + TensorInfo grad_grad_output, + TensorInfo grad_input, + TensorInfo grad_grid, + const GridSamplerPadding padding_mode, + bool align_corners, + const index_t grad_input_memory_span) { + + index_t C = input.sizes[1]; + index_t inp_D = input.sizes[2]; + index_t inp_H = input.sizes[3]; + index_t inp_W = input.sizes[4]; + + index_t out_D = grid.sizes[1]; + index_t out_H = grid.sizes[2]; + index_t out_W = grid.sizes[3]; + + index_t g2inp_sN = grad2_grad_input.strides[0]; + index_t g2inp_sC = grad2_grad_input.strides[1]; + index_t g2inp_sD = grad2_grad_input.strides[2]; + index_t g2inp_sH = grad2_grad_input.strides[3]; + index_t g2inp_sW = grad2_grad_input.strides[4]; + + index_t g2grid_sN = grad2_grad_grid.strides[0]; + index_t g2grid_sD = grad2_grad_grid.strides[1]; + index_t g2grid_sH = grad2_grad_grid.strides[2]; + index_t g2grid_sW = grad2_grad_grid.strides[3]; + index_t g2grid_sCoor = grad2_grad_grid.strides[4]; + + index_t gOut_sN = grad_output.strides[0]; + index_t gOut_sC = grad_output.strides[1]; + index_t gOut_sD = grad_output.strides[2]; + index_t gOut_sH = grad_output.strides[3]; + index_t gOut_sW = grad_output.strides[4]; + + index_t inp_sN = input.strides[0]; + index_t inp_sC = input.strides[1]; + index_t inp_sD = input.strides[2]; + index_t inp_sH = input.strides[3]; + index_t inp_sW = input.strides[4]; + + index_t grid_sN = grid.strides[0]; + index_t grid_sD = grid.strides[1]; + index_t grid_sH = grid.strides[2]; + index_t grid_sW = grid.strides[3]; + index_t grid_sCoor = grid.strides[4]; + + index_t gInp_sN = grad_input.strides[0]; + index_t gInp_sC = grad_input.strides[1]; + index_t gInp_sD = grad_input.strides[2]; + index_t gInp_sH = grad_input.strides[3]; + index_t gInp_sW = grad_input.strides[4]; + + index_t gGrid_sW = grad_grid.strides[3]; + + index_t ggOut_sN = grad_grad_output.strides[0]; + index_t ggOut_sC = grad_grad_output.strides[1]; + index_t ggOut_sD = grad_grad_output.strides[2]; + index_t ggOut_sH = grad_grad_output.strides[3]; + index_t ggOut_sW = grad_grad_output.strides[4]; + + CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) { + const index_t w = index % out_W; + const index_t h = (index / out_W) % out_H; + const index_t d = (index / (out_H * out_W)) % out_D; + const index_t n = index / (out_D * out_H * out_W); + + /* Grid related staff */ + index_t grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y co-ordinates from grid + scalar_t ix = grid.data[grid_offset]; + scalar_t iy = grid.data[grid_offset + grid_sCoor]; + scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor]; + + // multipliers for gradients on ix and iy + scalar_t gix_mult, giy_mult, giz_mult; + ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult); + iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult); + iz = grid_sampler_compute_source_index_set_grad(iz, inp_D, padding_mode, align_corners, &giz_mult); + + // get NE, NW, SE, SW pixel values from (x, y) + index_t ix_tnw = static_cast(::floor(ix)); + index_t iy_tnw = static_cast(::floor(iy)); + index_t iz_tnw = static_cast(::floor(iz)); + + index_t ix_tne = ix_tnw + 1; + index_t iy_tne = iy_tnw; + index_t iz_tne = iz_tnw; + + index_t ix_tsw = ix_tnw; + index_t iy_tsw = iy_tnw + 1; + index_t iz_tsw = iz_tnw; + + index_t ix_tse = ix_tnw + 1; + index_t iy_tse = iy_tnw + 1; + index_t iz_tse = iz_tnw; + + index_t ix_bnw = ix_tnw; + index_t iy_bnw = iy_tnw; + index_t iz_bnw = iz_tnw + 1; + + index_t ix_bne = ix_tnw + 1; + index_t iy_bne = iy_tnw; + index_t iz_bne = iz_tnw + 1; + + index_t ix_bsw = ix_tnw; + index_t iy_bsw = iy_tnw + 1; + index_t iz_bsw = iz_tnw + 1; + + index_t ix_bse = ix_tnw + 1; + index_t iy_bse = iy_tnw + 1; + index_t iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + /* grad2_grad_input related init */ + scalar_t *g2_inp_ptr_NC = grad2_grad_input.data + n * g2inp_sN; + + /* grad2_grad_grid related init */ + grid_offset = n * g2grid_sN + d * g2grid_sD + h * g2grid_sH + w * g2grid_sW; + scalar_t dx = grad2_grad_grid.data[grid_offset]; + scalar_t dy = grad2_grad_grid.data[grid_offset + g2grid_sCoor]; + scalar_t dz = grad2_grad_grid.data[grid_offset + 2 * g2grid_sCoor]; + + dx = dx * gix_mult; + dy = dy * giy_mult; + dz = dz * giz_mult; + + /* grad_output related init */ + scalar_t *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW; + + /* input related init */ + scalar_t *inp_ptr_NC = input.data + n * inp_sN; + + /* grad_grad_output related init */ + scalar_t *ggOut_ptr_NCDHW = grad_grad_output.data + n * ggOut_sN + d * ggOut_sD + h * ggOut_sH + w * ggOut_sW; + + /* grad_input related init */ + index_t NC_offset = n * gInp_sN; + + /* grad_grid related init */ + scalar_t *gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sW; + scalar_t gix = static_cast(0), giy = static_cast(0), giz = static_cast(0); + + scalar_t tnw_val, tne_val, tsw_val, tse_val, bnw_val, bne_val, bsw_val, bse_val; + scalar_t g2_tnw_val, g2_tne_val, g2_tsw_val, g2_tse_val, g2_bnw_val, g2_bne_val, g2_bsw_val, g2_bse_val; + + scalar_t zero = static_cast(0); + for (index_t c = 0; c < C; + ++c, + g2_inp_ptr_NC += g2inp_sC, + inp_ptr_NC += inp_sC, + NC_offset += gInp_sC, + gOut_ptr_NCDHW += gOut_sC, + ggOut_ptr_NCDHW += ggOut_sC) { + + if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { + tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW]; + g2_tnw_val = g2_inp_ptr_NC[iz_tnw * g2inp_sD + iy_tnw * g2inp_sH + ix_tnw * g2inp_sW]; + } else { + tnw_val = zero; + g2_tnw_val = zero; + } + if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { + tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW]; + g2_tne_val = g2_inp_ptr_NC[iz_tne * g2inp_sD + iy_tne * g2inp_sH + ix_tne * g2inp_sW]; + } else { + tne_val = zero; + g2_tne_val = zero; + } + if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { + tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW]; + g2_tsw_val = g2_inp_ptr_NC[iz_tsw * g2inp_sD + iy_tsw * g2inp_sH + ix_tsw * g2inp_sW]; + } else { + tsw_val = zero; + g2_tsw_val = zero; + } + if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { + tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW]; + g2_tse_val = g2_inp_ptr_NC[iz_tse * g2inp_sD + iy_tse * g2inp_sH + ix_tse * g2inp_sW]; + } else { + tse_val = zero; + g2_tse_val = zero; + } + + if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { + bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW]; + g2_bnw_val = g2_inp_ptr_NC[iz_bnw * g2inp_sD + iy_bnw * g2inp_sH + ix_bnw * g2inp_sW]; + } else { + bnw_val = zero; + g2_bnw_val = zero; + } + if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { + bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW]; + g2_bne_val = g2_inp_ptr_NC[iz_bne * g2inp_sD + iy_bne * g2inp_sH + ix_bne * g2inp_sW]; + } else { + bne_val = zero; + g2_bne_val = zero; + } + if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { + bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW]; + g2_bsw_val = g2_inp_ptr_NC[iz_bsw * g2inp_sD + iy_bsw * g2inp_sH + ix_bsw * g2inp_sW]; + } else { + bsw_val = zero; + g2_bsw_val = zero; + } + if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { + bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW]; + g2_bse_val = g2_inp_ptr_NC[iz_bse * g2inp_sD + iy_bse * g2inp_sH + ix_bse * g2inp_sW]; + } else { + bse_val = zero; + g2_bse_val = zero; + } + + // Computing gradient wrt to grad_output = + // grad2_grad_input * x * y * z + *ggOut_ptr_NCDHW = static_cast(0); + *ggOut_ptr_NCDHW += g2_tnw_val * tnw + g2_tne_val * tne + g2_tsw_val * tsw + g2_tse_val * tse + +g2_bnw_val * bnw + g2_bne_val * bne + g2_bsw_val * bsw + g2_bse_val * bse; + + // +val * (grad2_grad_grid_x * y * z + grad2_grad_grid_y * x * z + grad2_grad_grid_z * x * y) + scalar_t tnw_tmp = (-dx * (iy_bse - iy) * (iz_bse - iz) - dy * (ix_bse - ix) * (iz_bse - iz) - dz * (ix_bse - ix) * (iy_bse - iy)); + scalar_t tne_tmp = (+dx * (iy_bsw - iy) * (iz_bsw - iz) - dy * (ix - ix_bsw) * (iz_bsw - iz) - dz * (ix - ix_bsw) * (iy_bsw - iy)); + scalar_t tsw_tmp = (-dx * (iy - iy_bne) * (iz_bne - iz) + dy * (ix_bne - ix) * (iz_bne - iz) - dz * (ix_bne - ix) * (iy - iy_bne)); + scalar_t tse_tmp = (+dx * (iy - iy_bnw) * (iz_bnw - iz) + dy * (ix - ix_bnw) * (iz_bnw - iz) - dz * (ix - ix_bnw) * (iy - iy_bnw)); + scalar_t bnw_tmp = (-dx * (iy_tse - iy) * (iz - iz_tse) - dy * (ix_tse - ix) * (iz - iz_tse) + dz * (ix_tse - ix) * (iy_tse - iy)); + scalar_t bne_tmp = (+dx * (iy_tsw - iy) * (iz - iz_tsw) - dy * (ix - ix_tsw) * (iz - iz_tsw) + dz * (ix - ix_tsw) * (iy_tsw - iy)); + scalar_t bsw_tmp = (-dx * (iy - iy_tne) * (iz - iz_tne) + dy * (ix_tne - ix) * (iz - iz_tne) + dz * (ix_tne - ix) * (iy - iy_tne)); + scalar_t bse_tmp = (+dx * (iy - iy_tnw) * (iz - iz_tnw) + dy * (ix - ix_tnw) * (iz - iz_tnw) + dz * (ix - ix_tnw) * (iy - iy_tnw)); + + *ggOut_ptr_NCDHW += tnw_val * tnw_tmp + tne_val * tne_tmp + tsw_val * tsw_tmp + tse_val * tse_tmp + +bnw_val * bnw_tmp + bne_val * bne_tmp + bsw_val * bsw_tmp + bse_val * bse_tmp; + + // Computing gradient wrt input = grad2_grad_grid_x * grad_output * y * z + grad2_grad_grid_y * grad_output * x * z + + // grad2_grad_grid_z * grad_output * y * z + scalar_t gOut = *gOut_ptr_NCDHW; + + safe_add_3d(grad_input.data, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw_tmp * gOut, + NC_offset, grad_input_memory_span); + safe_add_3d(grad_input.data, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne_tmp * gOut, + NC_offset, grad_input_memory_span); + safe_add_3d(grad_input.data, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw_tmp * gOut, + NC_offset, grad_input_memory_span); + safe_add_3d(grad_input.data, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse_tmp * gOut, + NC_offset, grad_input_memory_span); + safe_add_3d(grad_input.data, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw_tmp * gOut, + NC_offset, grad_input_memory_span); + safe_add_3d(grad_input.data, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne_tmp * gOut, + NC_offset, grad_input_memory_span); + safe_add_3d(grad_input.data, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw_tmp * gOut, + NC_offset, grad_input_memory_span); + safe_add_3d(grad_input.data, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse_tmp * gOut, + NC_offset, grad_input_memory_span); + + //Computing gradient wrt grid + scalar_t dxy = (tnw_val * (iz_bse - iz) - tne_val * (iz_bsw - iz) + -tsw_val * (iz_bne - iz) + tse_val * (iz_bnw - iz) + +bnw_val * (iz - iz_tse) - bne_val * (iz - iz_tsw) + -bsw_val * (iz - iz_tne) + bse_val * (iz - iz_tnw)); + + scalar_t dxz = (tnw_val * (iy_bse - iy) - tne_val * (iy_bsw - iy) + +tsw_val * (iy - iy_bne) - tse_val * (iy - iy_bnw) + -bnw_val * (iy_tse - iy) + bne_val * (iy_tsw - iy) + -bsw_val * (iy - iy_tne) + bse_val * (iy - iy_tnw)); + + scalar_t dyz = (tnw_val * (ix_bse - ix) + tne_val * (ix - ix_bsw) + -tsw_val * (ix_bne - ix) - tse_val * (ix - ix_bnw) + -bnw_val * (ix_tse - ix) - bne_val * (ix - ix_tsw) + +bsw_val * (ix_tne - ix) + bse_val * (ix - ix_tnw)); + + + // Computing gradient wrt grid_x = + // grad2_grad_input * z * y * gOut + gix += gOut * (-g2_tnw_val * (iy_bse - iy) * (iz_bse - iz) + g2_tne_val * (iy_bsw - iy) * (iz_bsw - iz) + -g2_tsw_val * (iy - iy_bne) * (iz_bne - iz) + g2_tse_val * (iy - iy_bnw) * (iz_bnw - iz) + -g2_bnw_val * (iy_tse - iy) * (iz - iz_tse) + g2_bne_val * (iy_tsw - iy) * (iz - iz_tsw) + -g2_bsw_val * (iy - iy_tne) * (iz - iz_tne) + g2_bse_val * (iy - iy_tnw) * (iz - iz_tnw)); + + //+ grad2_grad_grid_z * y * val * gOut + grad2_grad_grid_y * z * val * gOut + gix += gOut * (dz * dxz + dy * dxy); + + // Computing gradient wrt grid_y = + // grad2_grad_input * x * z * gOut + giy += gOut * (-g2_tnw_val * (ix_bse - ix) * (iz_bse - iz) - g2_tne_val * (ix - ix_bsw) * (iz_bsw - iz) + +g2_tsw_val * (ix_bne - ix) * (iz_bne - iz) + g2_tse_val * (ix - ix_bnw) * (iz_bnw - iz) + -g2_bnw_val * (ix_tse - ix) * (iz - iz_tse) - g2_bne_val * (ix - ix_tsw) * (iz - iz_tsw) + +g2_bsw_val * (ix_tne - ix) * (iz - iz_tne) + g2_bse_val * (ix - ix_tnw) * (iz - iz_tnw)); + //+ grad2_grad_grid_x * z * val * gOut + grad2_grad_grid_z * x * val * gOut + giy += gOut * (dx * dxy + dz * dyz); + + // Computing gradient wrt grid_z = + // grad2_grad_input * x * y * gOut + giz += gOut * (-g2_tnw_val * (ix_bse - ix) * (iy_bse - iy) - g2_tne_val * (ix - ix_bsw) * (iy_bsw - iy) + -g2_tsw_val * (ix_bne - ix) * (iy - iy_bne) - g2_tse_val * (ix - ix_bnw) * (iy - iy_bnw) + +g2_bnw_val * (ix_tse - ix) * (iy_tse - iy) + g2_bne_val * (ix - ix_tsw) * (iy_tsw - iy) + +g2_bsw_val * (ix_tne - ix) * (iy - iy_tne) + g2_bse_val * (ix - ix_tnw) * (iy - iy_tnw)); + //+ grad2_grad_grid_x * y * val * gOut + grad2_grad_grid_y * x * val * gOut + giz += gOut * (dx * dxz + dy * dyz); + } + + gGrid_ptr_NDHW[0] = gix * gix_mult; + gGrid_ptr_NDHW[1] = giy * giy_mult; + gGrid_ptr_NDHW[2] = giz * giz_mult; + } +}} + + +std::vector grid_sample2d_cuda_grad2( + const torch::Tensor &grad2_grad_input, + const torch::Tensor &grad2_grad_grid, + const torch::Tensor &grad_output, + const torch::Tensor &input, + const torch::Tensor &grid, + bool padding_mode, + bool align_corners) { + + const auto batch_size = input.size(0); + const auto C = input.size(1); + const auto H_IN = input.size(2); + const auto W_IN = input.size(3); + + const auto H_OUT = grid.size(1); + const auto W_OUT = grid.size(2); + + torch::Tensor grad_grad_output = torch::zeros_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + torch::Tensor grad_input = torch::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + torch::Tensor grad_grid = torch::zeros_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + int64_t count = batch_size * H_OUT * W_OUT; + + if (count > 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_grad2_cuda", [&] { + if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) && + canUse32BitIndexMath(grad_output)) { + grid_sampler_2d_grad2_kernel + <<>>( + static_cast(count), + getTensorInfo(grad2_grad_input), + getTensorInfo(grad2_grad_grid), + getTensorInfo(grad_output), + getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(grad_grad_output), + getTensorInfo(grad_input), + getTensorInfo(grad_grid), + static_cast(padding_mode), + align_corners, + static_cast(grad_input.numel())); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + grid_sampler_2d_grad2_kernel + <<>>( + count, + getTensorInfo(grad2_grad_input), + getTensorInfo(grad2_grad_grid), + getTensorInfo(grad_output), + getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(grad_grad_output), + getTensorInfo(grad_input), + getTensorInfo(grad_grid), + static_cast(padding_mode), + align_corners, + grad_input.numel()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + } + + return {grad_grad_output, grad_input, grad_grid}; +} + +std::vector grid_sample3d_cuda_grad2( + const torch::Tensor &grad2_grad_input, + const torch::Tensor &grad2_grad_grid, + const torch::Tensor &grad_output, + const torch::Tensor &input, + const torch::Tensor &grid, + bool padding_mode, + bool align_corners) { + + const auto batch_size = input.size(0); + const auto C = input.size(1); + const auto D_IN = input.size(2); + const auto H_IN = input.size(3); + const auto W_IN = input.size(4); + + const auto D_OUT = grid.size(1); + const auto H_OUT = grid.size(2); + const auto W_OUT = grid.size(3); + + torch::Tensor grad_grad_output = torch::zeros_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + torch::Tensor grad_input = torch::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + torch::Tensor grad_grid = torch::zeros_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + int64_t count = batch_size * D_OUT * H_OUT * W_OUT; + + if (count > 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_grad2_cuda", [&] { + if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) && + canUse32BitIndexMath(grad_output)) { + grid_sampler_3d_grad2_kernel + <<>>( + static_cast(count), + getTensorInfo(grad2_grad_input), + getTensorInfo(grad2_grad_grid), + getTensorInfo(grad_output), + getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(grad_grad_output), + getTensorInfo(grad_input), + getTensorInfo(grad_grid), + static_cast(padding_mode), + align_corners, + static_cast(grad_input.numel())); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + grid_sampler_3d_grad2_kernel + <<>>( + count, + getTensorInfo(grad2_grad_input), + getTensorInfo(grad2_grad_grid), + getTensorInfo(grad_output), + getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(grad_grad_output), + getTensorInfo(grad_input), + getTensorInfo(grad_grid), + static_cast(padding_mode), + align_corners, + grad_input.numel()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + } + + return {grad_grad_output, grad_input, grad_grid}; +} + +}} diff --git a/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample.py b/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..8195ca18e17643cd286348ffabf274330330c5d0 --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample.py @@ -0,0 +1,145 @@ +from torch.utils.cpp_extension import load +import torch +from pkg_resources import parse_version + +from .. import custom_ops +import os + +# gridsample_grad2 = load(name='gridsample_grad2', sources=['third_party/ops/gridsample_cuda.cpp', 'third_party/ops/gridsample_cuda.cu'], verbose=True) + +gridsample_grad2 = load(name='gridsample_grad2', sources=[os.path.join(os.path.dirname(__file__), f) for f in ['grid_sample.cpp', 'grid_sample.cu']], verbose=True) + +gridsample_grad2 = None + +def _init(): + global gridsample_grad2 + if gridsample_grad2 is None: + gridsample_grad2 = custom_ops.get_plugin( + module_name='gridsample_grad2', + sources=['gridsample_cuda.cpp', 'gridsample_cuda.cu'], + headers=None, + source_dir=os.path.dirname(__file__), + extra_cuda_cflags=['--use_fast_math'], + ) + return True + +_init() + +def grid_sample_2d(input, grid, padding_mode='zeros', align_corners=True): + assert padding_mode in ['zeros', 'border'] + return _GridSample2dForward.apply(input, grid, padding_mode, align_corners) + + +def grid_sample_3d(input, grid, padding_mode='zeros', align_corners=True): + assert padding_mode in ['zeros', 'border'] + return _GridSample3dForward.apply(input, grid, padding_mode, align_corners) + + +_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') + + +class _GridSample2dForward(torch.autograd.Function): + @staticmethod + def forward(ctx, input, grid, padding_mode=0, align_corners=True): + assert input.ndim == 4 + assert grid.ndim == 4 + assert input.shape[0] == grid.shape[0] + assert grid.shape[3] == 2 + + output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', + padding_mode=padding_mode, align_corners=align_corners) + ctx.save_for_backward(input, grid) + ctx.padding_mode = ['zeros', 'border'].index(padding_mode) + ctx.align_corners = align_corners + return output + + @staticmethod + def backward(ctx, grad_output): + input, grid = ctx.saved_tensors + grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid, ctx.padding_mode, ctx.align_corners) + return grad_input, grad_grid, None, None + +class _GridSample2dBackward(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input, grid, padding_mode=0, align_corners=True): + op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')[0] + if _use_pytorch_1_11_api: + output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) + # breakpoint() + grad_input, grad_grid = op(grad_output, input, grid, 0, padding_mode, align_corners, output_mask) + else: + grad_input, grad_grid = op(grad_output, input, grid, 0, padding_mode, align_corners) + + ctx.save_for_backward(grad_output, input, grid) + ctx.padding_mode = padding_mode + ctx.align_corners = align_corners + + return grad_input, grad_grid + + @staticmethod + def backward(ctx, grad2_grad_input, grad2_grad_grid): + grad_output, input, grid = ctx.saved_tensors + assert grad_output.is_cuda and input.is_cuda and grid.is_cuda and grad2_grad_input.is_cuda and grad2_grad_grid.is_cuda + out = gridsample_grad2.grad2_2d(grad2_grad_input, grad2_grad_grid, grad_output, + input, grid, ctx.padding_mode, ctx.align_corners) + + grad_grad_output = out[0] + grad_input = out[1] + grad_grid = out[2] + + return grad_grad_output, grad_input, grad_grid, None, None + + +class _GridSample3dForward(torch.autograd.Function): + @staticmethod + def forward(ctx, input, grid, padding_mode=0, align_corners=True): + assert input.ndim == 5 + assert grid.ndim == 5 + assert input.shape[0] == grid.shape[0] + assert grid.shape[4] == 3 + + output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', + padding_mode=padding_mode, align_corners=align_corners) + ctx.save_for_backward(input, grid) + ctx.padding_mode = ['zeros', 'border'].index(padding_mode) + ctx.align_corners = align_corners + return output + + @staticmethod + def backward(ctx, grad_output): + input, grid = ctx.saved_tensors + grad_input, grad_grid = _GridSample3dBackward.apply(grad_output, input, grid, ctx.padding_mode, ctx.align_corners) + return grad_input, grad_grid, None, None + + + +class _GridSample3dBackward(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input, grid, padding_mode=0, align_corners=True): + op = torch._C._jit_get_operation('aten::grid_sampler_3d_backward') + if _use_pytorch_1_11_api: + output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) + grad_input, grad_grid = op(grad_output, input, grid, 0, padding_mode, align_corners, output_mask) + else: + grad_input, grad_grid = op(grad_output, input, grid, 0, padding_mode, align_corners) + + ctx.save_for_backward(grad_output, input, grid) + ctx.padding_mode = padding_mode + ctx.align_corners = align_corners + + return grad_input, grad_grid + + @staticmethod + def backward(ctx, grad2_grad_input, grad2_grad_grid): + grad_output, input, grid = ctx.saved_tensors + assert grad_output.is_cuda and input.is_cuda and grid.is_cuda and grad2_grad_input.is_cuda and grad2_grad_grid.is_cuda + out = gridsample_grad2.grad2_3d(grad2_grad_input, grad2_grad_grid, grad_output, + input, grid, ctx.padding_mode, ctx.align_corners) + + grad_grad_output = out[0] + grad_input = out[1] + grad_grid = out[2] + + return grad_grad_output, grad_input, grad_grid, None, None + + diff --git a/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample_gradfix.py b/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample_gradfix.py new file mode 100644 index 0000000000000000000000000000000000000000..a5d766c16fa3dca90f60abba16d634dfb0510300 --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample_gradfix.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.grid_sample` that +supports arbitrarily high order gradients between the input and output. +Only works on 2D images and assumes +`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" + +import torch + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access + +#---------------------------------------------------------------------------- + +enabled = True # Enable the custom op by setting this to true. + +#---------------------------------------------------------------------------- + +def grid_sample(input, grid): + if _should_use_custom_op(): + return _GridSample2dForward.apply(input, grid) + return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(): + return enabled + +#---------------------------------------------------------------------------- + +class _GridSample2dForward(torch.autograd.Function): + @staticmethod + def forward(ctx, input, grid): + assert input.ndim == 4 + assert grid.ndim == 4 + output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + ctx.save_for_backward(input, grid) + return output + + @staticmethod + def backward(ctx, grad_output): + input, grid = ctx.saved_tensors + grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) + return grad_input, grad_grid + +#---------------------------------------------------------------------------- + +class _GridSample2dBackward(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input, grid): + op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) + ctx.save_for_backward(grid) + return grad_input, grad_grid + + @staticmethod + def backward(ctx, grad2_grad_input, grad2_grad_grid): + _ = grad2_grad_grid # unused + grid, = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None # grad2_grad_input # + grad2_grid = None # grad2_grad_grid # + + if ctx.needs_input_grad[0]: + grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) + + assert not ctx.needs_input_grad[2] + return grad2_grad_output, grad2_input, grad2_grid + +#---------------------------------------------------------------------------- diff --git a/svrm/ldm/modules/rendering_neus/third_party/ops/gridsample_cuda.cpp b/svrm/ldm/modules/rendering_neus/third_party/ops/gridsample_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..05ba9baa08b8bf0aac3643219306b871f9dee650 --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/third_party/ops/gridsample_cuda.cpp @@ -0,0 +1,57 @@ +#include + +#include + +// CUDA forward declarations + +namespace at {namespace native { +std::vector grid_sample2d_cuda_grad2( + const torch::Tensor &grad2_grad_input, + const torch::Tensor &grad2_grad_grid, + const torch::Tensor &grad_output, + const torch::Tensor &input, + const torch::Tensor &grid, + bool padding_mode, + bool align_corners); +std::vector grid_sample3d_cuda_grad2( + const torch::Tensor &grad2_grad_input, + const torch::Tensor &grad2_grad_grid, + const torch::Tensor &grad_output, + const torch::Tensor &input, + const torch::Tensor &grid, + bool padding_mode, + bool align_corners); +}} + +std::vector grid_sample2d_grad2( + const torch::Tensor &grad2_grad_input, + const torch::Tensor &grad2_grad_grid, + const torch::Tensor &grad_output, + const torch::Tensor &input, + const torch::Tensor &grid, + bool padding_mode, + bool align_corners) { + + return at::native::grid_sample2d_cuda_grad2(grad2_grad_input, grad2_grad_grid, + grad_output, input, grid, padding_mode, align_corners); +} + +std::vector grid_sample3d_grad2( + const torch::Tensor &grad2_grad_input, + const torch::Tensor &grad2_grad_grid, + const torch::Tensor &grad_output, + const torch::Tensor &input, + const torch::Tensor &grid, + bool padding_mode, + bool align_corners) { + + return at::native::grid_sample3d_cuda_grad2(grad2_grad_input, grad2_grad_grid, + grad_output, input, grid, padding_mode, align_corners); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("grad2_2d", &grid_sample2d_grad2, "grid_sample2d second derivative"); + m.def("grad2_3d", &grid_sample3d_grad2, "grid_sample3d second derivative"); +} + diff --git a/svrm/ldm/modules/rendering_neus/third_party/ops/gridsample_cuda.cu b/svrm/ldm/modules/rendering_neus/third_party/ops/gridsample_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..02b5a33918eec1c6baac348b765bb2e3467ce319 --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/third_party/ops/gridsample_cuda.cu @@ -0,0 +1,668 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include + +namespace at { namespace native { +namespace { + +using namespace at::cuda::detail; + +using at::native::detail::GridSamplerInterpolation; +using at::native::detail::GridSamplerPadding; + +template + C10_LAUNCH_BOUNDS_1(256) + __global__ void grid_sampler_2d_grad2_kernel( + const index_t nthreads, + TensorInfo grad2_grad_input, + TensorInfo grad2_grad_grid, + TensorInfo grad_output, + TensorInfo input, + TensorInfo grid, + TensorInfo grad_grad_output, + TensorInfo grad_input, + TensorInfo grad_grid, + const GridSamplerPadding padding_mode, + bool align_corners, + const index_t grad_input_memory_span) { + + index_t C = input.sizes[1]; + index_t inp_H = input.sizes[2]; + index_t inp_W = input.sizes[3]; + + index_t out_H = grid.sizes[1]; + index_t out_W = grid.sizes[2]; + + index_t g2inp_sN = grad2_grad_input.strides[0]; + index_t g2inp_sC = grad2_grad_input.strides[1]; + index_t g2inp_sH = grad2_grad_input.strides[2]; + index_t g2inp_sW = grad2_grad_input.strides[3]; + + index_t g2grid_sN = grad2_grad_grid.strides[0]; + index_t g2grid_sH = grad2_grad_grid.strides[1]; + index_t g2grid_sW = grad2_grad_grid.strides[2]; + index_t g2grid_sCoor = grad2_grad_grid.strides[3]; + + index_t gOut_sN = grad_output.strides[0]; + index_t gOut_sC = grad_output.strides[1]; + index_t gOut_sH = grad_output.strides[2]; + index_t gOut_sW = grad_output.strides[3]; + + index_t inp_sN = input.strides[0]; + index_t inp_sC = input.strides[1]; + index_t inp_sH = input.strides[2]; + index_t inp_sW = input.strides[3]; + + index_t grid_sN = grid.strides[0]; + index_t grid_sH = grid.strides[1]; + index_t grid_sW = grid.strides[2]; + index_t grid_sCoor = grid.strides[3]; + + index_t gInp_sN = grad_input.strides[0]; + index_t gInp_sC = grad_input.strides[1]; + index_t gInp_sH = grad_input.strides[2]; + index_t gInp_sW = grad_input.strides[3]; + + index_t gGrid_sW = grad_grid.strides[2]; + + index_t ggOut_sN = grad_grad_output.strides[0]; + index_t ggOut_sC = grad_grad_output.strides[1]; + index_t ggOut_sH = grad_grad_output.strides[2]; + index_t ggOut_sW = grad_grad_output.strides[3]; + + CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) { + const index_t w = index % out_W; + const index_t h = (index / out_W) % out_H; + const index_t n = index / (out_H * out_W); + + /* Grid related staff */ + index_t grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y co-ordinates from grid + scalar_t x = grid.data[grid_offset]; + scalar_t y = grid.data[grid_offset + grid_sCoor]; + + // multipliers for gradients on ix and iy + scalar_t gix_mult, giy_mult; + scalar_t ix = grid_sampler_compute_source_index_set_grad(x, inp_W, padding_mode, align_corners, &gix_mult); + scalar_t iy = grid_sampler_compute_source_index_set_grad(y, inp_H, padding_mode, align_corners, &giy_mult); + + // get NE, NW, SE, SW pixel values from (x, y) + index_t ix_nw = static_cast(::floor(ix)); + index_t iy_nw = static_cast(::floor(iy)); + index_t ix_ne = ix_nw + 1; + index_t iy_ne = iy_nw; + index_t ix_sw = ix_nw; + index_t iy_sw = iy_nw + 1; + index_t ix_se = ix_nw + 1; + index_t iy_se = iy_nw + 1; + + // get surfaces to each neighbor: + scalar_t nw = (ix_se - ix) * (iy_se - iy); + scalar_t ne = (ix - ix_sw) * (iy_sw - iy); + scalar_t sw = (ix_ne - ix) * (iy - iy_ne); + scalar_t se = (ix - ix_nw) * (iy - iy_nw); + + /* grad2_grad_input related init */ + scalar_t *g2_inp_ptr_NC = grad2_grad_input.data + n * g2inp_sN; + + /* grad2_grad_grid related init */ + grid_offset = n * g2grid_sN + h * g2grid_sH + w * g2grid_sW; + scalar_t dx = grad2_grad_grid.data[grid_offset]; + scalar_t dy = grad2_grad_grid.data[grid_offset + g2grid_sCoor]; + + dx = dx * gix_mult; + dy = dy * giy_mult; + + /* grad_output related init */ + scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW; + + /* input related init */ + scalar_t *inp_ptr_NC = input.data + n * inp_sN; + + /* grad_grad_output related init */ + scalar_t *ggOut_ptr_NCHW = grad_grad_output.data + n * ggOut_sN + h * ggOut_sH + w * ggOut_sW; + + /* grad_input related init */ + index_t NC_offset = n * gInp_sN; + + /* grad_grid related init */ + scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW; + scalar_t gix = static_cast(0), giy = static_cast(0); + + scalar_t nw_val, ne_val, sw_val, se_val; + scalar_t g2_nw_val, g2_ne_val, g2_sw_val, g2_se_val; + + scalar_t zero = static_cast(0); + for (index_t c = 0; c < C; + ++c, + g2_inp_ptr_NC += g2inp_sC, + inp_ptr_NC += inp_sC, + NC_offset += gInp_sC, + gOut_ptr_NCHW += gOut_sC, + ggOut_ptr_NCHW += ggOut_sC) { + + nw_val = within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)? inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW]: zero; + ne_val = within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)? inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW]: zero; + sw_val = within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)? inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW]: zero; + se_val = within_bounds_2d(iy_se, ix_se, inp_H, inp_W)? inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW]: zero; + + g2_nw_val = within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)? g2_inp_ptr_NC[iy_nw * g2inp_sH + ix_nw * g2inp_sW]: zero; + g2_ne_val = within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)? g2_inp_ptr_NC[iy_ne * g2inp_sH + ix_ne * g2inp_sW]: zero; + g2_sw_val = within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)? g2_inp_ptr_NC[iy_sw * g2inp_sH + ix_sw * g2inp_sW]: zero; + g2_se_val = within_bounds_2d(iy_se, ix_se, inp_H, inp_W)? g2_inp_ptr_NC[iy_se * g2inp_sH + ix_se * g2inp_sW]: zero; + + // Computing gradient wrt to grad_output = grad2_grad_input * x * y + grad2_grad_grid_x * y * val + grad2_grad_grid_y * x * val + // grad2_grad_input * x * y + *ggOut_ptr_NCHW = static_cast(0); + *ggOut_ptr_NCHW += g2_nw_val * nw + g2_ne_val * ne + g2_sw_val * sw + g2_se_val * se; + + scalar_t nw_tmp = -dx * (iy_se - iy) - dy * (ix_se - ix); + scalar_t ne_tmp = +dx * (iy_sw - iy) - dy * (ix - ix_sw); + scalar_t sw_tmp = -dx * (iy - iy_ne) + dy * (ix_ne - ix); + scalar_t se_tmp = +dx * (iy - iy_nw) + dy * (ix - ix_nw); + + + // grad2_grad_grid_x * y * val + grad2_grad_grid_y * x * val + *ggOut_ptr_NCHW += nw_val * nw_tmp + ne_tmp * ne_val + sw_tmp * sw_val + se_tmp * se_val; + + // Computing gradient wrt input = grad2_grad_grid_x * grad_output * y + grad2_grad_grid_y * grad_output * x + scalar_t gOut = *gOut_ptr_NCHW; + //scalar_t val; + //val = gOut * (-dx * (iy_se - iy) - dy * (ix_se - ix)); + safe_add_2d(grad_input.data, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, nw_tmp * gOut, NC_offset, grad_input_memory_span); + //val = gOut * (+dx * (iy_sw - iy) - dy * (ix - ix_sw)); + safe_add_2d(grad_input.data, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, ne_tmp * gOut, NC_offset, grad_input_memory_span); + //val = gOut * (-dx * (iy - iy_ne) + dy * (ix_ne - ix)); + safe_add_2d(grad_input.data, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, sw_tmp * gOut, NC_offset, grad_input_memory_span); + //val = gOut * (+dx * (iy - iy_nw) + dy * (ix - ix_nw)); + safe_add_2d(grad_input.data, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, se_tmp * gOut, NC_offset, grad_input_memory_span); + + scalar_t dxy = nw_val - ne_val - sw_val + se_val; + // Computing gradient wrt grid_x = grad2_grad_input * y * gOut + grad2_grad_grid_y * val * gOut + gix += gOut * (-g2_nw_val * (iy_se - iy) + g2_ne_val * (iy_sw - iy) + -g2_sw_val * (iy - iy_ne) + g2_se_val * (iy - iy_nw)); + gix += gOut * dy * dxy; + + // Computing gradient wrt grid_y = grad2_grad_input * x * gOut + grad2_grad_grid_x * val * gOut + giy += gOut * (-g2_nw_val * (ix_se - ix) - g2_ne_val * (ix - ix_sw) + +g2_sw_val * (ix_ne - ix) + g2_se_val * (ix - ix_nw)); + giy += gOut * dx * dxy; + } + + gGrid_ptr_NHW[0] = gix * gix_mult; + gGrid_ptr_NHW[1] = giy * giy_mult; + } +} + +template + C10_LAUNCH_BOUNDS_1(256) + __global__ void grid_sampler_3d_grad2_kernel( + const index_t nthreads, + TensorInfo grad2_grad_input, + TensorInfo grad2_grad_grid, + TensorInfo grad_output, + TensorInfo input, + TensorInfo grid, + TensorInfo grad_grad_output, + TensorInfo grad_input, + TensorInfo grad_grid, + const GridSamplerPadding padding_mode, + bool align_corners, + const index_t grad_input_memory_span) { + + index_t C = input.sizes[1]; + index_t inp_D = input.sizes[2]; + index_t inp_H = input.sizes[3]; + index_t inp_W = input.sizes[4]; + + index_t out_D = grid.sizes[1]; + index_t out_H = grid.sizes[2]; + index_t out_W = grid.sizes[3]; + + index_t g2inp_sN = grad2_grad_input.strides[0]; + index_t g2inp_sC = grad2_grad_input.strides[1]; + index_t g2inp_sD = grad2_grad_input.strides[2]; + index_t g2inp_sH = grad2_grad_input.strides[3]; + index_t g2inp_sW = grad2_grad_input.strides[4]; + + index_t g2grid_sN = grad2_grad_grid.strides[0]; + index_t g2grid_sD = grad2_grad_grid.strides[1]; + index_t g2grid_sH = grad2_grad_grid.strides[2]; + index_t g2grid_sW = grad2_grad_grid.strides[3]; + index_t g2grid_sCoor = grad2_grad_grid.strides[4]; + + index_t gOut_sN = grad_output.strides[0]; + index_t gOut_sC = grad_output.strides[1]; + index_t gOut_sD = grad_output.strides[2]; + index_t gOut_sH = grad_output.strides[3]; + index_t gOut_sW = grad_output.strides[4]; + + index_t inp_sN = input.strides[0]; + index_t inp_sC = input.strides[1]; + index_t inp_sD = input.strides[2]; + index_t inp_sH = input.strides[3]; + index_t inp_sW = input.strides[4]; + + index_t grid_sN = grid.strides[0]; + index_t grid_sD = grid.strides[1]; + index_t grid_sH = grid.strides[2]; + index_t grid_sW = grid.strides[3]; + index_t grid_sCoor = grid.strides[4]; + + index_t gInp_sN = grad_input.strides[0]; + index_t gInp_sC = grad_input.strides[1]; + index_t gInp_sD = grad_input.strides[2]; + index_t gInp_sH = grad_input.strides[3]; + index_t gInp_sW = grad_input.strides[4]; + + index_t gGrid_sW = grad_grid.strides[3]; + + index_t ggOut_sN = grad_grad_output.strides[0]; + index_t ggOut_sC = grad_grad_output.strides[1]; + index_t ggOut_sD = grad_grad_output.strides[2]; + index_t ggOut_sH = grad_grad_output.strides[3]; + index_t ggOut_sW = grad_grad_output.strides[4]; + + CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) { + const index_t w = index % out_W; + const index_t h = (index / out_W) % out_H; + const index_t d = (index / (out_H * out_W)) % out_D; + const index_t n = index / (out_D * out_H * out_W); + + /* Grid related staff */ + index_t grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y co-ordinates from grid + scalar_t ix = grid.data[grid_offset]; + scalar_t iy = grid.data[grid_offset + grid_sCoor]; + scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor]; + + // multipliers for gradients on ix and iy + scalar_t gix_mult, giy_mult, giz_mult; + ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult); + iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult); + iz = grid_sampler_compute_source_index_set_grad(iz, inp_D, padding_mode, align_corners, &giz_mult); + + // get NE, NW, SE, SW pixel values from (x, y) + index_t ix_tnw = static_cast(::floor(ix)); + index_t iy_tnw = static_cast(::floor(iy)); + index_t iz_tnw = static_cast(::floor(iz)); + + index_t ix_tne = ix_tnw + 1; + index_t iy_tne = iy_tnw; + index_t iz_tne = iz_tnw; + + index_t ix_tsw = ix_tnw; + index_t iy_tsw = iy_tnw + 1; + index_t iz_tsw = iz_tnw; + + index_t ix_tse = ix_tnw + 1; + index_t iy_tse = iy_tnw + 1; + index_t iz_tse = iz_tnw; + + index_t ix_bnw = ix_tnw; + index_t iy_bnw = iy_tnw; + index_t iz_bnw = iz_tnw + 1; + + index_t ix_bne = ix_tnw + 1; + index_t iy_bne = iy_tnw; + index_t iz_bne = iz_tnw + 1; + + index_t ix_bsw = ix_tnw; + index_t iy_bsw = iy_tnw + 1; + index_t iz_bsw = iz_tnw + 1; + + index_t ix_bse = ix_tnw + 1; + index_t iy_bse = iy_tnw + 1; + index_t iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + /* grad2_grad_input related init */ + scalar_t *g2_inp_ptr_NC = grad2_grad_input.data + n * g2inp_sN; + + /* grad2_grad_grid related init */ + grid_offset = n * g2grid_sN + d * g2grid_sD + h * g2grid_sH + w * g2grid_sW; + scalar_t dx = grad2_grad_grid.data[grid_offset]; + scalar_t dy = grad2_grad_grid.data[grid_offset + g2grid_sCoor]; + scalar_t dz = grad2_grad_grid.data[grid_offset + 2 * g2grid_sCoor]; + + dx = dx * gix_mult; + dy = dy * giy_mult; + dz = dz * giz_mult; + + /* grad_output related init */ + scalar_t *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW; + + /* input related init */ + scalar_t *inp_ptr_NC = input.data + n * inp_sN; + + /* grad_grad_output related init */ + scalar_t *ggOut_ptr_NCDHW = grad_grad_output.data + n * ggOut_sN + d * ggOut_sD + h * ggOut_sH + w * ggOut_sW; + + /* grad_input related init */ + index_t NC_offset = n * gInp_sN; + + /* grad_grid related init */ + scalar_t *gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sW; + scalar_t gix = static_cast(0), giy = static_cast(0), giz = static_cast(0); + + scalar_t tnw_val, tne_val, tsw_val, tse_val, bnw_val, bne_val, bsw_val, bse_val; + scalar_t g2_tnw_val, g2_tne_val, g2_tsw_val, g2_tse_val, g2_bnw_val, g2_bne_val, g2_bsw_val, g2_bse_val; + + scalar_t zero = static_cast(0); + for (index_t c = 0; c < C; + ++c, + g2_inp_ptr_NC += g2inp_sC, + inp_ptr_NC += inp_sC, + NC_offset += gInp_sC, + gOut_ptr_NCDHW += gOut_sC, + ggOut_ptr_NCDHW += ggOut_sC) { + + if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { + tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW]; + g2_tnw_val = g2_inp_ptr_NC[iz_tnw * g2inp_sD + iy_tnw * g2inp_sH + ix_tnw * g2inp_sW]; + } else { + tnw_val = zero; + g2_tnw_val = zero; + } + if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { + tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW]; + g2_tne_val = g2_inp_ptr_NC[iz_tne * g2inp_sD + iy_tne * g2inp_sH + ix_tne * g2inp_sW]; + } else { + tne_val = zero; + g2_tne_val = zero; + } + if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { + tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW]; + g2_tsw_val = g2_inp_ptr_NC[iz_tsw * g2inp_sD + iy_tsw * g2inp_sH + ix_tsw * g2inp_sW]; + } else { + tsw_val = zero; + g2_tsw_val = zero; + } + if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { + tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW]; + g2_tse_val = g2_inp_ptr_NC[iz_tse * g2inp_sD + iy_tse * g2inp_sH + ix_tse * g2inp_sW]; + } else { + tse_val = zero; + g2_tse_val = zero; + } + + if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { + bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW]; + g2_bnw_val = g2_inp_ptr_NC[iz_bnw * g2inp_sD + iy_bnw * g2inp_sH + ix_bnw * g2inp_sW]; + } else { + bnw_val = zero; + g2_bnw_val = zero; + } + if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { + bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW]; + g2_bne_val = g2_inp_ptr_NC[iz_bne * g2inp_sD + iy_bne * g2inp_sH + ix_bne * g2inp_sW]; + } else { + bne_val = zero; + g2_bne_val = zero; + } + if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { + bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW]; + g2_bsw_val = g2_inp_ptr_NC[iz_bsw * g2inp_sD + iy_bsw * g2inp_sH + ix_bsw * g2inp_sW]; + } else { + bsw_val = zero; + g2_bsw_val = zero; + } + if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { + bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW]; + g2_bse_val = g2_inp_ptr_NC[iz_bse * g2inp_sD + iy_bse * g2inp_sH + ix_bse * g2inp_sW]; + } else { + bse_val = zero; + g2_bse_val = zero; + } + + // Computing gradient wrt to grad_output = + // grad2_grad_input * x * y * z + *ggOut_ptr_NCDHW = static_cast(0); + *ggOut_ptr_NCDHW += g2_tnw_val * tnw + g2_tne_val * tne + g2_tsw_val * tsw + g2_tse_val * tse + +g2_bnw_val * bnw + g2_bne_val * bne + g2_bsw_val * bsw + g2_bse_val * bse; + + // +val * (grad2_grad_grid_x * y * z + grad2_grad_grid_y * x * z + grad2_grad_grid_z * x * y) + scalar_t tnw_tmp = (-dx * (iy_bse - iy) * (iz_bse - iz) - dy * (ix_bse - ix) * (iz_bse - iz) - dz * (ix_bse - ix) * (iy_bse - iy)); + scalar_t tne_tmp = (+dx * (iy_bsw - iy) * (iz_bsw - iz) - dy * (ix - ix_bsw) * (iz_bsw - iz) - dz * (ix - ix_bsw) * (iy_bsw - iy)); + scalar_t tsw_tmp = (-dx * (iy - iy_bne) * (iz_bne - iz) + dy * (ix_bne - ix) * (iz_bne - iz) - dz * (ix_bne - ix) * (iy - iy_bne)); + scalar_t tse_tmp = (+dx * (iy - iy_bnw) * (iz_bnw - iz) + dy * (ix - ix_bnw) * (iz_bnw - iz) - dz * (ix - ix_bnw) * (iy - iy_bnw)); + scalar_t bnw_tmp = (-dx * (iy_tse - iy) * (iz - iz_tse) - dy * (ix_tse - ix) * (iz - iz_tse) + dz * (ix_tse - ix) * (iy_tse - iy)); + scalar_t bne_tmp = (+dx * (iy_tsw - iy) * (iz - iz_tsw) - dy * (ix - ix_tsw) * (iz - iz_tsw) + dz * (ix - ix_tsw) * (iy_tsw - iy)); + scalar_t bsw_tmp = (-dx * (iy - iy_tne) * (iz - iz_tne) + dy * (ix_tne - ix) * (iz - iz_tne) + dz * (ix_tne - ix) * (iy - iy_tne)); + scalar_t bse_tmp = (+dx * (iy - iy_tnw) * (iz - iz_tnw) + dy * (ix - ix_tnw) * (iz - iz_tnw) + dz * (ix - ix_tnw) * (iy - iy_tnw)); + + *ggOut_ptr_NCDHW += tnw_val * tnw_tmp + tne_val * tne_tmp + tsw_val * tsw_tmp + tse_val * tse_tmp + +bnw_val * bnw_tmp + bne_val * bne_tmp + bsw_val * bsw_tmp + bse_val * bse_tmp; + + // Computing gradient wrt input = grad2_grad_grid_x * grad_output * y * z + grad2_grad_grid_y * grad_output * x * z + + // grad2_grad_grid_z * grad_output * y * z + scalar_t gOut = *gOut_ptr_NCDHW; + + safe_add_3d(grad_input.data, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw_tmp * gOut, + NC_offset, grad_input_memory_span); + safe_add_3d(grad_input.data, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne_tmp * gOut, + NC_offset, grad_input_memory_span); + safe_add_3d(grad_input.data, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw_tmp * gOut, + NC_offset, grad_input_memory_span); + safe_add_3d(grad_input.data, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse_tmp * gOut, + NC_offset, grad_input_memory_span); + safe_add_3d(grad_input.data, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw_tmp * gOut, + NC_offset, grad_input_memory_span); + safe_add_3d(grad_input.data, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne_tmp * gOut, + NC_offset, grad_input_memory_span); + safe_add_3d(grad_input.data, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw_tmp * gOut, + NC_offset, grad_input_memory_span); + safe_add_3d(grad_input.data, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse_tmp * gOut, + NC_offset, grad_input_memory_span); + + //Computing gradient wrt grid + scalar_t dxy = (tnw_val * (iz_bse - iz) - tne_val * (iz_bsw - iz) + -tsw_val * (iz_bne - iz) + tse_val * (iz_bnw - iz) + +bnw_val * (iz - iz_tse) - bne_val * (iz - iz_tsw) + -bsw_val * (iz - iz_tne) + bse_val * (iz - iz_tnw)); + + scalar_t dxz = (tnw_val * (iy_bse - iy) - tne_val * (iy_bsw - iy) + +tsw_val * (iy - iy_bne) - tse_val * (iy - iy_bnw) + -bnw_val * (iy_tse - iy) + bne_val * (iy_tsw - iy) + -bsw_val * (iy - iy_tne) + bse_val * (iy - iy_tnw)); + + scalar_t dyz = (tnw_val * (ix_bse - ix) + tne_val * (ix - ix_bsw) + -tsw_val * (ix_bne - ix) - tse_val * (ix - ix_bnw) + -bnw_val * (ix_tse - ix) - bne_val * (ix - ix_tsw) + +bsw_val * (ix_tne - ix) + bse_val * (ix - ix_tnw)); + + + // Computing gradient wrt grid_x = + // grad2_grad_input * z * y * gOut + gix += gOut * (-g2_tnw_val * (iy_bse - iy) * (iz_bse - iz) + g2_tne_val * (iy_bsw - iy) * (iz_bsw - iz) + -g2_tsw_val * (iy - iy_bne) * (iz_bne - iz) + g2_tse_val * (iy - iy_bnw) * (iz_bnw - iz) + -g2_bnw_val * (iy_tse - iy) * (iz - iz_tse) + g2_bne_val * (iy_tsw - iy) * (iz - iz_tsw) + -g2_bsw_val * (iy - iy_tne) * (iz - iz_tne) + g2_bse_val * (iy - iy_tnw) * (iz - iz_tnw)); + + //+ grad2_grad_grid_z * y * val * gOut + grad2_grad_grid_y * z * val * gOut + gix += gOut * (dz * dxz + dy * dxy); + + // Computing gradient wrt grid_y = + // grad2_grad_input * x * z * gOut + giy += gOut * (-g2_tnw_val * (ix_bse - ix) * (iz_bse - iz) - g2_tne_val * (ix - ix_bsw) * (iz_bsw - iz) + +g2_tsw_val * (ix_bne - ix) * (iz_bne - iz) + g2_tse_val * (ix - ix_bnw) * (iz_bnw - iz) + -g2_bnw_val * (ix_tse - ix) * (iz - iz_tse) - g2_bne_val * (ix - ix_tsw) * (iz - iz_tsw) + +g2_bsw_val * (ix_tne - ix) * (iz - iz_tne) + g2_bse_val * (ix - ix_tnw) * (iz - iz_tnw)); + //+ grad2_grad_grid_x * z * val * gOut + grad2_grad_grid_z * x * val * gOut + giy += gOut * (dx * dxy + dz * dyz); + + // Computing gradient wrt grid_z = + // grad2_grad_input * x * y * gOut + giz += gOut * (-g2_tnw_val * (ix_bse - ix) * (iy_bse - iy) - g2_tne_val * (ix - ix_bsw) * (iy_bsw - iy) + -g2_tsw_val * (ix_bne - ix) * (iy - iy_bne) - g2_tse_val * (ix - ix_bnw) * (iy - iy_bnw) + +g2_bnw_val * (ix_tse - ix) * (iy_tse - iy) + g2_bne_val * (ix - ix_tsw) * (iy_tsw - iy) + +g2_bsw_val * (ix_tne - ix) * (iy - iy_tne) + g2_bse_val * (ix - ix_tnw) * (iy - iy_tnw)); + //+ grad2_grad_grid_x * y * val * gOut + grad2_grad_grid_y * x * val * gOut + giz += gOut * (dx * dxz + dy * dyz); + } + + gGrid_ptr_NDHW[0] = gix * gix_mult; + gGrid_ptr_NDHW[1] = giy * giy_mult; + gGrid_ptr_NDHW[2] = giz * giz_mult; + } +}} + + +std::vector grid_sample2d_cuda_grad2( + const torch::Tensor &grad2_grad_input, + const torch::Tensor &grad2_grad_grid, + const torch::Tensor &grad_output, + const torch::Tensor &input, + const torch::Tensor &grid, + bool padding_mode, + bool align_corners) { + + const auto batch_size = input.size(0); + const auto C = input.size(1); + const auto H_IN = input.size(2); + const auto W_IN = input.size(3); + + const auto H_OUT = grid.size(1); + const auto W_OUT = grid.size(2); + + torch::Tensor grad_grad_output = torch::zeros_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + torch::Tensor grad_input = torch::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + torch::Tensor grad_grid = torch::zeros_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + int64_t count = batch_size * H_OUT * W_OUT; + + if (count > 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_grad2_cuda", [&] { + if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) && + canUse32BitIndexMath(grad_output)) { + grid_sampler_2d_grad2_kernel + <<>>( + static_cast(count), + getTensorInfo(grad2_grad_input), + getTensorInfo(grad2_grad_grid), + getTensorInfo(grad_output), + getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(grad_grad_output), + getTensorInfo(grad_input), + getTensorInfo(grad_grid), + static_cast(padding_mode), + align_corners, + static_cast(grad_input.numel())); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + grid_sampler_2d_grad2_kernel + <<>>( + count, + getTensorInfo(grad2_grad_input), + getTensorInfo(grad2_grad_grid), + getTensorInfo(grad_output), + getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(grad_grad_output), + getTensorInfo(grad_input), + getTensorInfo(grad_grid), + static_cast(padding_mode), + align_corners, + grad_input.numel()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + } + + return {grad_grad_output, grad_input, grad_grid}; +} + +std::vector grid_sample3d_cuda_grad2( + const torch::Tensor &grad2_grad_input, + const torch::Tensor &grad2_grad_grid, + const torch::Tensor &grad_output, + const torch::Tensor &input, + const torch::Tensor &grid, + bool padding_mode, + bool align_corners) { + + const auto batch_size = input.size(0); + const auto C = input.size(1); + const auto D_IN = input.size(2); + const auto H_IN = input.size(3); + const auto W_IN = input.size(4); + + const auto D_OUT = grid.size(1); + const auto H_OUT = grid.size(2); + const auto W_OUT = grid.size(3); + + torch::Tensor grad_grad_output = torch::zeros_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + torch::Tensor grad_input = torch::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + torch::Tensor grad_grid = torch::zeros_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + int64_t count = batch_size * D_OUT * H_OUT * W_OUT; + + if (count > 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_grad2_cuda", [&] { + if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) && + canUse32BitIndexMath(grad_output)) { + grid_sampler_3d_grad2_kernel + <<>>( + static_cast(count), + getTensorInfo(grad2_grad_input), + getTensorInfo(grad2_grad_grid), + getTensorInfo(grad_output), + getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(grad_grad_output), + getTensorInfo(grad_input), + getTensorInfo(grad_grid), + static_cast(padding_mode), + align_corners, + static_cast(grad_input.numel())); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + grid_sampler_3d_grad2_kernel + <<>>( + count, + getTensorInfo(grad2_grad_input), + getTensorInfo(grad2_grad_grid), + getTensorInfo(grad_output), + getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(grad_grad_output), + getTensorInfo(grad_input), + getTensorInfo(grad_grid), + static_cast(padding_mode), + align_corners, + grad_input.numel()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + } + + return {grad_grad_output, grad_input, grad_grid}; +} + +}} diff --git a/svrm/ldm/modules/rendering_neus/third_party/pytorch_ssim/__init__.py b/svrm/ldm/modules/rendering_neus/third_party/pytorch_ssim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..14d38b1722a7242ecd5353063eecb38ccc3aee02 --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/third_party/pytorch_ssim/__init__.py @@ -0,0 +1,93 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from math import exp + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + + +def _ssim(img1, img2, window, window_size, channel, use_padding, size_average=True): + + if use_padding: + padding_size = window_size // 2 + else: + padding_size = 0 + + mu1 = F.conv2d(img1, window, padding=padding_size, groups=channel) + mu2 = F.conv2d(img2, window, padding=padding_size, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=padding_size, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=padding_size, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=padding_size, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + + +class SSIM(torch.nn.Module): + def __init__(self, window_size=11, use_padding=True, size_average=True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.use_padding = use_padding + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + return _ssim(img1, img2, window, self.window_size, channel, self.use_padding, self.size_average) + + +def ssim(img1, img2, use_padding=True, window_size=11, size_average=True): + """SSIM only defined at intensity channel. For RGB or YUV or other image format, this function computes SSIm at each + channel and averge them. + :param img1: (B, C, H, W) float32 in [0, 1] + :param img2: (B, C, H, W) float32 in [0, 1] + :param use_padding: we use conv2d when we compute mean and var for each patch, this use_padding is for that conv2d. + :param window_size: patch size + :param size_average: + :return: a tensor that contains only one scalar. + """ + (_, channel, _, _) = img1.size() + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, use_padding, size_average) \ No newline at end of file diff --git a/svrm/ldm/modules/rendering_neus/utils/__init__.py b/svrm/ldm/modules/rendering_neus/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3910899b3e0bfba650b294c1dd08a559a234933f --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/utils/__init__.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. diff --git a/svrm/ldm/modules/rendering_neus/utils/math_utils.py b/svrm/ldm/modules/rendering_neus/utils/math_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1ca717cfc9d6fb0926d7d8d165661e60ffc27731 --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/utils/math_utils.py @@ -0,0 +1,118 @@ +# MIT License + +# Copyright (c) 2022 Petr Kellnhofer + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch + +def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor: + """ + Left-multiplies MxM @ NxM. Returns NxM. + """ + res = torch.matmul(vectors4, matrix.T) + return res + + +def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: + """ + Normalize vector lengths. + """ + return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) + +def torch_dot(x: torch.Tensor, y: torch.Tensor): + """ + Dot product of two tensors. + """ + return (x * y).sum(-1) + + +def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length): + """ + Author: Petr Kellnhofer + Intersects rays with the [-1, 1] NDC volume. + Returns min and max distance of entry. + Returns -1 for no intersection. + https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection + """ + o_shape = rays_o.shape + rays_o = rays_o.detach().reshape(-1, 3) + rays_d = rays_d.detach().reshape(-1, 3) + + + bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)] + bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)] + bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device) + is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device) + + # Precompute inverse for stability. + invdir = 1 / rays_d + sign = (invdir < 0).long() + + # Intersect with YZ plane. + tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] + tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] + + # Intersect with XZ plane. + tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] + tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] + + # Resolve parallel rays. + is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False + + # Use the shortest intersection. + tmin = torch.max(tmin, tymin) + tmax = torch.min(tmax, tymax) + + # Intersect with XY plane. + tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] + tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] + + # Resolve parallel rays. + is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False + + # Use the shortest intersection. + tmin = torch.max(tmin, tzmin) + tmax = torch.min(tmax, tzmax) + + # Mark invalid. + tmin[torch.logical_not(is_valid)] = -1 + tmax[torch.logical_not(is_valid)] = -2 + + return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1) + + +def linspace(start: torch.Tensor, stop: torch.Tensor, num: int): + """ + Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive. + Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. + """ + # create a tensor of 'num' steps from 0 to 1 + steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1) + + # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings + # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript + # "cannot statically infer the expected size of a list in this contex", hence the code below + for i in range(start.ndim): + steps = steps.unsqueeze(-1) + + # the output starts at 'start' and increments until 'stop' in each dimension + out = start[None] + steps * (stop - start)[None] + + return out diff --git a/svrm/ldm/modules/rendering_neus/utils/ray_marcher.py b/svrm/ldm/modules/rendering_neus/utils/ray_marcher.py new file mode 100644 index 0000000000000000000000000000000000000000..753300ff6ce30fe3df2fd198f8e6ba4394d7b309 --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/utils/ray_marcher.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. +# +# Modified by Zexin He +# The modifications are subject to the same license as the original. + + +""" +The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths. +Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!) +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class LearnedVariance(nn.Module): + def __init__(self, init_val): + super(LearnedVariance, self).__init__() + self.register_parameter("_inv_std", nn.Parameter(torch.tensor(init_val))) + + @property + def inv_std(self): + val = torch.exp(self._inv_std * 10.0) + return val + + def forward(self, x): + return torch.ones_like(x) * self.inv_std.clamp(1.0e-6, 1.0e6) + + +class MipRayMarcher2(nn.Module): + def __init__(self, activation_factory): + super().__init__() + self.activation_factory = activation_factory + self.variance = LearnedVariance(0.3) + self.cos_anneal_ratio = 1.0 + def get_alpha(self, sdf, normal, dirs, dists): + # sdf: [N 1] normal: [N 3] dirs: [N 3] dists: [N 1] + # import ipdb; ipdb.set_trace() + + inv_std = self.variance(sdf) + + true_cos = (dirs * normal).sum(-1, keepdim=True) + # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes + # the cos value "not dead" at the beginning training iterations, for better convergence. + iter_cos = -( + F.relu(-true_cos * 0.5 + 0.5) * (1.0 - self.cos_anneal_ratio) + + F.relu(-true_cos) * self.cos_anneal_ratio + ) # always non-positive + + # Estimate signed distances at section points + estimated_next_sdf = sdf + iter_cos * dists * 0.5 + estimated_prev_sdf = sdf - iter_cos * dists * 0.5 + + prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_std) + next_cdf = torch.sigmoid(estimated_next_sdf * inv_std) + + p = prev_cdf - next_cdf + c = prev_cdf + + alpha = ((p + 1e-5) / (c + 1e-5)).clip(0.0, 1.0) + return alpha + + def run_forward(self, colors, sdfs, depths, normals, ray_directions, rendering_options, bgcolor=None, real_normals=None): + # depths: [B N_ray*N_sample 1] + # sdfs: [B, N_ray, N_sample 1] + # import ipdb; ipdb.set_trace() + + deltas = depths[:, :, 1:] - depths[:, :, :-1] + colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2 + sdfs_mid = (sdfs[:, :, :-1] + sdfs[:, :, 1:]) / 2 + depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 + normals_mid = (normals[:, :, :-1] + normals[:, :, 1:]) / 2 + + # zhaohx add for normal : + real_normals_mid = (real_normals[:, :, :-1] + real_normals[:, :, 1:]) / 2 + + # # using factory mode for better usability + # densities_mid = self.activation_factory(rendering_options)(densities_mid) + + # density_delta = densities_mid * deltas + + # alpha = 1 - torch.exp(-density_delta) + + # import ipdb; ipdb.set_trace() + dirs = ray_directions.unsqueeze(2).expand(-1, -1, sdfs_mid.shape[-2], -1) + B, N_ray, N_sample, _ = sdfs_mid.shape + alpha = self.get_alpha(sdfs_mid.reshape(-1, 1), normals_mid.reshape(-1, 3), dirs.reshape(-1, 3), deltas.reshape(-1, 1)) + alpha = alpha.reshape(B, N_ray, N_sample, -1) + + alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2) + weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1] + + composite_rgb = torch.sum(weights * colors_mid, -2) + weight_total = weights.sum(2) + composite_depth = torch.sum(weights * depths_mid, -2) / weight_total + + # clip the composite to min/max range of depths + composite_depth = torch.nan_to_num(composite_depth, float('inf')) + composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths)) + # import pdb; pdb.set_trace() + + # zhaohx add for normal : + composite_normal = torch.sum(weights * real_normals_mid, -2) / weight_total + composite_normal = torch.nan_to_num(composite_normal, float('inf')) + composite_normal = torch.clamp(composite_normal, torch.min(real_normals), torch.max(real_normals)) + + if rendering_options.get('white_back', False): + # composite_rgb = composite_rgb + 1 - weight_total + # weight_total[weight_total < 0.5] = 0 + # composite_rgb = composite_rgb * weight_total + 1 - weight_total + # now is this + if bgcolor is None: + composite_rgb = composite_rgb + 1 - weight_total + # composite_rgb = composite_rgb * weight_total + 1 - weight_total + else: + # import pdb; pdb.set_trace() + bgcolor = bgcolor.permute(0, 2, 3, 1).contiguous().view(composite_rgb.shape[0], -1, composite_rgb.shape[-1]) + composite_rgb = composite_rgb + (1 - weight_total) * bgcolor + # composite_rgb = composite_rgb * weight_total + (1 - weight_total) * bgcolor + # composite_rgb = composite_rgb + # print('new white_back') + + # rendered value scale is 0-1, comment out original mipnerf scaling + # composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1) + + return composite_rgb, composite_depth, weights, composite_normal + + + def forward(self, colors, sdfs, depths, normals, ray_directions, rendering_options, bgcolor=None, real_normals=None): + composite_rgb, composite_depth, weights, composite_normal = self.run_forward(colors, sdfs, depths, normals, ray_directions, rendering_options, bgcolor, real_normals) + + return composite_rgb, composite_depth, weights, composite_normal diff --git a/svrm/ldm/modules/rendering_neus/utils/ray_sampler.py b/svrm/ldm/modules/rendering_neus/utils/ray_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..7501478dbac004c3d0a9c0366f989af0b52ef5db --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/utils/ray_sampler.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. +# +# Modified by Zexin He +# The modifications are subject to the same license as the original. + + +""" +The ray sampler is a module that takes in camera matrices and resolution and batches of rays. +Expects cam2world matrices that use the OpenCV camera coordinate system conventions. +""" + +import torch + +class RaySampler(torch.nn.Module): + def __init__(self): + super().__init__() + self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None + + + def forward(self, cam2world_matrix, intrinsics, render_size): + """ + Create batches of rays and return origins and directions. + + cam2world_matrix: (N, 4, 4) + intrinsics: (N, 3, 3) + render_size: int + + ray_origins: (N, M, 3) + ray_dirs: (N, M, 2) + """ + + N, M = cam2world_matrix.shape[0], render_size**2 + cam_locs_world = cam2world_matrix[:, :3, 3] + fx = intrinsics[:, 0, 0] + fy = intrinsics[:, 1, 1] + cx = intrinsics[:, 0, 2] + cy = intrinsics[:, 1, 2] + sk = intrinsics[:, 0, 1] + + uv = torch.stack(torch.meshgrid( + torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device), + torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device), + indexing='ij', + )) + uv = uv.flip(0).reshape(2, -1).transpose(1, 0) + uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) + + x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size) + y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size) + z_cam = torch.ones((N, M), device=cam2world_matrix.device) + + x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam + y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam + + cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1) + + _opencv2blender = torch.tensor([ + [1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1], + ], dtype=torch.float32, device=cam2world_matrix.device).unsqueeze(0).repeat(N, 1, 1) + + cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender) + + world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] + + ray_dirs = world_rel_points - cam_locs_world[:, None, :] + ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2) + + ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1) + + return ray_origins, ray_dirs diff --git a/svrm/ldm/modules/rendering_neus/utils/renderer.py b/svrm/ldm/modules/rendering_neus/utils/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..e315c72e47a6a4437bb76e40bcc29d38cd7e2c2c --- /dev/null +++ b/svrm/ldm/modules/rendering_neus/utils/renderer.py @@ -0,0 +1,331 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. +# +# Modified by Zexin He +# The modifications are subject to the same license as the original. + + +""" +The renderer is a module that takes in rays, decides where to sample along each +ray, and computes pixel colors using the volume rendering equation. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .ray_marcher import MipRayMarcher2 +from . import math_utils +# from ldm.modules.rendering_neus.third_party.ops import grid_sample + +def generate_planes(): + """ + Defines planes by the three vectors that form the "axes" of the + plane. Should work with arbitrary number of planes and planes of + arbitrary orientation. + + Bugfix reference: https://github.com/NVlabs/eg3d/issues/67 + """ + return torch.tensor([[[1, 0, 0], + [0, 1, 0], + [0, 0, 1]], + [[1, 0, 0], + [0, 0, 1], + [0, 1, 0]], + [[0, 0, 1], + [0, 1, 0], + [1, 0, 0]]], dtype=torch.float32) + +def project_onto_planes(planes, coordinates): + """ + Does a projection of a 3D point onto a batch of 2D planes, + returning 2D plane coordinates. + + Takes plane axes of shape n_planes, 3, 3 + # Takes coordinates of shape N, M, 3 + # returns projections of shape N*n_planes, M, 2 + """ + N, M, C = coordinates.shape + n_planes, _, _ = planes.shape + coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3) + inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3) + projections = torch.bmm(coordinates, inv_planes) + return projections[..., :2] + +def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None): + assert padding_mode == 'zeros' + N, n_planes, C, H, W = plane_features.shape + _, M, _ = coordinates.shape + plane_features = plane_features.view(N*n_planes, C, H, W) + + coordinates = (2/box_warp) * coordinates # add specific box bounds + # print(coordinates.max(), coordinates.min()) + # import pdb; pdb.set_trace() + + projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1) + # output_features = grid_sample.grid_sample_2d(plane_features, projected_coordinates.float().to(plane_features.device)).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) + + output_features = torch.nn.functional.grid_sample(plane_features.float(), projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) + return output_features + +def sample_from_3dgrid(grid, coordinates): + """ + Expects coordinates in shape (batch_size, num_points_per_batch, 3) + Expects grid in shape (1, channels, H, W, D) + (Also works if grid has batch size) + Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels) + """ + batch_size, n_coords, n_dims = coordinates.shape + sampled_features = torch.nn.functional.grid_sample(grid.expand(batch_size, -1, -1, -1, -1), + coordinates.reshape(batch_size, 1, 1, -1, n_dims), + mode='bilinear', padding_mode='zeros', align_corners=False) + N, C, H, W, D = sampled_features.shape + sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C) + return sampled_features + +class ImportanceRenderer(torch.nn.Module): + """ + Modified original version to filter out-of-box samples as TensoRF does. + + Reference: + TensoRF: https://github.com/apchenstu/TensoRF/blob/main/models/tensorBase.py#L277 + """ + def __init__(self): + super().__init__() + self.activation_factory = self._build_activation_factory() + self.ray_marcher = MipRayMarcher2(self.activation_factory) + self.plane_axes = generate_planes() + + def _build_activation_factory(self): + def activation_factory(options: dict): + if options['clamp_mode'] == 'softplus': + return lambda x: F.softplus(x - 1) # activation bias of -1 makes things initialize better + else: + assert False, "Renderer only supports `clamp_mode`=`softplus`!" + return activation_factory + + def _forward_pass(self, depths: torch.Tensor, ray_directions: torch.Tensor, ray_origins: torch.Tensor, + planes: torch.Tensor, decoder: nn.Module, rendering_options: dict): + """ + Additional filtering is applied to filter out-of-box samples. + Modifications made by Zexin He. + """ + + # context related variables + batch_size, num_rays, samples_per_ray, _ = depths.shape + device = depths.device + + # define sample points with depths + sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3) + sample_coordinates = (ray_origins.unsqueeze(-2) + depths * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3) + # print(f'min bbox: {sample_coordinates.min()}, max bbox: {sample_coordinates.max()}') + # import pdb; pdb.set_trace() + # filter out-of-box samples + mask_inbox = \ + (rendering_options['sampler_bbox_min'] <= sample_coordinates) & \ + (sample_coordinates <= rendering_options['sampler_bbox_max']) + mask_inbox = mask_inbox.all(-1) + + # forward model according to all samples + _out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options) + + # set out-of-box samples to zeros(rgb) & -inf(sigma) + SAFE_GUARD = 3 + DATA_TYPE = _out['sdf'].dtype + colors_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE) + normals_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE) + sdfs_pass = torch.nan_to_num(torch.full((batch_size, num_rays * samples_per_ray, 1), -float('inf'), device=device, dtype=DATA_TYPE)) / SAFE_GUARD + + # print(DATA_TYPE) + # import pdb; pdb.set_trace() + # colors_pass[mask_inbox], sdfs_pass[mask_inbox] = _out['rgb'][mask_inbox], _out['sdf'][mask_inbox] + colors_pass[mask_inbox], sdfs_pass = _out['rgb'][mask_inbox], _out['sdf'] + normals_pass = _out['normal'] + + # reshape back + colors_pass = colors_pass.reshape(batch_size, num_rays, samples_per_ray, colors_pass.shape[-1]) + sdfs_pass = sdfs_pass.reshape(batch_size, num_rays, samples_per_ray, sdfs_pass.shape[-1]) + normals_pass = normals_pass.reshape(batch_size, num_rays, samples_per_ray, normals_pass.shape[-1]) + + return colors_pass, sdfs_pass, normals_pass, _out['sdf_grad'] + + def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options, bgcolor=None): + # self.plane_axes = self.plane_axes.to(ray_origins.device) + + if rendering_options['ray_start'] == 'auto' == rendering_options['ray_end']: + ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp']) # [1, N_ray, 1] + is_ray_valid = ray_end > ray_start + if torch.any(is_ray_valid).item(): + ray_start[~is_ray_valid] = ray_start[is_ray_valid].min() + ray_end[~is_ray_valid] = ray_start[is_ray_valid].max() + depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) # [1, N_ray, N_sample, 1]】 + else: + # Create stratified depth samples + depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) + + + # Coarse Pass + colors_coarse, sdfs_coarse, normals_coarse, sdf_grad = self._forward_pass( + depths=depths_coarse, ray_directions=ray_directions, ray_origins=ray_origins, + planes=planes, decoder=decoder, rendering_options=rendering_options) + + + # Fine Pass + N_importance = rendering_options['depth_resolution_importance'] + # TODO + if N_importance > 0: + _, _, weights = self.ray_marcher(colors_coarse, sdfs_coarse, depths_coarse, sdf_grad.reshape(*normals_coarse.shape), ray_directions, rendering_options, bgcolor) + + depths_fine = self.sample_importance(depths_coarse, weights, N_importance) + + colors_fine, densities_fine = self._forward_pass( + depths=depths_fine, ray_directions=ray_directions, ray_origins=ray_origins, + planes=planes, decoder=decoder, rendering_options=rendering_options) + + all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse, + depths_fine, colors_fine, densities_fine) + #### + # dists = depths_coarse[:, :, 1:, :] - depths_coarse[:, :, :-1, :] + # inter = (ray_end - ray_start) / ( rendering_options['depth_resolution'] + rendering_options['depth_resolution_importance'] - 1) # [1, N_ray, 1] + # dists = torch.cat([dists, inter.unsqueeze(2), 2]) + #### + + # Aggregate + rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options, bgcolor) + else: + # # import pdb; pdb.set_trace() + # dists = depths_coarse[:, :, 1:, :] - depths_coarse[:, :, :-1, :] + # inter = (ray_end - ray_start) / ( rendering_options['depth_resolution'] - 1) # [1, N_ray, 1] + # dists = torch.cat([dists, inter.unsqueeze(2)], 2) + # # import ipdb; ipdb.set_trace() + + # rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, sdfs_coarse, depths_coarse, normals_coarse, dists, ray_directions, rendering_options, bgcolor) + rgb_final, depth_final, weights, normal_final = self.ray_marcher(colors_coarse, sdfs_coarse, depths_coarse, sdf_grad.reshape(*normals_coarse.shape), ray_directions, rendering_options, bgcolor, normals_coarse) + # import ipdb; ipdb.set_trace() + + return rgb_final, depth_final, weights.sum(2), sdf_grad, normal_final + + def run_model(self, planes, decoder, sample_coordinates, sample_directions, options): + plane_axes = self.plane_axes.to(planes.device) + + out = decoder(sample_directions, sample_coordinates, plane_axes, planes, options) + # if options.get('density_noise', 0) > 0: + # out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise'] + return out + + def run_model_activated(self, planes, decoder, sample_coordinates, sample_directions, options): + out = self.run_model(planes, decoder, sample_coordinates, sample_directions, options) + out['sigma'] = self.activation_factory(options)(out['sigma']) + return out + + def sort_samples(self, all_depths, all_colors, all_densities): + _, indices = torch.sort(all_depths, dim=-2) + all_depths = torch.gather(all_depths, -2, indices) + all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) + return all_depths, all_colors, all_densities + + def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2): + all_depths = torch.cat([depths1, depths2], dim = -2) + all_colors = torch.cat([colors1, colors2], dim = -2) + all_densities = torch.cat([densities1, densities2], dim = -2) + + _, indices = torch.sort(all_depths, dim=-2) + all_depths = torch.gather(all_depths, -2, indices) + all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) + + return all_depths, all_colors, all_densities + + def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False): + """ + Return depths of approximately uniformly spaced samples along rays. + """ + N, M, _ = ray_origins.shape + if disparity_space_sampling: + depths_coarse = torch.linspace(0, + 1, + depth_resolution, + device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) + depth_delta = 1/(depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta + depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse) + else: + if type(ray_start) == torch.Tensor: + depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3) + depth_delta = (ray_end - ray_start) / (depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None] + else: + depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) + depth_delta = (ray_end - ray_start)/(depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta + + return depths_coarse + + def sample_importance(self, z_vals, weights, N_importance): + """ + Return depths of importance sampled points along rays. See NeRF importance sampling for more. + """ + with torch.no_grad(): + batch_size, num_rays, samples_per_ray, _ = z_vals.shape + + z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray) + weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher + + # smooth weights + weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1).float(), 2, 1, padding=1) + weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze() + weights = weights + 0.01 + + z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:]) + importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1], + N_importance).detach().reshape(batch_size, num_rays, N_importance, 1) + return importance_z_vals + + def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5): + """ + Sample @N_importance samples from @bins with distribution defined by @weights. + Inputs: + bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" + weights: (N_rays, N_samples_) + N_importance: the number of samples to draw from the distribution + det: deterministic or not + eps: a small number to prevent division by zero + Outputs: + samples: the sampled samples + """ + N_rays, N_samples_ = weights.shape + weights = weights + eps # prevent division by zero (don't do inplace op!) + pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_) + cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function + cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1) + # padded to 0~1 inclusive + + if det: + u = torch.linspace(0, 1, N_importance, device=bins.device) + u = u.expand(N_rays, N_importance) + else: + u = torch.rand(N_rays, N_importance, device=bins.device) + u = u.contiguous() + + inds = torch.searchsorted(cdf, u, right=True) + below = torch.clamp_min(inds-1, 0) + above = torch.clamp_max(inds, N_samples_) + + inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance) + cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2) + bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2) + + denom = cdf_g[...,1]-cdf_g[...,0] + denom[denom b c h w') + h = self.upsampler(h) + h = rearrange(h, '(b d) c h w-> b d c h w', d=3) + h = h.type(x.dtype) + return h + else: + h = self.upsampler(h) #[b, h, w, triplane_dim*4] + b, height, width, _ = h.shape + h = h.view(b, height, width, self.triplane_dim, self.upsample_ratio, self.upsample_ratio) #[b, h, w, triplane_dim, 2, 2] + h = h.permute(0,3,1,4,2,5).contiguous() #[b, triplane_dim, h, 2, w, 2] + h = h.view(b, self.triplane_dim, height*self.upsample_ratio, width*self.upsample_ratio) + h = rearrange(h, '(b d) c h w-> b d c h w', d=3) + h = h.type(x.dtype) + return h diff --git a/svrm/ldm/modules/x_transformer.py b/svrm/ldm/modules/x_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..0c7a4d4f2bd89ce015bc7896e73b55772e508a66 --- /dev/null +++ b/svrm/ldm/modules/x_transformer.py @@ -0,0 +1,641 @@ +"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial +from inspect import isfunction +from collections import namedtuple +from einops import rearrange, repeat, reduce + +# constants + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple('Intermediates', [ + 'pre_softmax_attn', + 'post_softmax_attn' +]) + +LayerIntermediates = namedtuple('Intermediates', [ + 'hiddens', + 'attn_intermediates' +]) + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self.init_() + + def init_(self): + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + + +# helpers + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def always(val): + def inner(*args, **kwargs): + return val + return inner + + +def not_equals(val): + def inner(x): + return x != val + return inner + + +def equals(val): + def inner(x): + return x == val + return inner + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +# keyword argument helpers + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# classes +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.value, *rest) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class Residual(nn.Module): + def forward(self, x, residual): + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + + def forward(self, x, residual): + gated_output = self.gru( + rearrange(x, 'b n d -> (b n) d'), + rearrange(residual, 'b n d -> (b n) d') + ) + + return gated_output.reshape_as(x) + + +# feedforward + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +# attention. +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0., + on_attn=False + ): + super().__init__() + if use_entmax15: + raise NotImplementedError("Check out entmax activation instead of softmax activation!") + self.scale = dim_head ** -0.5 + self.heads = heads + self.causal = causal + self.mask = mask + + inner_dim = dim_head * heads + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + #self.attn_fn = entmax15 if use_entmax15 else F.softmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None + ): + b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, 'b i -> b () i ()') + k_mask = rearrange(k_mask, 'b j -> b () () j') + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots + + if talking_heads: + dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + + if exists(rel_pos): + dots = rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, + post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) + attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + + dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + + self.has_pos_emb = position_infused_attn + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.rotary_pos_emb = always(None) + + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + self.rel_pos = None + + self.pre_norm = pre_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ('a', 'c', 'f') + elif cross_attend and only_cross: + default_block = ('c', 'f') + else: + default_block = ('a', 'f') + + if macaron: + default_block = ('f',) + default_block + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, 'par ratio out of range' + default_block = tuple(filter(not_equals('f'), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f',) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ('f',) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' + layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + + for layer_type in self.layer_types: + if layer_type == 'a': + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == 'c': + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == 'f': + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f'invalid layer type {layer_type}') + + if isinstance(layer, Attention) and exists(branch_fn): + layer = branch_fn(layer) + + if gate_residual: + residual_fn = GRUGating(dim) + else: + residual_fn = Residual() + + self.layers.append(nn.ModuleList([ + norm_fn(), + layer, + residual_fn + ])) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + return_hiddens=False + ): + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + is_last = ind == (len(self.layers) - 1) + + if layer_type == 'a': + hiddens.append(x) + layer_mem = mems.pop(0) + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == 'a': + out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, + prev_attn=prev_attn, mem=layer_mem) + elif layer_type == 'c': + out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) + elif layer_type == 'f': + out = block(x) + + x = residual_fn(out, residual) + + if layer_type in ('a', 'c'): + intermediates.append(inter) + + if layer_type == 'a' and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == 'c' and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if not self.pre_norm and not is_last: + x = norm(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, + attn_intermediates=intermediates + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on encoder' + super().__init__(causal=False, **kwargs) + + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0., + emb_dropout=0., + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.num_tokens = num_tokens + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + # let funnel encoder know number of memory tokens, if specified + if hasattr(attn_layers, 'num_memory_tokens'): + attn_layers.num_memory_tokens = num_memory_tokens + + def init_(self): + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + **kwargs + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_mems: + hiddens = intermediates.hiddens + new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens + new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + return out, new_mems + + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + return out, attn_maps + + return out + diff --git a/svrm/ldm/util.py b/svrm/ldm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..1b32ee92cc886d3d949abbfb9fea9c9633cb4a6c --- /dev/null +++ b/svrm/ldm/util.py @@ -0,0 +1,252 @@ +import os +import importlib +from inspect import isfunction +import cv2 +import time +import numpy as np +from PIL import Image, ImageDraw, ImageFont +import matplotlib.pyplot as plt +import torch +from torch import optim +import torchvision + + +def pil_rectangle_crop(im): + width, height = im.size # Get dimensions + + if width <= height: + left = 0 + right = width + top = (height - width)/2 + bottom = (height + width)/2 + else: + + top = 0 + bottom = height + left = (width - height) / 2 + bottom = (width + height) / 2 + + # Crop the center of the image + im = im.crop((left, top, right, bottom)) + return im + + +def add_margin(pil_img, color, size=256): + width, height = pil_img.size + result = Image.new(pil_img.mode, (size, size), color) + result.paste(pil_img, ((size - width) // 2, (size - height) // 2)) + return result + + +def load_and_preprocess(interface, input_im): + ''' + :param input_im (PIL Image). + :return image (H, W, 3) array in [0, 1]. + ''' + # See https://github.com/Ir1d/image-background-remove-tool + image = input_im.convert('RGB') + + image_without_background = interface([image])[0] + image_without_background = np.array(image_without_background) + est_seg = image_without_background > 127 + image = np.array(image) + foreground = est_seg[:, : , -1].astype(np.bool_) + image[~foreground] = [255., 255., 255.] + x, y, w, h = cv2.boundingRect(foreground.astype(np.uint8)) + image = image[y:y+h, x:x+w, :] + image = Image.fromarray(np.array(image)) + + # resize image such that long edge is 512 + image.thumbnail([200, 200], Image.Resampling.LANCZOS) + image = add_margin(image, (255, 255, 255), size=256) + image = np.array(image) + return image + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x,torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +class AdamWwithEMAandWings(optim.Optimizer): + # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 + def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using + weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code + ema_power=1., param_names=()): + """AdamW that saves EMA versions of the parameters.""" + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= ema_decay <= 1.0: + raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, + ema_power=ema_power, param_names=param_names) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + ema_params_with_grad = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + ema_decay = group['ema_decay'] + ema_power = group['ema_power'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of parameter values + state['param_exp_avg'] = p.detach().float().clone() + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + ema_params_with_grad.append(state['param_exp_avg']) + + if amsgrad: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + + optim._functional.adamw(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + maximize=False) + + cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) + for param, ema_param in zip(params_with_grad, ema_params_with_grad): + ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) + + return loss diff --git a/svrm/ldm/utils/ops.py b/svrm/ldm/utils/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..628f2d6208aa457b40b6292c73dbceda0fc7d4a4 --- /dev/null +++ b/svrm/ldm/utils/ops.py @@ -0,0 +1,538 @@ +import os +import math +import numpy as np +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd +from igl import fast_winding_number_for_meshes, point_mesh_squared_distance, read_obj + +from .typing import * + + +def get_rank(): + # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, + # therefore LOCAL_RANK needs to be checked first + rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") + for key in rank_keys: + rank = os.environ.get(key) + if rank is not None: + return int(rank) + return 0 + +def dot(x, y): + return torch.sum(x * y, -1, keepdim=True) + + +def reflect(x, n): + return 2 * dot(x, n) * n - x + + +ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]] + + +def scale_tensor( + dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale +): + if inp_scale is None: + inp_scale = (0, 1) + if tgt_scale is None: + tgt_scale = (0, 1) + if isinstance(tgt_scale, Tensor): + assert dat.shape[-1] == tgt_scale.shape[-1] + dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) + dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] + return dat + + +class _TruncExp(Function): # pylint: disable=abstract-method + # Implementation from torch-ngp: + # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, x): # pylint: disable=arguments-differ + ctx.save_for_backward(x) + return torch.exp(x) + + @staticmethod + @custom_bwd + def backward(ctx, g): # pylint: disable=arguments-differ + x = ctx.saved_tensors[0] + return g * torch.exp(torch.clamp(x, max=15)) + + +class SpecifyGradient(Function): + # Implementation from stable-dreamfusion + # https://github.com/ashawkey/stable-dreamfusion + @staticmethod + @custom_fwd + def forward(ctx, input_tensor, gt_grad): + ctx.save_for_backward(gt_grad) + # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward. + return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype) + + @staticmethod + @custom_bwd + def backward(ctx, grad_scale): + (gt_grad,) = ctx.saved_tensors + gt_grad = gt_grad * grad_scale + return gt_grad, None + + +trunc_exp = _TruncExp.apply + + +def get_activation(name) -> Callable: + if name is None: + return lambda x: x + name = name.lower() + if name == "none": + return lambda x: x + elif name == "lin2srgb": + return lambda x: torch.where( + x > 0.0031308, + torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055, + 12.92 * x, + ).clamp(0.0, 1.0) + elif name == "exp": + return lambda x: torch.exp(x) + elif name == "shifted_exp": + return lambda x: torch.exp(x - 1.0) + elif name == "trunc_exp": + return trunc_exp + elif name == "shifted_trunc_exp": + return lambda x: trunc_exp(x - 1.0) + elif name == "sigmoid": + return lambda x: torch.sigmoid(x) + elif name == "tanh": + return lambda x: torch.tanh(x) + elif name == "shifted_softplus": + return lambda x: F.softplus(x - 1.0) + elif name == "scale_-11_01": + return lambda x: x * 0.5 + 0.5 + else: + try: + return getattr(F, name) + except AttributeError: + raise ValueError(f"Unknown activation function: {name}") + + +def chunk_batch(func: Callable, chunk_size: int, triplane=None, *args, **kwargs) -> Any: + if chunk_size <= 0: + return func(*args, **kwargs) + B = None + for arg in list(args) + list(kwargs.values()): + if isinstance(arg, torch.Tensor): + B = arg.shape[0] + break + assert ( + B is not None + ), "No tensor found in args or kwargs, cannot determine batch size." + out = defaultdict(list) + out_type = None + # max(1, B) to support B == 0 + for i in range(0, max(1, B), chunk_size): + if triplane is not None: + out_chunk = func(triplane=triplane, + *[ + arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg + for arg in args + ], + **{ + k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg + for k, arg in kwargs.items() + }, + ) + else: + out_chunk = func( + *[ + arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg + for arg in args + ], + **{ + k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg + for k, arg in kwargs.items() + }, + ) + if out_chunk is None: + continue + out_type = type(out_chunk) + if isinstance(out_chunk, torch.Tensor): + out_chunk = {0: out_chunk} + elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list): + chunk_length = len(out_chunk) + out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)} + elif isinstance(out_chunk, dict): + pass + else: + print( + f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}." + ) + exit(1) + for k, v in out_chunk.items(): + v = v if torch.is_grad_enabled() else v.detach() + out[k].append(v) + + if out_type is None: + return None + + out_merged: Dict[Any, Optional[torch.Tensor]] = {} + for k, v in out.items(): + if all([vv is None for vv in v]): + # allow None in return value + out_merged[k] = None + elif all([isinstance(vv, torch.Tensor) for vv in v]): + out_merged[k] = torch.cat(v, dim=0) + else: + raise TypeError( + f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}" + ) + + if out_type is torch.Tensor: + return out_merged[0] + elif out_type in [tuple, list]: + return out_type([out_merged[i] for i in range(chunk_length)]) + elif out_type is dict: + return out_merged + + +def get_ray_directions( + H: int, + W: int, + focal: Union[float, Tuple[float, float]], + principal: Optional[Tuple[float, float]] = None, + use_pixel_centers: bool = True, +) -> Float[Tensor, "H W 3"]: + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + + Inputs: + H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + pixel_center = 0.5 if use_pixel_centers else 0 + + if isinstance(focal, float): + fx, fy = focal, focal + cx, cy = W / 2, H / 2 + else: + fx, fy = focal + assert principal is not None + cx, cy = principal + + i, j = torch.meshgrid( + torch.arange(W, dtype=torch.float32) + pixel_center, + torch.arange(H, dtype=torch.float32) + pixel_center, + indexing="xy", + ) + + directions: Float[Tensor, "H W 3"] = torch.stack( + [(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1 + ) + + return directions + + +def get_rays( + directions: Float[Tensor, "... 3"], + c2w: Float[Tensor, "... 4 4"], + keepdim=False, + noise_scale=0.0, +) -> Tuple[Float[Tensor, "... 3"], Float[Tensor, "... 3"]]: + # Rotate ray directions from camera coordinate to the world coordinate + assert directions.shape[-1] == 3 + + if directions.ndim == 2: # (N_rays, 3) + if c2w.ndim == 2: # (4, 4) + c2w = c2w[None, :, :] + assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4) + rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3) + rays_o = c2w[:, :3, 3].expand(rays_d.shape) + elif directions.ndim == 3: # (H, W, 3) + assert c2w.ndim in [2, 3] + if c2w.ndim == 2: # (4, 4) + rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum( + -1 + ) # (H, W, 3) + rays_o = c2w[None, None, :3, 3].expand(rays_d.shape) + elif c2w.ndim == 3: # (B, 4, 4) + rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum( + -1 + ) # (B, H, W, 3) + rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape) + elif directions.ndim == 4: # (B, H, W, 3) + assert c2w.ndim == 3 # (B, 4, 4) + rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum( + -1 + ) # (B, H, W, 3) + rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape) + + # add camera noise to avoid grid-like artifect + # https://github.com/ashawkey/stable-dreamfusion/blob/49c3d4fa01d68a4f027755acf94e1ff6020458cc/nerf/utils.py#L373 + if noise_scale > 0: + rays_o = rays_o + torch.randn(3, device=rays_o.device) * noise_scale + rays_d = rays_d + torch.randn(3, device=rays_d.device) * noise_scale + + rays_d = F.normalize(rays_d, dim=-1) + if not keepdim: + rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) + + return rays_o, rays_d + + +def get_projection_matrix( + fovy: Float[Tensor, "B"], aspect_wh: float, near: float, far: float +) -> Float[Tensor, "B 4 4"]: + batch_size = fovy.shape[0] + proj_mtx = torch.zeros(batch_size, 4, 4, dtype=torch.float32) + proj_mtx[:, 0, 0] = 1.0 / (torch.tan(fovy / 2.0) * aspect_wh) + proj_mtx[:, 1, 1] = -1.0 / torch.tan( + fovy / 2.0 + ) # add a negative sign here as the y axis is flipped in nvdiffrast output + proj_mtx[:, 2, 2] = -(far + near) / (far - near) + proj_mtx[:, 2, 3] = -2.0 * far * near / (far - near) + proj_mtx[:, 3, 2] = -1.0 + return proj_mtx + + +def get_mvp_matrix( + c2w: Float[Tensor, "B 4 4"], proj_mtx: Float[Tensor, "B 4 4"] +) -> Float[Tensor, "B 4 4"]: + # calculate w2c from c2w: R' = Rt, t' = -Rt * t + # mathematically equivalent to (c2w)^-1 + w2c: Float[Tensor, "B 4 4"] = torch.zeros(c2w.shape[0], 4, 4).to(c2w) + w2c[:, :3, :3] = c2w[:, :3, :3].permute(0, 2, 1) + w2c[:, :3, 3:] = -c2w[:, :3, :3].permute(0, 2, 1) @ c2w[:, :3, 3:] + w2c[:, 3, 3] = 1.0 + # calculate mvp matrix by proj_mtx @ w2c (mv_mtx) + mvp_mtx = proj_mtx @ w2c + return mvp_mtx + + +def get_full_projection_matrix( + c2w: Float[Tensor, "B 4 4"], proj_mtx: Float[Tensor, "B 4 4"] +) -> Float[Tensor, "B 4 4"]: + return (c2w.unsqueeze(0).bmm(proj_mtx.unsqueeze(0))).squeeze(0) + + +# gaussian splatting functions +def convert_pose(C2W): + flip_yz = torch.eye(4, device=C2W.device) + flip_yz[1, 1] = -1 + flip_yz[2, 2] = -1 + C2W = torch.matmul(C2W, flip_yz) + return C2W + + +def get_projection_matrix_gaussian(znear, zfar, fovX, fovY, device="cuda"): + tanHalfFovY = math.tan((fovY / 2)) + tanHalfFovX = math.tan((fovX / 2)) + + top = tanHalfFovY * znear + bottom = -top + right = tanHalfFovX * znear + left = -right + + P = torch.zeros(4, 4, device=device) + + z_sign = 1.0 + + P[0, 0] = 2.0 * znear / (right - left) + P[1, 1] = 2.0 * znear / (top - bottom) + P[0, 2] = (right + left) / (right - left) + P[1, 2] = (top + bottom) / (top - bottom) + P[3, 2] = z_sign + P[2, 2] = z_sign * zfar / (zfar - znear) + P[2, 3] = -(zfar * znear) / (zfar - znear) + return P + + +def get_fov_gaussian(P): + tanHalfFovX = 1 / P[0, 0] + tanHalfFovY = 1 / P[1, 1] + fovY = math.atan(tanHalfFovY) * 2 + fovX = math.atan(tanHalfFovX) * 2 + return fovX, fovY + + +def get_cam_info_gaussian(c2w, fovx, fovy, znear, zfar): + c2w = convert_pose(c2w) + world_view_transform = torch.inverse(c2w) + + world_view_transform = world_view_transform.transpose(0, 1).cuda().float() + projection_matrix = ( + get_projection_matrix_gaussian(znear=znear, zfar=zfar, fovX=fovx, fovY=fovy) + .transpose(0, 1) + .cuda() + ) + full_proj_transform = ( + world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0)) + ).squeeze(0) + camera_center = world_view_transform.inverse()[3, :3] + + return world_view_transform, full_proj_transform, camera_center + + +def binary_cross_entropy(input, target): + """ + F.binary_cross_entropy is not numerically stable in mixed-precision training. + """ + return -(target * torch.log(input) + (1 - target) * torch.log(1 - input)).mean() + + +def tet_sdf_diff( + vert_sdf: Float[Tensor, "Nv 1"], tet_edges: Integer[Tensor, "Ne 2"] +) -> Float[Tensor, ""]: + sdf_f1x6x2 = vert_sdf[:, 0][tet_edges.reshape(-1)].reshape(-1, 2) + mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1]) + sdf_f1x6x2 = sdf_f1x6x2[mask] + sdf_diff = F.binary_cross_entropy_with_logits( + sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float() + ) + F.binary_cross_entropy_with_logits( + sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float() + ) + return sdf_diff + + +# Implementation from Latent-NeRF +# https://github.com/eladrich/latent-nerf/blob/f49ecefcd48972e69a28e3116fe95edf0fac4dc8/src/latent_nerf/models/mesh_utils.py +class MeshOBJ: + dx = torch.zeros(3).float() + dx[0] = 1 + dy, dz = dx[[1, 0, 2]], dx[[2, 1, 0]] + dx, dy, dz = dx[None, :], dy[None, :], dz[None, :] + + def __init__(self, v: np.ndarray, f: np.ndarray): + self.v = v + self.f = f + self.dx, self.dy, self.dz = MeshOBJ.dx, MeshOBJ.dy, MeshOBJ.dz + self.v_tensor = torch.from_numpy(self.v) + + vf = self.v[self.f, :] + self.f_center = vf.mean(axis=1) + self.f_center_tensor = torch.from_numpy(self.f_center).float() + + e1 = vf[:, 1, :] - vf[:, 0, :] + e2 = vf[:, 2, :] - vf[:, 0, :] + self.face_normals = np.cross(e1, e2) + self.face_normals = ( + self.face_normals / np.linalg.norm(self.face_normals, axis=-1)[:, None] + ) + self.face_normals_tensor = torch.from_numpy(self.face_normals) + + def normalize_mesh(self, target_scale=0.5): + verts = self.v + + # Compute center of bounding box + # center = torch.mean(torch.column_stack([torch.max(verts, dim=0)[0], torch.min(verts, dim=0)[0]])) + center = verts.mean(axis=0) + verts = verts - center + scale = np.max(np.linalg.norm(verts, axis=1)) + verts = (verts / scale) * target_scale + + return MeshOBJ(verts, self.f) + + def winding_number(self, query: torch.Tensor): + device = query.device + shp = query.shape + query_np = query.detach().cpu().reshape(-1, 3).numpy() + target_alphas = fast_winding_number_for_meshes( + self.v.astype(np.float32), self.f, query_np + ) + return torch.from_numpy(target_alphas).reshape(shp[:-1]).to(device) + + def gaussian_weighted_distance(self, query: torch.Tensor, sigma): + device = query.device + shp = query.shape + query_np = query.detach().cpu().reshape(-1, 3).numpy() + distances, _, _ = point_mesh_squared_distance( + query_np, self.v.astype(np.float32), self.f + ) + distances = torch.from_numpy(distances).reshape(shp[:-1]).to(device) + weight = torch.exp(-(distances / (2 * sigma**2))) + return weight + + +def ce_pq_loss(p, q, weight=None): + def clamp(v, T=0.0001): + return v.clamp(T, 1 - T) + + p = p.view(q.shape) + ce = -1 * (p * torch.log(clamp(q)) + (1 - p) * torch.log(clamp(1 - q))) + if weight is not None: + ce *= weight + return ce.sum() + + +class ShapeLoss(nn.Module): + def __init__(self, guide_shape): + super().__init__() + self.mesh_scale = 0.7 + self.proximal_surface = 0.3 + self.delta = 0.2 + self.shape_path = guide_shape + v, _, _, f, _, _ = read_obj(self.shape_path, float) + mesh = MeshOBJ(v, f) + matrix_rot = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) @ np.array( + [[0, 0, 1], [0, 1, 0], [-1, 0, 0]] + ) + self.sketchshape = mesh.normalize_mesh(self.mesh_scale) + self.sketchshape = MeshOBJ( + np.ascontiguousarray( + (matrix_rot @ self.sketchshape.v.transpose(1, 0)).transpose(1, 0) + ), + f, + ) + + def forward(self, xyzs, sigmas): + mesh_occ = self.sketchshape.winding_number(xyzs) + if self.proximal_surface > 0: + weight = 1 - self.sketchshape.gaussian_weighted_distance( + xyzs, self.proximal_surface + ) + else: + weight = None + indicator = (mesh_occ > 0.5).float() + nerf_occ = 1 - torch.exp(-self.delta * sigmas) + nerf_occ = nerf_occ.clamp(min=0, max=1.1) + loss = ce_pq_loss( + nerf_occ, indicator, weight=weight + ) # order is important for CE loss + second argument may not be optimized + return loss + + +def shifted_expotional_decay(a, b, c, r): + return a * torch.exp(-b * r) + c + + +def shifted_cosine_decay(a, b, c, r): + return a * torch.cos(b * r + c) + a + + +def perpendicular_component(x: Float[Tensor, "B C H W"], y: Float[Tensor, "B C H W"]): + # get the component of x that is perpendicular to y + eps = torch.ones_like(x[:, 0, 0, 0]) * 1e-6 + return ( + x + - ( + torch.mul(x, y).sum(dim=[1, 2, 3]) + / torch.maximum(torch.mul(y, y).sum(dim=[1, 2, 3]), eps) + ).view(-1, 1, 1, 1) + * y + ) + + +def validate_empty_rays(ray_indices, t_start, t_end): + if ray_indices.nelement() == 0: + print("Warn Empty rays_indices!") + ray_indices = torch.LongTensor([0]).to(ray_indices) + t_start = torch.Tensor([0]).to(ray_indices) + t_end = torch.Tensor([0]).to(ray_indices) + return ray_indices, t_start, t_end + diff --git a/svrm/ldm/utils/typing.py b/svrm/ldm/utils/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..35320a646e5bbd16ee7bf91b1ad804a7a1e2c6ac --- /dev/null +++ b/svrm/ldm/utils/typing.py @@ -0,0 +1,38 @@ +""" +This module contains type annotations for the project, using +1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects +2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors + +Two types of typing checking can be used: +1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode) +2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking) +""" + +# Basic types +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Literal, + NamedTuple, + NewType, + Optional, + Sized, + Tuple, + Type, + TypeVar, + Union, +) + +# Tensor dtype +# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md +from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt + +# PyTorch Tensor type +from torch import Tensor + +# Runtime type checking decorator +from typeguard import typechecked as typechecker + diff --git a/svrm/ldm/vis_util.py b/svrm/ldm/vis_util.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddddfc74266ebdf8d19bd5106ffb17580d057b4 --- /dev/null +++ b/svrm/ldm/vis_util.py @@ -0,0 +1,91 @@ +import os +from typing import List, Optional +from PIL import Image +import imageio +import time +import torch +from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj +from pytorch3d.ops import interpolate_face_attributes +from pytorch3d.common.datatypes import Device +from pytorch3d.structures import Meshes +from pytorch3d.vis.texture_vis import texturesuv_image_matplotlib +from pytorch3d.renderer import ( + look_at_view_transform, + FoVPerspectiveCameras, + PointLights, + DirectionalLights, + AmbientLights, + Materials, + RasterizationSettings, + MeshRenderer, + MeshRasterizer, + SoftPhongShader, + TexturesUV, + TexturesVertex, + camera_position_from_spherical_angles, + BlendParams, +) + + +def render( + obj_filename, + elev=0, + azim=0, + resolution=512, + gif_dst_path='', + n_views=120, + fps=30, + device="cuda:0", + rgb=False +): + ''' + obj_filename: path to obj file + gif_dst_path: + if set a path, will render n_views frames, then save it to a gif file + if not set, will render single frame, then return PIL.Image instance + rgb: if set true, will convert result to rgb image/frame + ''' + # load mesh + mesh = load_objs_as_meshes([obj_filename], device=device) + meshes = mesh.extend(n_views) + + if gif_dst_path != '': + elev = torch.linspace(elev, elev, n_views+1)[:-1] + azim = torch.linspace(0, 360, n_views+1)[:-1] + + # prepare R,T then compute cameras + R, T = look_at_view_transform(dist=1.5, elev=elev, azim=azim) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=49.1) + + # init pytorch3d renderer instance + renderer = MeshRenderer( + rasterizer=MeshRasterizer( + cameras=cameras, + raster_settings=RasterizationSettings( + image_size=resolution, + blur_radius=0.0, + faces_per_pixel=1, + ), + ), + shader=SoftPhongShader( + device=device, + cameras=cameras, + lights=AmbientLights(device=device), + blend_params=BlendParams(background_color=(1.0, 1.0, 1.0)), + ) + ) + images = renderer(meshes) + + # single frame rendering + if gif_dst_path == '': + frame = images[0, ..., :3] if rgb else images[0, ...] + frame = Image.fromarray((frame.cpu().squeeze(0) * 255).numpy().astype("uint8")) + return frame + + # orbit frames rendering + with imageio.get_writer(uri=gif_dst_path, mode='I', duration=1. / fps * 1000, loop=0) as writer: + for i in range(n_views): + frame = images[i, ..., :3] if rgb else images[i, ...] + frame = Image.fromarray((frame.cpu().squeeze(0) * 255).numpy().astype("uint8")) + writer.append_data(frame) + return gif_dst_path diff --git a/svrm/predictor.py b/svrm/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..19fc4ddd12cbc83caa639c41feba4a2d0e64005a --- /dev/null +++ b/svrm/predictor.py @@ -0,0 +1,150 @@ +# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein: +# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited. + +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# The below software and/or models in this distribution may have been +# modified by THL A29 Limited ("Tencent Modifications"). +# All Tencent Modifications are Copyright (C) THL A29 Limited. + +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +import os +import math +import time +import torch +import numpy as np +from tqdm import tqdm +from PIL import Image, ImageSequence +from omegaconf import OmegaConf +from torchvision import transforms +from safetensors.torch import save_file, load_file +from .ldm.util import instantiate_from_config +from .ldm.vis_util import render + +class MV23DPredictor(object): + def __init__(self, ckpt_path, cfg_path, elevation=15, number_view=60, + render_size=256, device="cuda:0") -> None: + self.device = device + self.elevation = elevation + self.number_view = number_view + self.render_size = render_size + + self.elevation_list = [0, 0, 0, 0, 0, 0, 0] + self.azimuth_list = [0, 60, 120, 180, 240, 300, 0] + + st = time.time() + self.model = self.init_model(ckpt_path, cfg_path) + print(f"=====> mv23d model init time: {time.time() - st}") + + self.input_view_transform = transforms.Compose([ + transforms.Resize(504, interpolation=Image.BICUBIC), + transforms.ToTensor(), + ]) + self.final_input_view_transform = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) + + def init_model(self, ckpt_path, cfg_path): + config = OmegaConf.load(cfg_path) + model = instantiate_from_config(config.model) + + weights = load_file("./weights/svrm/svrm.safetensors") + model.load_state_dict(weights) + + model.to(self.device) + model = model.eval() + model.render.half() + print(f'Load model successfully') + return model + + def create_camera_to_world_matrix(self, elevation, azimuth, cam_dis=1.5): + # elevation azimuth are radians + # Convert elevation and azimuth angles to Cartesian coordinates on a unit sphere + x = np.cos(elevation) * np.cos(azimuth) + y = np.cos(elevation) * np.sin(azimuth) + z = np.sin(elevation) + + # Calculate camera position, target, and up vectors + camera_pos = np.array([x, y, z]) * cam_dis + target = np.array([0, 0, 0]) + up = np.array([0, 0, 1]) + + # Construct view matrix + forward = target - camera_pos + forward /= np.linalg.norm(forward) + right = np.cross(forward, up) + right /= np.linalg.norm(right) + new_up = np.cross(right, forward) + new_up /= np.linalg.norm(new_up) + cam2world = np.eye(4) + cam2world[:3, :3] = np.array([right, new_up, -forward]).T + cam2world[:3, 3] = camera_pos + return cam2world + + def refine_mask(self, mask, k=16): + mask /= 255.0 + boder_mask = (mask >= -math.pi / 2.0 / k + 0.5) & (mask <= math.pi / 2.0 / k + 0.5) + mask[boder_mask] = 0.5 * np.sin(k * (mask[boder_mask] - 0.5)) + 0.5 + mask[mask < -math.pi / 2.0 / k + 0.5] = 0.0 + mask[mask > math.pi / 2.0 / k + 0.5] = 1.0 + return (mask * 255.0).astype(np.uint8) + + def load_images_and_cameras(self, input_imgs, elevation_list, azimuth_list): + input_image_list = [] + input_cam_list = [] + for input_view_image, elevation, azimuth in zip(input_imgs, elevation_list, azimuth_list): + input_view_image = self.input_view_transform(input_view_image) + input_image_list.append(input_view_image) + + input_view_cam_pos = self.create_camera_to_world_matrix(np.radians(elevation), np.radians(azimuth)) + input_view_cam_intrinsic = np.array([35. / 32, 35. /32, 0.5, 0.5]) + input_view_cam = torch.from_numpy( + np.concatenate([input_view_cam_pos.reshape(-1), input_view_cam_intrinsic], 0) + ).float() + input_cam_list.append(input_view_cam) + + pixels_input = torch.stack(input_image_list, dim=0) + input_images = self.final_input_view_transform(pixels_input) + input_cams = torch.stack(input_cam_list, dim=0) + return input_images, input_cams + + def load_data(self, intput_imgs): + assert (6+1) == len(intput_imgs) + + input_images, input_cams = self.load_images_and_cameras(intput_imgs, self.elevation_list, self.azimuth_list) + input_cams[-1, :] = 0 # for user input view + + data = {} + data["input_view"] = input_images.unsqueeze(0).to(self.device) # 1 4 3 512 512 + data["input_view_cam"] = input_cams.unsqueeze(0).to(self.device) # 1 4 20 + return data + + @torch.no_grad() + def predict( + self, + intput_imgs, + save_dir = "outputs/", + image_input = None, + target_face_count = 10000, + do_texture_mapping = True, + ): + os.makedirs(save_dir, exist_ok=True) + print(save_dir) + + with torch.cuda.amp.autocast(): + self.model.export_mesh_with_uv( + data = self.load_data(intput_imgs), + out_dir = save_dir, + target_face_count = target_face_count, + do_texture_mapping = do_texture_mapping + ) diff --git a/svrm/utils/camera_utils.py b/svrm/utils/camera_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb15e74ab632a30b8c78b21511b3fbfc6ca46d4 --- /dev/null +++ b/svrm/utils/camera_utils.py @@ -0,0 +1,90 @@ +import math +import numpy as np + +def compute_extrinsic_matrix(elevation, azimuth, camera_distance): + # 将角度转换为弧度 + elevation_rad = np.radians(elevation) + azimuth_rad = np.radians(azimuth) + + R = np.array([ + [np.cos(azimuth_rad), 0, -np.sin(azimuth_rad)], + [0, 1, 0], + [np.sin(azimuth_rad), 0, np.cos(azimuth_rad)], + ], dtype=np.float32) + + R = R @ np.array([ + [1, 0, 0], + [0, np.cos(elevation_rad), -np.sin(elevation_rad)], + [0, np.sin(elevation_rad), np.cos(elevation_rad)] + ], dtype=np.float32) + + # 构建平移矩阵 T (3x1) + T = np.array([[camera_distance], [0], [0]], dtype=np.float32) + T = R @ T + + # 组合成 4x4 的变换矩阵 + extrinsic_matrix = np.vstack((np.hstack((R, T)), np.array([[0, 0, 0, 1]], dtype=np.float32))) + + return extrinsic_matrix + + +def transform_camera_pose(im_pose, ori_pose, new_pose): + T = new_pose @ ori_pose.T + transformed_poses = [] + + for pose in im_pose: + transformed_pose = T @ pose + transformed_poses.append(transformed_pose) + + return transformed_poses + +def compute_fov(intrinsic_matrix): + # 获取内参矩阵中的焦距值 + fx = intrinsic_matrix[0, 0] + fy = intrinsic_matrix[1, 1] + + h, w = intrinsic_matrix[0,2]*2, intrinsic_matrix[1,2]*2 + + # 计算水平和垂直方向的FOV值 + fov_x = 2 * math.atan(w / (2 * fx)) * 180 / math.pi + fov_y = 2 * math.atan(h / (2 * fy)) * 180 / math.pi + + return fov_x, fov_y + + + +def rotation_matrix_to_quaternion(rotation_matrix): + rot = Rotation.from_matrix(rotation_matrix) + quaternion = rot.as_quat() + return quaternion + +def quaternion_to_rotation_matrix(quaternion): + rot = Rotation.from_quat(quaternion) + rotation_matrix = rot.as_matrix() + return rotation_matrix + +def remap_points(img_size, match, size=512): + H, W, _ = img_size + + S = max(W, H) + new_W = int(round(W * size / S)) + new_H = int(round(H * size / S)) + cx, cy = new_W // 2, new_H // 2 + + # 计算变换后的图像中心点坐标 + halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8 + + dw, dh = cx - halfw, cy - halfh + + # 初始化一个新的数组来存储映射回原始图像的点坐标 + new_match = np.zeros_like(match) + + # 将变换后的点坐标映射回原始图像 + new_match[:, 0] = (match[:, 0] + dw) / new_W * W + new_match[:, 1] = (match[:, 1] + dh) / new_H * H + + #print(dw,new_W,W,dh,new_H,H) + + return new_match + + \ No newline at end of file diff --git a/svrm/utils/img_utils.py b/svrm/utils/img_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6b6db1c08ed09d5fe1d774f6057e2a7d2162b08d --- /dev/null +++ b/svrm/utils/img_utils.py @@ -0,0 +1,217 @@ +import os +import cv2 +import numpy as np +from skimage.metrics import hausdorff_distance +from matplotlib import pyplot as plt + + +def get_input_imgs_path(input_data_dir): + path = {} + names = ['000', 'ori_000'] + for name in names: + jpg_path = os.path.join(input_data_dir, f"{name}.jpg") + png_path = os.path.join(input_data_dir, f"{name}.png") + if os.path.exists(jpg_path): + path[name] = jpg_path + elif os.path.exists(png_path): + path[name] = png_path + return path + + +def rgba_to_rgb(image, bg_color=[255, 255, 255]): + if image.shape[-1] == 3: return image + + rgba = image.astype(float) + rgb = rgba[:, :, :3].copy() + alpha = rgba[:, :, 3] / 255.0 + + bg = np.ones((image.shape[0], image.shape[1], 3), dtype=np.float32) + bg = bg * np.array(bg_color, dtype=np.float32) + + rgb = rgb * alpha[:, :, np.newaxis] + bg * (1 - alpha[:, :, np.newaxis]) + rgb = rgb.astype(np.uint8) + + return rgb + + +def resize_with_aspect_ratio(image1, image2, pad_value=[255, 255, 255]): + aspect_ratio1 = float(image1.shape[1]) / float(image1.shape[0]) + aspect_ratio2 = float(image2.shape[1]) / float(image2.shape[0]) + + top_pad, bottom_pad, left_pad, right_pad = 0, 0, 0, 0 + + if aspect_ratio1 < aspect_ratio2: + new_width = (aspect_ratio2 * image1.shape[0]) + right_pad = left_pad = int((new_width - image1.shape[1]) / 2) + else: + new_height = (image1.shape[1] / aspect_ratio2) + bottom_pad = top_pad = int((new_height - image1.shape[0]) / 2) + + image1_padded = cv2.copyMakeBorder( + image1, top_pad, bottom_pad, left_pad, right_pad, cv2.BORDER_CONSTANT, value=pad_value + ) + return image1_padded + + +def estimate_img_mask(image): + # 转换为灰度图像 + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + + # 使用大津法进行阈值分割 + # _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) + # mask_otsu = thresh.astype(bool) + # thresh_gray = 240 + + # 使用 Canny 边缘检测算法找到边缘 + edges = cv2.Canny(gray, 20, 50) + + # 使用形态学操作扩展边缘 + kernel = np.ones((3, 3), np.uint8) + edges_dilated = cv2.dilate(edges, kernel, iterations=1) + + contours, _ = cv2.findContours(edges_dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + # 创建一个空的 mask + mask = np.zeros_like(gray, dtype=np.uint8) + + # 根据轮廓信息填充 mask(使用 thickness=cv2.FILLED 参数) + cv2.drawContours(mask, contours, -1, 255, thickness=cv2.FILLED) + mask = mask.astype(bool) + + return mask + + +def compute_img_diff(img1, img2, matches1, matches1_from_2, vis=False): + scale = 0.125 + gray_trunc_thres = 25 / 255.0 + + # Match + if matches1.shape[0] > 0: + match_scale = np.max(np.ptp(matches1, axis=-1)) + match_dists = np.sqrt(np.sum((matches1 - matches1_from_2) ** 2, axis=-1)) + dist_threshold = match_scale * 0.01 + match_num = np.sum(match_dists <= dist_threshold) + match_rate = np.mean(match_dists <= dist_threshold) + else: + match_num = 0 + match_rate = 0 + + # IOU + img1_mask = estimate_img_mask(img1) + img2_mask = estimate_img_mask(img2) + img_intersection = (img1_mask == 1) & (img2_mask == 1) + img_union = (img1_mask == 1) | (img2_mask == 1) + intersection = np.sum(img_intersection == 1) + union = np.sum(img_union == 1) + mask_iou = intersection / union if union != 0 else 0 + + # Gray + height, width = img1.shape[:2] + img1_gray = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY) + img2_gray = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY) + img1_gray = cv2.GaussianBlur(img1_gray, (7, 7), 0) + img2_gray = cv2.GaussianBlur(img2_gray, (7, 7), 0) + + # Gray Diff + img1_gray_small = cv2.resize(img1_gray, (int(width * scale), int(height * scale)), + interpolation=cv2.INTER_LINEAR) / 255.0 + img2_gray_small = cv2.resize(img2_gray, (int(width * scale), int(height * scale)), + interpolation=cv2.INTER_LINEAR) / 255.0 + img_gray_small_diff = np.abs(img1_gray_small - img2_gray_small) + gray_diff = img_gray_small_diff.sum() / (union * scale) if union != 0 else 1 + + img_gray_small_diff_trunc = img_gray_small_diff.copy() + img_gray_small_diff_trunc[img_gray_small_diff < gray_trunc_thres] = 0 + gray_diff_trunc = img_gray_small_diff_trunc.sum() / (union * scale) if union != 0 else 1 + + # Edge + img1_edge = cv2.Canny(img1_gray, 100, 200) + img2_edge = cv2.Canny(img2_gray, 100, 200) + bw_edges1 = (img1_edge > 0).astype(bool) + bw_edges2 = (img2_edge > 0).astype(bool) + hausdorff_dist = hausdorff_distance(bw_edges1, bw_edges2) + if vis == True: + fig, axs = plt.subplots(1, 4, figsize=(15, 5)) + axs[0].imshow(img1_gray, cmap='gray') + axs[0].set_title('Img1') + axs[1].imshow(img2_gray, cmap='gray') + axs[1].set_title('Img2') + axs[2].imshow(img1_mask) + axs[2].set_title('Mask1') + axs[3].imshow(img2_mask) + axs[3].set_title('Mask2') + plt.show() + plt.figure() + mask_cmp = np.zeros((height, width, 3)) + mask_cmp[img_intersection, 1] = 1 + mask_cmp[img_union, 0] = 1 + plt.imshow(mask_cmp) + plt.show() + fig, axs = plt.subplots(1, 4, figsize=(15, 5)) + axs[0].imshow(img1_gray_small, cmap='gray') + axs[0].set_title('Img1 Gray') + axs[1].imshow(img2_gray_small, cmap='gray') + axs[1].set_title('Img2 Gary') + axs[2].imshow(img_gray_small_diff, cmap='gray') + axs[2].set_title('diff') + axs[3].imshow(img_gray_small_diff_trunc, cmap='gray') + axs[3].set_title('diff_trunct') + plt.show() + fig, axs = plt.subplots(1, 2, figsize=(15, 5)) + axs[0].imshow(img1_edge, cmap='gray') + axs[0].set_title('img1_edge') + axs[1].imshow(img2_edge, cmap='gray') + axs[1].set_title('img2_edge') + plt.show() + + info = {} + info['match_num'] = match_num + info['match_rate'] = match_rate + info['mask_iou'] = mask_iou + info['gray_diff'] = gray_diff + info['gray_diff_trunc'] = gray_diff_trunc + info['hausdorff_dist'] = hausdorff_dist + return info + + +def predict_match_success_human(info): + match_num = info['match_num'] + match_rate = info['match_rate'] + mask_iou = info['mask_iou'] + gray_diff = info['gray_diff'] + gray_diff_trunc = info['gray_diff_trunc'] + hausdorff_dist = info['hausdorff_dist'] + + if mask_iou > 0.95: + return True + + if match_num < 20 or match_rate < 0.7: + return False + + if mask_iou > 0.80 and gray_diff < 0.040 and gray_diff_trunc < 0.010: + return True + + if mask_iou > 0.70 and gray_diff < 0.050 and gray_diff_trunc < 0.008: + return True + + ''' + if match_rate<0.70 or match_num<3000: + return False + + if (mask_iou>0.85 and hausdorff_dist<20)or (gray_diff<0.015 and gray_diff_trunc<0.01) or match_rate>=0.90: + return True + ''' + + return False + + +def predict_match_success(info, model=None): + if model == None: + return predict_match_success_human(info) + else: + feat_name = ['match_num', 'match_rate', 'mask_iou', 'gray_diff', 'gray_diff_trunc', 'hausdorff_dist'] + # 提取特征 + features = [info[f] for f in feat_name] + # 预测 + pred = model.predict([features])[0] + return pred >= 0.5 \ No newline at end of file diff --git a/svrm/utils/log_utils.py b/svrm/utils/log_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..12d0a5abcff36904ac7d748f10f8c94e37007868 --- /dev/null +++ b/svrm/utils/log_utils.py @@ -0,0 +1,21 @@ +import cv2 +import numpy as np + +def txt_to_img(text, font=cv2.FONT_HERSHEY_SIMPLEX, font_scale=1, font_thickness=2, img_width=1000, img_height=100, text_color=(0, 0, 0), bg_color=(255, 255, 255)): + lines = text.split('\n') + img_lines = [] + for line in lines: + # 计算每行文本的尺寸 + line_size, _ = cv2.getTextSize(line, font, font_scale, font_thickness) + line_width, line_height = line_size + # 创建包含当前行的图像画布 + img_line = np.full((int(line_height*1.5) , img_width, 3), bg_color, dtype=np.uint8) + text_x, text_y = 0, line_height + cv2.putText(img_line, line, (text_x, text_y), font, font_scale, text_color, font_thickness, cv2.LINE_AA) + + img_lines.append(img_line) + + + # 垂直堆叠所有行图像 + img = np.vstack(img_lines) + return img \ No newline at end of file