diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..4720b1a53b512a49e81661f68abecdc267a7bae7 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -1,35 +1,37 @@
-*.7z filter=lfs diff=lfs merge=lfs -text
-*.arrow filter=lfs diff=lfs merge=lfs -text
-*.bin filter=lfs diff=lfs merge=lfs -text
-*.bz2 filter=lfs diff=lfs merge=lfs -text
-*.ckpt filter=lfs diff=lfs merge=lfs -text
-*.ftz filter=lfs diff=lfs merge=lfs -text
-*.gz filter=lfs diff=lfs merge=lfs -text
-*.h5 filter=lfs diff=lfs merge=lfs -text
-*.joblib filter=lfs diff=lfs merge=lfs -text
-*.lfs.* filter=lfs diff=lfs merge=lfs -text
-*.mlmodel filter=lfs diff=lfs merge=lfs -text
-*.model filter=lfs diff=lfs merge=lfs -text
-*.msgpack filter=lfs diff=lfs merge=lfs -text
-*.npy filter=lfs diff=lfs merge=lfs -text
-*.npz filter=lfs diff=lfs merge=lfs -text
-*.onnx filter=lfs diff=lfs merge=lfs -text
-*.ot filter=lfs diff=lfs merge=lfs -text
-*.parquet filter=lfs diff=lfs merge=lfs -text
-*.pb filter=lfs diff=lfs merge=lfs -text
-*.pickle filter=lfs diff=lfs merge=lfs -text
-*.pkl filter=lfs diff=lfs merge=lfs -text
-*.pt filter=lfs diff=lfs merge=lfs -text
-*.pth filter=lfs diff=lfs merge=lfs -text
-*.rar filter=lfs diff=lfs merge=lfs -text
-*.safetensors filter=lfs diff=lfs merge=lfs -text
-saved_model/**/* filter=lfs diff=lfs merge=lfs -text
-*.tar.* filter=lfs diff=lfs merge=lfs -text
-*.tar filter=lfs diff=lfs merge=lfs -text
-*.tflite filter=lfs diff=lfs merge=lfs -text
-*.tgz filter=lfs diff=lfs merge=lfs -text
-*.wasm filter=lfs diff=lfs merge=lfs -text
-*.xz filter=lfs diff=lfs merge=lfs -text
-*.zip filter=lfs diff=lfs merge=lfs -text
-*.zst filter=lfs diff=lfs merge=lfs -text
-*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/teaser.png filter=lfs diff=lfs merge=lfs -text
+demos/example_003.png filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
index 4bb7a445b43a31999cb8d125c044a86cfdf08fd4..251cf678ebe65c33fbb110fb78495e15c99fb1bd 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,11 @@
---
-title: InstantIR
-emoji: 🖼
+title: Hunyuan3D-1.0
+emoji: 😻
colorFrom: purple
colorTo: red
sdk: gradio
sdk_version: 4.42.0
app_file: app.py
pinned: false
-license: apache-2.0
-short_description: diffusion-based Image Restoration model
----
-
-Check out the configuration reference at https://huggingface.co./docs/hub/spaces-config-reference
\ No newline at end of file
+short_description: Text-to-3D and Image-to-3D Generation
+---
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bc5dde13399a86d0ad1438afb50515f76eb550c
--- /dev/null
+++ b/app.py
@@ -0,0 +1,284 @@
+import os
+import warnings
+from huggingface_hub import hf_hub_download
+import gradio as gr
+from glob import glob
+import shutil
+import torch
+import numpy as np
+from PIL import Image
+from einops import rearrange
+import argparse
+
+# Suppress warnings
+warnings.simplefilter('ignore', category=UserWarning)
+warnings.simplefilter('ignore', category=FutureWarning)
+warnings.simplefilter('ignore', category=DeprecationWarning)
+
+def download_models():
+ # Create weights directory if it doesn't exist
+ os.makedirs("weights", exist_ok=True)
+ os.makedirs("weights/hunyuanDiT", exist_ok=True)
+
+ # Download Hunyuan3D-1 model
+ try:
+ hf_hub_download(
+ repo_id="tencent/Hunyuan3D-1",
+ local_dir="./weights",
+ resume_download=True
+ )
+ print("Successfully downloaded Hunyuan3D-1 model")
+ except Exception as e:
+ print(f"Error downloading Hunyuan3D-1: {e}")
+
+ # Download HunyuanDiT model
+ try:
+ hf_hub_download(
+ repo_id="Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled",
+ local_dir="./weights/hunyuanDiT",
+ resume_download=True
+ )
+ print("Successfully downloaded HunyuanDiT model")
+ except Exception as e:
+ print(f"Error downloading HunyuanDiT: {e}")
+
+# Download models before starting the app
+download_models()
+
+# Parse arguments
+parser = argparse.ArgumentParser()
+parser.add_argument("--use_lite", default=False, action="store_true")
+parser.add_argument("--mv23d_cfg_path", default="./svrm/configs/svrm.yaml", type=str)
+parser.add_argument("--mv23d_ckt_path", default="weights/svrm/svrm.safetensors", type=str)
+parser.add_argument("--text2image_path", default="weights/hunyuanDiT", type=str)
+parser.add_argument("--save_memory", default=False, action="store_true")
+parser.add_argument("--device", default="cuda:0", type=str)
+args = parser.parse_args()
+
+# Constants
+CONST_PORT = 8080
+CONST_MAX_QUEUE = 1
+CONST_SERVER = '0.0.0.0'
+
+CONST_HEADER = '''
+
Official 🤗 Gradio Demo
+
+'''
+
+# Helper functions
+def get_example_img_list():
+ print('Loading example img list ...')
+ return sorted(glob('./demos/example_*.png'))
+
+def get_example_txt_list():
+ print('Loading example txt list ...')
+ txt_list = []
+ for line in open('./demos/example_list.txt'):
+ txt_list.append(line.strip())
+ return txt_list
+
+example_is = get_example_img_list()
+example_ts = get_example_txt_list()
+
+# Import required workers
+from infer import seed_everything, save_gif
+from infer import Text2Image, Removebg, Image2Views, Views2Mesh, GifRenderer
+
+# Initialize workers
+worker_xbg = Removebg()
+print(f"loading {args.text2image_path}")
+worker_t2i = Text2Image(
+ pretrain=args.text2image_path,
+ device=args.device,
+ save_memory=args.save_memory
+)
+worker_i2v = Image2Views(
+ use_lite=args.use_lite,
+ device=args.device
+)
+worker_v23 = Views2Mesh(
+ args.mv23d_cfg_path,
+ args.mv23d_ckt_path,
+ use_lite=args.use_lite,
+ device=args.device
+)
+worker_gif = GifRenderer(args.device)
+
+# Pipeline stages
+def stage_0_t2i(text, image, seed, step):
+ os.makedirs('./outputs/app_output', exist_ok=True)
+ exists = set(int(_) for _ in os.listdir('./outputs/app_output') if not _.startswith("."))
+ cur_id = min(set(range(30)) - exists) if len(exists) < 30 else 0
+
+ if os.path.exists(f"./outputs/app_output/{(cur_id + 1) % 30}"):
+ shutil.rmtree(f"./outputs/app_output/{(cur_id + 1) % 30}")
+ save_folder = f'./outputs/app_output/{cur_id}'
+ os.makedirs(save_folder, exist_ok=True)
+
+ dst = save_folder + '/img.png'
+
+ if not text:
+ if image is None:
+ return dst, save_folder
+ image.save(dst)
+ return dst, save_folder
+
+ image = worker_t2i(text, seed, step)
+ image.save(dst)
+ dst = worker_xbg(image, save_folder)
+ return dst, save_folder
+
+def stage_1_xbg(image, save_folder):
+ if isinstance(image, str):
+ image = Image.open(image)
+ dst = save_folder + '/img_nobg.png'
+ rgba = worker_xbg(image)
+ rgba.save(dst)
+ return dst
+
+def stage_2_i2v(image, seed, step, save_folder):
+ if isinstance(image, str):
+ image = Image.open(image)
+ gif_dst = save_folder + '/views.gif'
+ res_img, pils = worker_i2v(image, seed, step)
+ save_gif(pils, gif_dst)
+ views_img, cond_img = res_img[0], res_img[1]
+ img_array = np.asarray(views_img, dtype=np.uint8)
+ show_img = rearrange(img_array, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
+ show_img = show_img[worker_i2v.order, ...]
+ show_img = rearrange(show_img, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
+ show_img = Image.fromarray(show_img)
+ return views_img, cond_img, show_img
+
+def stage_3_v23(views_pil, cond_pil, seed, save_folder, target_face_count=30000,
+ do_texture_mapping=True, do_render=True):
+ do_texture_mapping = do_texture_mapping or do_render
+ obj_dst = save_folder + '/mesh_with_colors.obj'
+ glb_dst = save_folder + '/mesh.glb'
+ worker_v23(
+ views_pil,
+ cond_pil,
+ seed=seed,
+ save_folder=save_folder,
+ target_face_count=target_face_count,
+ do_texture_mapping=do_texture_mapping
+ )
+ return obj_dst, glb_dst
+
+def stage_4_gif(obj_dst, save_folder, do_render_gif=True):
+ if not do_render_gif:
+ return None
+ gif_dst = save_folder + '/output.gif'
+ worker_gif(
+ save_folder + '/mesh.obj',
+ gif_dst_path=gif_dst
+ )
+ return gif_dst
+
+# Gradio Interface
+with gr.Blocks() as demo:
+ gr.Markdown(CONST_HEADER)
+
+ with gr.Row(variant="panel"):
+ with gr.Column(scale=2):
+ with gr.Tab("Text to 3D"):
+ with gr.Column():
+ text = gr.TextArea('一只黑白相间的熊猫在白色背景上居中坐着,呈现出卡通风格和可爱氛围。',
+ lines=1, max_lines=10, label='Input text')
+ with gr.Row():
+ textgen_seed = gr.Number(value=0, label="T2I seed", precision=0)
+ textgen_step = gr.Number(value=25, label="T2I step", precision=0)
+ textgen_SEED = gr.Number(value=0, label="Gen seed", precision=0)
+ textgen_STEP = gr.Number(value=50, label="Gen step", precision=0)
+ textgen_max_faces = gr.Number(value=90000, label="max number of faces", precision=0)
+
+ with gr.Row():
+ textgen_do_texture_mapping = gr.Checkbox(label="texture mapping", value=False)
+ textgen_do_render_gif = gr.Checkbox(label="Render gif", value=False)
+ textgen_submit = gr.Button("Generate", variant="primary")
+
+ gr.Examples(examples=example_ts, inputs=[text], label="Txt examples")
+
+ with gr.Tab("Image to 3D"):
+ with gr.Column():
+ input_image = gr.Image(label="Input image", width=256, height=256,
+ type="pil", image_mode="RGBA", sources="upload")
+ with gr.Row():
+ imggen_SEED = gr.Number(value=0, label="Gen seed", precision=0)
+ imggen_STEP = gr.Number(value=50, label="Gen step", precision=0)
+ imggen_max_faces = gr.Number(value=90000, label="max number of faces", precision=0)
+
+ with gr.Row():
+ imggen_do_texture_mapping = gr.Checkbox(label="texture mapping", value=False)
+ imggen_do_render_gif = gr.Checkbox(label="Render gif", value=False)
+ imggen_submit = gr.Button("Generate", variant="primary")
+
+ gr.Examples(examples=example_is, inputs=[input_image], label="Img examples")
+
+ with gr.Column(scale=3):
+ with gr.Tab("rembg image"):
+ rem_bg_image = gr.Image(label="No background image", width=256, height=256,
+ type="pil", image_mode="RGBA")
+
+ with gr.Tab("Multi views"):
+ result_image = gr.Image(label="Multi views", type="pil")
+ with gr.Tab("Obj"):
+ result_3dobj = gr.Model3D(label="Output obj")
+ with gr.Tab("Glb"):
+ result_3dglb = gr.Model3D(label="Output glb")
+ with gr.Tab("GIF"):
+ result_gif = gr.Image(label="Rendered GIF")
+
+ # States
+ none = gr.State(None)
+ save_folder = gr.State()
+ cond_image = gr.State()
+ views_image = gr.State()
+ text_image = gr.State()
+
+ # Event handlers
+ textgen_submit.click(
+ fn=stage_0_t2i,
+ inputs=[text, none, textgen_seed, textgen_step],
+ outputs=[rem_bg_image, save_folder],
+ ).success(
+ fn=stage_2_i2v,
+ inputs=[rem_bg_image, textgen_SEED, textgen_STEP, save_folder],
+ outputs=[views_image, cond_image, result_image],
+ ).success(
+ fn=stage_3_v23,
+ inputs=[views_image, cond_image, textgen_SEED, save_folder, textgen_max_faces,
+ textgen_do_texture_mapping, textgen_do_render_gif],
+ outputs=[result_3dobj, result_3dglb],
+ ).success(
+ fn=stage_4_gif,
+ inputs=[result_3dglb, save_folder, textgen_do_render_gif],
+ outputs=[result_gif],
+ )
+
+ imggen_submit.click(
+ fn=stage_0_t2i,
+ inputs=[none, input_image, textgen_seed, textgen_step],
+ outputs=[text_image, save_folder],
+ ).success(
+ fn=stage_1_xbg,
+ inputs=[text_image, save_folder],
+ outputs=[rem_bg_image],
+ ).success(
+ fn=stage_2_i2v,
+ inputs=[rem_bg_image, imggen_SEED, imggen_STEP, save_folder],
+ outputs=[views_image, cond_image, result_image],
+ ).success(
+ fn=stage_3_v23,
+ inputs=[views_image, cond_image, imggen_SEED, save_folder, imggen_max_faces,
+ imggen_do_texture_mapping, imggen_do_render_gif],
+ outputs=[result_3dobj, result_3dglb],
+ ).success(
+ fn=stage_4_gif,
+ inputs=[result_3dglb, save_folder, imggen_do_render_gif],
+ outputs=[result_gif],
+ )
+
+ demo.queue(max_size=CONST_MAX_QUEUE)
+ demo.launch(server_name=CONST_SERVER, server_port=CONST_PORT)
\ No newline at end of file
diff --git a/assets/logo.png b/assets/logo.png
new file mode 100644
index 0000000000000000000000000000000000000000..06dab332d2138fe5ddbd0555780817e9602e779d
Binary files /dev/null and b/assets/logo.png differ
diff --git a/assets/overview_3.png b/assets/overview_3.png
new file mode 100644
index 0000000000000000000000000000000000000000..edbec8202aeed23a277c94dc92d15515db2b18dc
Binary files /dev/null and b/assets/overview_3.png differ
diff --git a/assets/radar.png b/assets/radar.png
new file mode 100644
index 0000000000000000000000000000000000000000..fd57b17e36f15e9e78d4e8177fd62cef00f0e94b
Binary files /dev/null and b/assets/radar.png differ
diff --git a/assets/runtime.png b/assets/runtime.png
new file mode 100644
index 0000000000000000000000000000000000000000..220d288e9f78cfb4e6bf4ed811cb2890e52925ae
Binary files /dev/null and b/assets/runtime.png differ
diff --git a/assets/teaser.png b/assets/teaser.png
new file mode 100644
index 0000000000000000000000000000000000000000..f459167e3a32df9629dce6df1ffc41a6994825af
--- /dev/null
+++ b/assets/teaser.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:af24eeebe39864d377b7ef8e11521a8b7cba964c14032cc28bd0d95bd5219c00
+size 3097514
diff --git a/demos/example_000.png b/demos/example_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..6891f23a3a621255ee97a100bb40b5a27a9cfa20
Binary files /dev/null and b/demos/example_000.png differ
diff --git a/demos/example_001.png b/demos/example_001.png
new file mode 100644
index 0000000000000000000000000000000000000000..667dc22ce710ab308de3589b2f42cf7d10131ec8
Binary files /dev/null and b/demos/example_001.png differ
diff --git a/demos/example_002.png b/demos/example_002.png
new file mode 100644
index 0000000000000000000000000000000000000000..1c359f2bcffa3603462249a5046fe4f7a63c0435
Binary files /dev/null and b/demos/example_002.png differ
diff --git a/demos/example_003.png b/demos/example_003.png
new file mode 100644
index 0000000000000000000000000000000000000000..6c1e30acfb67a99da531571d2feaf5d3060ffefc
--- /dev/null
+++ b/demos/example_003.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d947e0ef10baf761abb78d2842519ae7428bc6eadab26a159510ddcaf2a47e67
+size 1066046
diff --git a/demos/example_list.txt b/demos/example_list.txt
new file mode 100644
index 0000000000000000000000000000000000000000..09d331354dc6978d24793bb4557b672485a9f5c9
--- /dev/null
+++ b/demos/example_list.txt
@@ -0,0 +1,2 @@
+a pot of green plants grows in a red flower pot.
+a lovely rabbit eating carrots
diff --git a/infer/__init__.py b/infer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5db8c89d17f7bdf774877d9d52f0c42f16d87617
--- /dev/null
+++ b/infer/__init__.py
@@ -0,0 +1,28 @@
+# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
+# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
+
+# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
+# The below software and/or models in this distribution may have been
+# modified by THL A29 Limited ("Tencent Modifications").
+# All Tencent Modifications are Copyright (C) THL A29 Limited.
+
+# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
+# except for the third-party components listed below.
+# Hunyuan 3D does not impose any additional limitations beyond what is outlined
+# in the repsective licenses of these third-party components.
+# Users must comply with all terms and conditions of original licenses of these third-party
+# components and must ensure that the usage of the third party components adheres to
+# all relevant laws and regulations.
+
+# For avoidance of doubts, Hunyuan 3D means the large language models and
+# their software and algorithms, including trained model weights, parameters (including
+# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
+# fine-tuning enabling code and other elements of the foregoing made publicly available
+# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
+
+from .utils import seed_everything, timing_decorator, auto_amp_inference
+from .rembg import Removebg
+from .text_to_image import Text2Image
+from .image_to_views import Image2Views, save_gif
+from .views_to_mesh import Views2Mesh
+from .gif_render import GifRenderer
diff --git a/infer/gif_render.py b/infer/gif_render.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d5bceffba12b5a595e7be889cc9c979f3db8c98
--- /dev/null
+++ b/infer/gif_render.py
@@ -0,0 +1,55 @@
+# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
+# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
+
+# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
+# The below software and/or models in this distribution may have been
+# modified by THL A29 Limited ("Tencent Modifications").
+# All Tencent Modifications are Copyright (C) THL A29 Limited.
+
+# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
+# except for the third-party components listed below.
+# Hunyuan 3D does not impose any additional limitations beyond what is outlined
+# in the repsective licenses of these third-party components.
+# Users must comply with all terms and conditions of original licenses of these third-party
+# components and must ensure that the usage of the third party components adheres to
+# all relevant laws and regulations.
+
+# For avoidance of doubts, Hunyuan 3D means the large language models and
+# their software and algorithms, including trained model weights, parameters (including
+# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
+# fine-tuning enabling code and other elements of the foregoing made publicly available
+# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
+
+from svrm.ldm.vis_util import render
+from .utils import seed_everything, timing_decorator
+
+class GifRenderer():
+ '''
+ render frame(s) of mesh using pytorch3d
+ '''
+ def __init__(self, device="cuda:0"):
+ self.device = device
+
+ @timing_decorator("gif render")
+ def __call__(
+ self,
+ obj_filename,
+ elev=0,
+ azim=0,
+ resolution=512,
+ gif_dst_path='',
+ n_views=120,
+ fps=30,
+ rgb=True
+ ):
+ render(
+ obj_filename,
+ elev=elev,
+ azim=azim,
+ resolution=resolution,
+ gif_dst_path=gif_dst_path,
+ n_views=n_views,
+ fps=fps,
+ device=self.device,
+ rgb=rgb
+ )
diff --git a/infer/image_to_views.py b/infer/image_to_views.py
new file mode 100644
index 0000000000000000000000000000000000000000..97a9ce3fe22a9bbd0bc0e05ea22c9007fc3e296d
--- /dev/null
+++ b/infer/image_to_views.py
@@ -0,0 +1,81 @@
+# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
+# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
+
+# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
+# The below software and/or models in this distribution may have been
+# modified by THL A29 Limited ("Tencent Modifications").
+# All Tencent Modifications are Copyright (C) THL A29 Limited.
+
+# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
+# except for the third-party components listed below.
+# Hunyuan 3D does not impose any additional limitations beyond what is outlined
+# in the repsective licenses of these third-party components.
+# Users must comply with all terms and conditions of original licenses of these third-party
+# components and must ensure that the usage of the third party components adheres to
+# all relevant laws and regulations.
+
+# For avoidance of doubts, Hunyuan 3D means the large language models and
+# their software and algorithms, including trained model weights, parameters (including
+# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
+# fine-tuning enabling code and other elements of the foregoing made publicly available
+# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
+
+import os
+import time
+import torch
+import random
+import numpy as np
+from PIL import Image
+from einops import rearrange
+from PIL import Image, ImageSequence
+
+from .utils import seed_everything, timing_decorator, auto_amp_inference
+from .utils import get_parameter_number, set_parameter_grad_false
+from mvd.hunyuan3d_mvd_std_pipeline import HunYuan3D_MVD_Std_Pipeline
+from mvd.hunyuan3d_mvd_lite_pipeline import Hunyuan3d_MVD_Lite_Pipeline
+
+
+def save_gif(pils, save_path, df=False):
+ # save a list of PIL.Image to gif
+ spf = 4000 / len(pils)
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ pils[0].save(save_path, format="GIF", save_all=True, append_images=pils[1:], duration=spf, loop=0)
+ return save_path
+
+
+class Image2Views():
+ def __init__(self, device="cuda:0", use_lite=False):
+ self.device = device
+ if use_lite:
+ self.pipe = Hunyuan3d_MVD_Lite_Pipeline.from_pretrained(
+ "./weights/mvd_lite",
+ torch_dtype = torch.float16,
+ use_safetensors = True,
+ )
+ else:
+ self.pipe = HunYuan3D_MVD_Std_Pipeline.from_pretrained(
+ "./weights/mvd_std",
+ torch_dtype = torch.float16,
+ use_safetensors = True,
+ )
+ self.pipe = self.pipe.to(device)
+ self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1]
+ set_parameter_grad_false(self.pipe.unet)
+ print('image2views unet model', get_parameter_number(self.pipe.unet))
+
+ @torch.no_grad()
+ @timing_decorator("image to views")
+ @auto_amp_inference
+ def __call__(self, pil_img, seed=0, steps=50, guidance_scale=2.0, guidance_curve=lambda t:2.0):
+ seed_everything(seed)
+ generator = torch.Generator(device=self.device)
+ res_img = self.pipe(pil_img,
+ num_inference_steps=steps,
+ guidance_scale=guidance_scale,
+ guidance_curve=guidance_curve,
+ generat=generator).images
+ show_image = rearrange(np.asarray(res_img[0], dtype=np.uint8), '(n h) (m w) c -> (n m) h w c', n=3, m=2)
+ pils = [res_img[1]]+[Image.fromarray(show_image[idx]) for idx in self.order]
+ torch.cuda.empty_cache()
+ return res_img, pils
+
\ No newline at end of file
diff --git a/infer/rembg.py b/infer/rembg.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ad1031e36670a393b88218a594939cfaab1262c
--- /dev/null
+++ b/infer/rembg.py
@@ -0,0 +1,26 @@
+from rembg import remove, new_session
+from .utils import timing_decorator
+
+class Removebg():
+ def __init__(self, name="u2net"):
+ '''
+ name: rembg
+ '''
+ self.session = new_session(name)
+
+ @timing_decorator("remove background")
+ def __call__(self, rgb_img, force=False):
+ '''
+ inputs:
+ rgb_img: PIL.Image, with RGB mode expected
+ force: bool, input is RGBA mode
+ return:
+ rgba_img: PIL.Image with RGBA mode
+ '''
+ if rgb_img.mode == "RGBA":
+ if force:
+ rgb_img = rgb_img.convert("RGB")
+ else:
+ return rgb_img
+ rgba_img = remove(rgb_img, session=self.session)
+ return rgba_img
diff --git a/infer/text_to_image.py b/infer/text_to_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce1ed3b767b0a5d7ecf4260146f350fed0a72618
--- /dev/null
+++ b/infer/text_to_image.py
@@ -0,0 +1,80 @@
+# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
+# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
+
+# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
+# The below software and/or models in this distribution may have been
+# modified by THL A29 Limited ("Tencent Modifications").
+# All Tencent Modifications are Copyright (C) THL A29 Limited.
+
+# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
+# except for the third-party components listed below.
+# Hunyuan 3D does not impose any additional limitations beyond what is outlined
+# in the repsective licenses of these third-party components.
+# Users must comply with all terms and conditions of original licenses of these third-party
+# components and must ensure that the usage of the third party components adheres to
+# all relevant laws and regulations.
+
+# For avoidance of doubts, Hunyuan 3D means the large language models and
+# their software and algorithms, including trained model weights, parameters (including
+# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
+# fine-tuning enabling code and other elements of the foregoing made publicly available
+# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
+
+import torch
+from .utils import seed_everything, timing_decorator, auto_amp_inference
+from .utils import get_parameter_number, set_parameter_grad_false
+from diffusers import HunyuanDiTPipeline, AutoPipelineForText2Image
+
+class Text2Image():
+ def __init__(self, pretrain="weights/hunyuanDiT", device="cuda:0", save_memory=False):
+ '''
+ save_memory: if GPU memory is low, can set it
+ '''
+ self.save_memory = save_memory
+ self.device = device
+ self.pipe = AutoPipelineForText2Image.from_pretrained(
+ pretrain,
+ torch_dtype = torch.float16,
+ enable_pag = True,
+ pag_applied_layers = ["blocks.(16|17|18|19)"]
+ )
+ set_parameter_grad_false(self.pipe.transformer)
+ print('text2image transformer model', get_parameter_number(self.pipe.transformer))
+ if not save_memory:
+ self.pipe = self.pipe.to(device)
+ self.neg_txt = "文本,特写,裁剪,出框,最差质量,低质量,JPEG伪影,PGLY,重复,病态,残缺,多余的手指,变异的手," \
+ "画得不好的手,画得不好的脸,变异,畸形,模糊,脱水,糟糕的解剖学,糟糕的比例,多余的肢体,克隆的脸," \
+ "毁容,恶心的比例,畸形的肢体,缺失的手臂,缺失的腿,额外的手臂,额外的腿,融合的手指,手指太多,长脖子"
+
+ @torch.no_grad()
+ @timing_decorator('text to image')
+ @auto_amp_inference
+ def __call__(self, *args, **kwargs):
+ if self.save_memory:
+ self.pipe = self.pipe.to(self.device)
+ torch.cuda.empty_cache()
+ res = self.call(*args, **kwargs)
+ self.pipe = self.pipe.to("cpu")
+ else:
+ res = self.call(*args, **kwargs)
+ torch.cuda.empty_cache()
+ return res
+
+ def call(self, prompt, seed=0, steps=25):
+ '''
+ inputs:
+ prompr: str
+ seed: int
+ steps: int
+ return:
+ rgb: PIL.Image
+ '''
+ prompt = prompt + ",白色背景,3D风格,最佳质量"
+ seed_everything(seed)
+ generator = torch.Generator(device=self.device)
+ if seed is not None: generator = generator.manual_seed(int(seed))
+ rgb = self.pipe(prompt=prompt, negative_prompt=self.neg_txt, num_inference_steps=steps,
+ pag_scale=1.3, width=1024, height=1024, generator=generator, return_dict=False)[0][0]
+ torch.cuda.empty_cache()
+ return rgb
+
\ No newline at end of file
diff --git a/infer/utils.py b/infer/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..659c9ac26e33b352d74083b104487871e6249b6a
--- /dev/null
+++ b/infer/utils.py
@@ -0,0 +1,77 @@
+# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
+# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
+
+# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
+# The below software and/or models in this distribution may have been
+# modified by THL A29 Limited ("Tencent Modifications").
+# All Tencent Modifications are Copyright (C) THL A29 Limited.
+
+# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
+# except for the third-party components listed below.
+# Hunyuan 3D does not impose any additional limitations beyond what is outlined
+# in the repsective licenses of these third-party components.
+# Users must comply with all terms and conditions of original licenses of these third-party
+# components and must ensure that the usage of the third party components adheres to
+# all relevant laws and regulations.
+
+# For avoidance of doubts, Hunyuan 3D means the large language models and
+# their software and algorithms, including trained model weights, parameters (including
+# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
+# fine-tuning enabling code and other elements of the foregoing made publicly available
+# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
+
+import os
+import time
+import random
+import numpy as np
+import torch
+from torch.cuda.amp import autocast, GradScaler
+from functools import wraps
+
+def seed_everything(seed):
+ '''
+ seed everthing
+ '''
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ os.environ["PL_GLOBAL_SEED"] = str(seed)
+
+def timing_decorator(category: str):
+ '''
+ timing_decorator: record time
+ '''
+ def decorator(func):
+ func.call_count = 0
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ start_time = time.time()
+ result = func(*args, **kwargs)
+ end_time = time.time()
+ elapsed_time = end_time - start_time
+ func.call_count += 1
+ print(f"[HunYuan3D]-[{category}], cost time: {elapsed_time:.4f}s") # huiwen
+ return result
+ return wrapper
+ return decorator
+
+def auto_amp_inference(func):
+ '''
+ with torch.cuda.amp.autocast()"
+ xxx
+ '''
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ with autocast():
+ output = func(*args, **kwargs)
+ return output
+ return wrapper
+
+def get_parameter_number(model):
+ total_num = sum(p.numel() for p in model.parameters())
+ trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ return {'Total': total_num, 'Trainable': trainable_num}
+
+def set_parameter_grad_false(model):
+ for p in model.parameters():
+ p.requires_grad = False
diff --git a/infer/views_to_mesh.py b/infer/views_to_mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0fe680f0fe235501e503bfd410e56931ee1eeff
--- /dev/null
+++ b/infer/views_to_mesh.py
@@ -0,0 +1,94 @@
+# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
+# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
+
+# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
+# The below software and/or models in this distribution may have been
+# modified by THL A29 Limited ("Tencent Modifications").
+# All Tencent Modifications are Copyright (C) THL A29 Limited.
+
+# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
+# except for the third-party components listed below.
+# Hunyuan 3D does not impose any additional limitations beyond what is outlined
+# in the repsective licenses of these third-party components.
+# Users must comply with all terms and conditions of original licenses of these third-party
+# components and must ensure that the usage of the third party components adheres to
+# all relevant laws and regulations.
+
+# For avoidance of doubts, Hunyuan 3D means the large language models and
+# their software and algorithms, including trained model weights, parameters (including
+# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
+# fine-tuning enabling code and other elements of the foregoing made publicly available
+# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
+
+import os
+import time
+import torch
+import random
+import numpy as np
+from PIL import Image
+from einops import rearrange
+from PIL import Image, ImageSequence
+
+from .utils import seed_everything, timing_decorator, auto_amp_inference
+from .utils import get_parameter_number, set_parameter_grad_false
+from svrm.predictor import MV23DPredictor
+
+
+class Views2Mesh():
+ def __init__(self, mv23d_cfg_path, mv23d_ckt_path, device="cuda:0", use_lite=False):
+ '''
+ mv23d_cfg_path: config yaml file
+ mv23d_ckt_path: path to ckpt
+ use_lite:
+ '''
+ self.mv23d_predictor = MV23DPredictor(mv23d_ckt_path, mv23d_cfg_path, device=device)
+ self.mv23d_predictor.model.eval()
+ self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1]
+ set_parameter_grad_false(self.mv23d_predictor.model)
+ print('view2mesh model', get_parameter_number(self.mv23d_predictor.model))
+
+ @torch.no_grad()
+ @timing_decorator("views to mesh")
+ @auto_amp_inference
+ def __call__(
+ self,
+ views_pil=None,
+ cond_pil=None,
+ gif_pil=None,
+ seed=0,
+ target_face_count = 10000,
+ do_texture_mapping = True,
+ save_folder='./outputs/test'
+ ):
+ '''
+ can set views_pil, cond_pil simutaously or set gif_pil only
+ seed: int
+ target_face_count: int
+ save_folder: path to save mesh files
+ '''
+ save_dir = save_folder
+ os.makedirs(save_dir, exist_ok=True)
+
+ if views_pil is not None and cond_pil is not None:
+ show_image = rearrange(np.asarray(views_pil, dtype=np.uint8),
+ '(n h) (m w) c -> (n m) h w c', n=3, m=2)
+ views = [Image.fromarray(show_image[idx]) for idx in self.order]
+ image_list = [cond_pil]+ views
+ image_list = [img.convert('RGB') for img in image_list]
+ elif gif_pil is not None:
+ image_list = [img.convert('RGB') for img in ImageSequence.Iterator(gif_pil)]
+
+ image_input = image_list[0]
+ image_list = image_list[1:] + image_list[:1]
+
+ seed_everything(seed)
+ self.mv23d_predictor.predict(
+ image_list,
+ save_dir = save_dir,
+ image_input = image_input,
+ target_face_count = target_face_count,
+ do_texture_mapping = do_texture_mapping
+ )
+ torch.cuda.empty_cache()
+ return save_dir
+
\ No newline at end of file
diff --git a/mvd/__init__.py b/mvd/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mvd/hunyuan3d_mvd_lite_pipeline.py b/mvd/hunyuan3d_mvd_lite_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..29abd3a1792a8b96b50bbb61845cb0284255b113
--- /dev/null
+++ b/mvd/hunyuan3d_mvd_lite_pipeline.py
@@ -0,0 +1,493 @@
+# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
+# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
+
+# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
+# The below software and/or models in this distribution may have been
+# modified by THL A29 Limited ("Tencent Modifications").
+# All Tencent Modifications are Copyright (C) THL A29 Limited.
+
+# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
+# except for the third-party components listed below.
+# Hunyuan 3D does not impose any additional limitations beyond what is outlined
+# in the repsective licenses of these third-party components.
+# Users must comply with all terms and conditions of original licenses of these third-party
+# components and must ensure that the usage of the third party components adheres to
+# all relevant laws and regulations.
+
+# For avoidance of doubts, Hunyuan 3D means the large language models and
+# their software and algorithms, including trained model weights, parameters (including
+# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
+# fine-tuning enabling code and other elements of the foregoing made publicly available
+# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
+
+import math
+import numpy
+import torch
+import inspect
+import warnings
+from PIL import Image
+from einops import rearrange
+import torch.nn.functional as F
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.configuration_utils import FrozenDict
+from diffusers.image_processor import VaeImageProcessor
+from typing import Any, Callable, Dict, List, Optional, Union
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers import DDPMScheduler, EulerAncestralDiscreteScheduler, ImagePipelineOutput
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ LoraLoaderMixin,
+ TextualInversionLoaderMixin
+)
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection
+)
+from diffusers.models.attention_processor import (
+ Attention,
+ AttnProcessor,
+ XFormersAttnProcessor,
+ AttnProcessor2_0
+)
+
+from .utils import to_rgb_image, white_out_background, recenter_img
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from here import Hunyuan3d_MVD_Qing_Pipeline
+
+ >>> pipe = Hunyuan3d_MVD_Qing_Pipeline.from_pretrained(
+ ... "Tencent-Hunyuan-3D/MVD-Qing", torch_dtype=torch.float16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> img = Image.open("demo.png")
+ >>> res_img = pipe(img).images[0]
+"""
+
+def unscale_latents(latents): return latents / 0.75 + 0.22
+def unscale_image (image ): return image / 0.50 * 0.80
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+
+class ReferenceOnlyAttnProc(torch.nn.Module):
+ # reference attention
+ def __init__(self, chained_proc, enabled=False, name=None):
+ super().__init__()
+ self.enabled = enabled
+ self.chained_proc = chained_proc
+ self.name = name
+
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, mode="w", ref_dict=None):
+ if encoder_hidden_states is None: encoder_hidden_states = hidden_states
+ if self.enabled:
+ if mode == 'w':
+ ref_dict[self.name] = encoder_hidden_states
+ elif mode == 'r':
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1)
+ res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask)
+ return res
+
+
+# class RowWiseAttnProcessor2_0:
+# def __call__(self, attn,
+# hidden_states,
+# encoder_hidden_states=None,
+# attention_mask=None,
+# temb=None,
+# num_views=6,
+# *args,
+# **kwargs):
+# residual = hidden_states
+# if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb)
+
+# input_ndim = hidden_states.ndim
+# if input_ndim == 4:
+# batch_size, channel, height, width = hidden_states.shape
+# hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+# if encoder_hidden_states is None:
+# batch_size, sequence_length, _ = hidden_states.shape
+# else:
+# batch_size, sequence_length, _ = encoder_hidden_states.shape
+
+# if attention_mask is not None:
+# attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+# attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+# if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+# query = attn.to_q(hidden_states)
+# if encoder_hidden_states is None: encoder_hidden_states = hidden_states
+# elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+# # encoder_hidden_states [B, 6hw+hw, C] if ref att
+# key = attn.to_k(encoder_hidden_states) # [B, Vhw+hw, C]
+# value = attn.to_v(encoder_hidden_states) # [B, Vhw+hw, C]
+
+# mv_flag = hidden_states.shape[1] < encoder_hidden_states.shape[1] and encoder_hidden_states.shape[1] != 77
+# if mv_flag:
+# target_size = int(math.sqrt(hidden_states.shape[1] // num_views))
+# assert target_size ** 2 * num_views == hidden_states.shape[1]
+
+# gen_key = key[:, :num_views*target_size*target_size, :]
+# ref_key = key[:, num_views*target_size*target_size:, :]
+# gen_value = value[:, :num_views*target_size*target_size, :]
+# ref_value = value[:, num_views*target_size*target_size:, :]
+
+# # rowwise attention
+# query, gen_key, gen_value = \
+# rearrange( query, "b (v1 h v2 w) c -> (b h) (v1 v2 w) c",
+# v1=num_views//2, v2=2, h=target_size, w=target_size), \
+# rearrange( gen_key, "b (v1 h v2 w) c -> (b h) (v1 v2 w) c",
+# v1=num_views//2, v2=2, h=target_size, w=target_size), \
+# rearrange(gen_value, "b (v1 h v2 w) c -> (b h) (v1 v2 w) c",
+# v1=num_views//2, v2=2, h=target_size, w=target_size)
+
+# inner_dim = key.shape[-1]
+# ref_size = int(math.sqrt(ref_key.shape[1]))
+# ref_key_expanded = ref_key.view(batch_size, 1, ref_size * ref_size, inner_dim)
+# ref_key_expanded = ref_key_expanded.expand(-1, target_size, -1, -1).contiguous()
+# ref_key_expanded = ref_key_expanded.view(batch_size * target_size, ref_size * ref_size, inner_dim)
+# key = torch.cat([ gen_key, ref_key_expanded], dim=1)
+
+# ref_value_expanded = ref_value.view(batch_size, 1, ref_size * ref_size, inner_dim)
+# ref_value_expanded = ref_value_expanded.expand(-1, target_size, -1, -1).contiguous()
+# ref_value_expanded = ref_value_expanded.view(batch_size * target_size, ref_size * ref_size, inner_dim)
+# value = torch.cat([gen_value, ref_value_expanded], dim=1)
+# h = target_size
+# else:
+# target_size = int(math.sqrt(hidden_states.shape[1]))
+# h = 1
+# num_views = 1
+
+# inner_dim = key.shape[-1]
+# head_dim = inner_dim // attn.heads
+
+# query = query.view(batch_size * h, -1, attn.heads, head_dim).transpose(1, 2)
+# key = key.view(batch_size * h, -1, attn.heads, head_dim).transpose(1, 2)
+# value = value.view(batch_size * h, -1, attn.heads, head_dim).transpose(1, 2)
+
+# hidden_states = F.scaled_dot_product_attention(query, key, value,
+# attn_mask=attention_mask,
+# dropout_p=0.0,
+# is_causal=False)
+# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size * h,
+# -1,
+# attn.heads * head_dim).to(query.dtype)
+# hidden_states = attn.to_out[1](attn.to_out[0](hidden_states))
+
+# if mv_flag: hidden_states = rearrange(hidden_states, "(b h) (v1 v2 w) c -> b (v1 h v2 w) c",
+# b=batch_size, v1=num_views//2,
+# v2=2, h=target_size, w=target_size)
+
+# if input_ndim == 4:
+# hidden_states = hidden_states.transpose(-1, -2)
+# hidden_states = hidden_states.reshape(batch_size,
+# channel,
+# target_size,
+# target_size)
+# if attn.residual_connection: hidden_states = hidden_states + residual
+# hidden_states = hidden_states / attn.rescale_output_factor
+# return hidden_states
+
+
+class RefOnlyNoisedUNet(torch.nn.Module):
+ def __init__(self, unet, train_sched, val_sched):
+ super().__init__()
+ self.unet = unet
+ self.train_sched = train_sched
+ self.val_sched = val_sched
+
+ unet_lora_attn_procs = dict()
+ for name, _ in unet.attn_processors.items():
+ unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(AttnProcessor2_0(),
+ enabled=name.endswith("attn1.processor"),
+ name=name)
+ unet.set_attn_processor(unet_lora_attn_procs)
+
+ def __getattr__(self, name: str):
+ try:
+ return super().__getattr__(name)
+ except AttributeError:
+ return getattr(self.unet, name)
+
+ def forward(self, sample, timestep, encoder_hidden_states, *args, cross_attention_kwargs, **kwargs):
+ cond_lat = cross_attention_kwargs['cond_lat']
+ noise = torch.randn_like(cond_lat)
+ if self.training:
+ noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep)
+ noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep)
+ else:
+ noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1))
+ noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
+
+ ref_dict = {}
+ self.unet(noisy_cond_lat,
+ timestep,
+ encoder_hidden_states,
+ *args,
+ cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict),
+ **kwargs)
+ return self.unet(sample,
+ timestep,
+ encoder_hidden_states,
+ *args,
+ cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict),
+ **kwargs)
+
+
+class Hunyuan3d_MVD_Lite_Pipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin):
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ vision_encoder: CLIPVisionModelWithProjection,
+ feature_extractor_clip: CLIPImageProcessor,
+ feature_extractor_vae: CLIPImageProcessor,
+ ramping_coefficients: Optional[list] = None,
+ safety_checker=None,
+ ):
+ DiffusionPipeline.__init__(self)
+ self.register_modules(
+ vae=vae,
+ unet=unet,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ text_encoder=text_encoder,
+ vision_encoder=vision_encoder,
+ feature_extractor_vae=feature_extractor_vae,
+ feature_extractor_clip=feature_extractor_clip)
+ '''
+ rewrite the stable diffusion pipeline
+ vae: vae
+ unet: unet
+ tokenizer: tokenizer
+ scheduler: scheduler
+ text_encoder: text_encoder
+ vision_encoder: vision_encoder
+ feature_extractor_vae: feature_extractor_vae
+ feature_extractor_clip: feature_extractor_clip
+ '''
+ self.register_to_config(ramping_coefficients=ramping_coefficients)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ extra_step_kwargs = {}
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_eta: extra_step_kwargs["eta"] = eta
+
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator: extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)[0]
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None: uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError()
+ elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt): raise ValueError()
+ else: uncond_tokens = negative_prompt
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt")
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(device), attention_mask=attention_mask)
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ @torch.no_grad()
+ def encode_condition_image(self, image: torch.Tensor): return self.vae.encode(image).latent_dist.sample()
+
+ @torch.no_grad()
+ def __call__(self, image=None,
+ width=640,
+ height=960,
+ num_inference_steps=75,
+ return_dict=True,
+ generator=None,
+ **kwargs):
+ batch_size = 1
+ num_images_per_prompt = 1
+ output_type = 'pil'
+ do_classifier_free_guidance = True
+ guidance_rescale = 0.
+ if isinstance(self.unet, UNet2DConditionModel):
+ self.unet = RefOnlyNoisedUNet(self.unet, None, self.scheduler).eval()
+
+ cond_image = recenter_img(image)
+ cond_image = to_rgb_image(image)
+ image = cond_image
+ image_1 = self.feature_extractor_vae(images=image, return_tensors="pt").pixel_values
+ image_2 = self.feature_extractor_clip(images=image, return_tensors="pt").pixel_values
+ image_1 = image_1.to(device=self.vae.device, dtype=self.vae.dtype)
+ image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype)
+
+ cond_lat = self.encode_condition_image(image_1)
+ negative_lat = self.encode_condition_image(torch.zeros_like(image_1))
+ cond_lat = torch.cat([negative_lat, cond_lat])
+ cross_attention_kwargs = dict(cond_lat=cond_lat)
+
+ global_embeds = self.vision_encoder(image_2, output_hidden_states=False).image_embeds.unsqueeze(-2)
+ encoder_hidden_states = self._encode_prompt('', self.device, num_images_per_prompt, False)
+ ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
+ prompt_embeds = torch.cat([encoder_hidden_states, encoder_hidden_states + global_embeds * ramp])
+
+ device = self._execution_device
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ None)
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, 0.0)
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+
+ # set adaptive cfg
+ # the image order is:
+ # [0, 60,
+ # 120, 180,
+ # 240, 300]
+ # the cfg is set as 3, 2.5, 2, 1.5
+
+ tmp_guidance_scale = torch.ones_like(latents)
+ tmp_guidance_scale[:, :, :40, :40] = 3
+ tmp_guidance_scale[:, :, :40, 40:] = 2.5
+ tmp_guidance_scale[:, :, 40:80, :40] = 2
+ tmp_guidance_scale[:, :, 40:80, 40:] = 1.5
+ tmp_guidance_scale[:, :, 80:120, :40] = 2
+ tmp_guidance_scale[:, :, 80:120, 40:] = 2.5
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ noise_pred = self.unet(latent_model_input, t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False)[0]
+
+ adaptive_guidance_scale = (2 + 16 * (t / 1000) ** 5) / 3
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + \
+ tmp_guidance_scale * adaptive_guidance_scale * \
+ (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ if i==len(timesteps)-1 or ((i+1)>num_warmup_steps and (i+1)%self.scheduler.order==0):
+ progress_bar.update()
+
+ latents = unscale_latents(latents)
+ image = unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0])
+ image = self.image_processor.postprocess(image, output_type='pil')[0]
+ image = [image, cond_image]
+ return ImagePipelineOutput(images=image) if return_dict else (image,)
+
diff --git a/mvd/hunyuan3d_mvd_std_pipeline.py b/mvd/hunyuan3d_mvd_std_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..88cd8bae1a35d663e3275e50077af59ee3fcae33
--- /dev/null
+++ b/mvd/hunyuan3d_mvd_std_pipeline.py
@@ -0,0 +1,471 @@
+# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
+# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
+
+# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
+# The below software and/or models in this distribution may have been
+# modified by THL A29 Limited ("Tencent Modifications").
+# All Tencent Modifications are Copyright (C) THL A29 Limited.
+
+# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
+# except for the third-party components listed below.
+# Hunyuan 3D does not impose any additional limitations beyond what is outlined
+# in the repsective licenses of these third-party components.
+# Users must comply with all terms and conditions of original licenses of these third-party
+# components and must ensure that the usage of the third party components adheres to
+# all relevant laws and regulations.
+
+# For avoidance of doubts, Hunyuan 3D means the large language models and
+# their software and algorithms, including trained model weights, parameters (including
+# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
+# fine-tuning enabling code and other elements of the foregoing made publicly available
+# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
+
+import inspect
+from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import os
+import torch
+import numpy as np
+from PIL import Image
+
+import diffusers
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.models.attention_processor import (
+ Attention,
+ AttnProcessor,
+ XFormersAttnProcessor,
+ AttnProcessor2_0
+)
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ EulerAncestralDiscreteScheduler,
+ UNet2DConditionModel,
+ ImagePipelineOutput
+)
+import transformers
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+ CLIPTextModelWithProjection
+)
+
+from .utils import to_rgb_image, white_out_background, recenter_img
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import Hunyuan3d_MVD_XL_Pipeline
+
+ >>> pipe = Hunyuan3d_MVD_XL_Pipeline.from_pretrained(
+ ... "Tencent-Hunyuan-3D/MVD-XL", torch_dtype=torch.float16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> img = Image.open("demo.png")
+ >>> res_img = pipe(img).images[0]
+ ```
+"""
+
+
+
+def scale_latents(latents): return (latents - 0.22) * 0.75
+def unscale_latents(latents): return (latents / 0.75) + 0.22
+def scale_image(image): return (image - 0.5) / 0.5
+def scale_image_2(image): return (image * 0.5) / 0.8
+def unscale_image(image): return (image * 0.5) + 0.5
+def unscale_image_2(image): return (image * 0.8) / 0.5
+
+
+
+
+class ReferenceOnlyAttnProc(torch.nn.Module):
+ def __init__(self, chained_proc, enabled=False, name=None):
+ super().__init__()
+ self.enabled = enabled
+ self.chained_proc = chained_proc
+ self.name = name
+
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, mode="w", ref_dict=None):
+ encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states
+ if self.enabled:
+ if mode == 'w': ref_dict[self.name] = encoder_hidden_states
+ elif mode == 'r': encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1)
+ else: raise Exception(f"mode should not be {mode}")
+ return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask)
+
+
+class RefOnlyNoisedUNet(torch.nn.Module):
+ def __init__(self, unet, scheduler) -> None:
+ super().__init__()
+ self.unet = unet
+ self.scheduler = scheduler
+
+ unet_attn_procs = dict()
+ for name, _ in unet.attn_processors.items():
+ if torch.__version__ >= '2.0': default_attn_proc = AttnProcessor2_0()
+ elif is_xformers_available(): default_attn_proc = XFormersAttnProcessor()
+ else: default_attn_proc = AttnProcessor()
+ unet_attn_procs[name] = ReferenceOnlyAttnProc(
+ default_attn_proc, enabled=name.endswith("attn1.processor"), name=name
+ )
+ unet.set_attn_processor(unet_attn_procs)
+
+ def __getattr__(self, name: str):
+ try:
+ return super().__getattr__(name)
+ except AttributeError:
+ return getattr(self.unet, name)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ class_labels: Optional[torch.Tensor] = None,
+ down_block_res_samples: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_res_sample: Optional[Tuple[torch.Tensor]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ return_dict: bool = True,
+ **kwargs
+ ):
+
+ dtype = self.unet.dtype
+
+ # cond_lat add same level noise
+ cond_lat = cross_attention_kwargs['cond_lat']
+ noise = torch.randn_like(cond_lat)
+
+ noisy_cond_lat = self.scheduler.add_noise(cond_lat, noise, timestep.reshape(-1))
+ noisy_cond_lat = self.scheduler.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
+
+ ref_dict = {}
+
+ _ = self.unet(
+ noisy_cond_lat,
+ timestep,
+ encoder_hidden_states = encoder_hidden_states,
+ class_labels = class_labels,
+ cross_attention_kwargs = dict(mode="w", ref_dict=ref_dict),
+ added_cond_kwargs = added_cond_kwargs,
+ return_dict = return_dict,
+ **kwargs
+ )
+
+ res = self.unet(
+ sample,
+ timestep,
+ encoder_hidden_states,
+ class_labels=class_labels,
+ cross_attention_kwargs = dict(mode="r", ref_dict=ref_dict),
+ down_block_additional_residuals = [
+ sample.to(dtype=dtype) for sample in down_block_res_samples
+ ] if down_block_res_samples is not None else None,
+ mid_block_additional_residual = (
+ mid_block_res_sample.to(dtype=dtype)
+ if mid_block_res_sample is not None else None),
+ added_cond_kwargs = added_cond_kwargs,
+ return_dict = return_dict,
+ **kwargs
+ )
+ return res
+
+
+
+class HunYuan3D_MVD_Std_Pipeline(diffusers.DiffusionPipeline):
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ feature_extractor_vae: CLIPImageProcessor,
+ vision_processor: CLIPImageProcessor,
+ vision_encoder: CLIPVisionModelWithProjection,
+ vision_encoder_2: CLIPVisionModelWithProjection,
+ ramping_coefficients: Optional[list] = None,
+ add_watermarker: Optional[bool] = None,
+ safety_checker = None,
+ ):
+ DiffusionPipeline.__init__(self)
+
+ self.register_modules(
+ vae=vae, unet=unet, scheduler=scheduler, safety_checker=None, feature_extractor_vae=feature_extractor_vae,
+ vision_processor=vision_processor, vision_encoder=vision_encoder, vision_encoder_2=vision_encoder_2,
+ )
+ self.register_to_config( ramping_coefficients = ramping_coefficients)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.default_sample_size = self.unet.config.sample_size
+ self.watermark = None
+ self.prepare_init = False
+
+ def prepare(self):
+ assert isinstance(self.unet, UNet2DConditionModel), "unet should be UNet2DConditionModel"
+ self.unet = RefOnlyNoisedUNet(self.unet, self.scheduler).eval()
+ self.prepare_init = True
+
+ def encode_image(self, image: torch.Tensor, scale_factor: bool = False):
+ latent = self.vae.encode(image).latent_dist.sample()
+ return (latent * self.vae.config.scaling_factor) if scale_factor else latent
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def _get_add_time_ids(
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
+ ):
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, " \
+ f"but a vector of {passed_add_embed_dim} was created. The model has an incorrect config." \
+ f" Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ return add_time_ids
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta: extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator: extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: Image.Image = None,
+ guidance_scale = 2.0,
+ output_type: Optional[str] = "pil",
+ num_inference_steps: int = 50,
+ return_dict: bool = True,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ latent: torch.Tensor = None,
+ guidance_curve = None,
+ **kwargs
+ ):
+ if not self.prepare_init:
+ self.prepare()
+
+ here = dict(device=self.vae.device, dtype=self.vae.dtype)
+
+ batch_size = 1
+ num_images_per_prompt = 1
+ width, height = 512 * 2, 512 * 3
+ target_size = original_size = (height, width)
+
+ self._guidance_scale = guidance_scale
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ self.vae.dtype,
+ device,
+ generator,
+ latents=latent,
+ )
+
+ # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+
+ # Prepare added time ids & embeddings
+ text_encoder_projection_dim = 1280
+ add_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ dtype=self.vae.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ negative_add_time_ids = add_time_ids
+
+ # hw: preprocess
+ cond_image = recenter_img(image)
+ cond_image = to_rgb_image(image)
+ image_vae = self.feature_extractor_vae(images=cond_image, return_tensors="pt").pixel_values.to(**here)
+ image_clip = self.vision_processor(images=cond_image, return_tensors="pt").pixel_values.to(**here)
+
+ # hw: get cond_lat from cond_img using vae
+ cond_lat = self.encode_image(image_vae, scale_factor=False)
+ negative_lat = self.encode_image(torch.zeros_like(image_vae), scale_factor=False)
+ cond_lat = torch.cat([negative_lat, cond_lat])
+
+ # hw: get visual global embedding using clip
+ global_embeds_1 = self.vision_encoder(image_clip, output_hidden_states=False).image_embeds.unsqueeze(-2)
+ global_embeds_2 = self.vision_encoder_2(image_clip, output_hidden_states=False).image_embeds.unsqueeze(-2)
+ global_embeds = torch.concat([global_embeds_1, global_embeds_2], dim=-1)
+
+ ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
+ prompt_embeds = self.uc_text_emb.to(**here)
+ pooled_prompt_embeds = self.uc_text_emb_2.to(**here)
+
+ prompt_embeds = prompt_embeds + global_embeds * ramp
+ add_text_embeds = pooled_prompt_embeds
+
+ if self.do_classifier_free_guidance:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ timestep_cond = None
+ self._num_timesteps = len(timesteps)
+
+ if guidance_curve is None:
+ guidance_curve = lambda t: guidance_scale
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=dict(cond_lat=cond_lat),
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+
+ # cur_guidance_scale = self.guidance_scale
+ cur_guidance_scale = guidance_curve(t) # 1.5 + 2.5 * ((t/1000)**2)
+
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + cur_guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # cur_guidance_scale_topleft = (cur_guidance_scale - 1.0) * 4 + 1.0
+ # noise_pred_top_left = noise_pred_uncond +
+ # cur_guidance_scale_topleft * (noise_pred_text - noise_pred_uncond)
+ # _, _, h, w = noise_pred.shape
+ # noise_pred[:, :, :h//3, :w//2] = noise_pred_top_left[:, :, :h//3, :w//2]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ latents = unscale_latents(latents)
+
+ if output_type=="latent":
+ image = latents
+ else:
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image = unscale_image(unscale_image_2(image)).clamp(0, 1)
+ image = [
+ Image.fromarray((image[0]*255+0.5).clamp_(0, 255).permute(1, 2, 0).cpu().numpy().astype("uint8")),
+ # self.image_processor.postprocess(image, output_type=output_type)[0],
+ cond_image.resize((512, 512))
+ ]
+
+ if not return_dict: return (image,)
+ return ImagePipelineOutput(images=image)
+
+ def save_pretrained(self, save_directory):
+ # uc_text_emb.pt and uc_text_emb_2.pt are inferenced and saved in advance
+ super().save_pretrained(save_directory)
+ torch.save(self.uc_text_emb, os.path.join(save_directory, "uc_text_emb.pt"))
+ torch.save(self.uc_text_emb_2, os.path.join(save_directory, "uc_text_emb_2.pt"))
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
+ # uc_text_emb.pt and uc_text_emb_2.pt are inferenced and saved in advance
+ pipeline = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
+ pipeline.uc_text_emb = torch.load(os.path.join(pretrained_model_name_or_path, "uc_text_emb.pt"))
+ pipeline.uc_text_emb_2 = torch.load(os.path.join(pretrained_model_name_or_path, "uc_text_emb_2.pt"))
+ return pipeline
diff --git a/mvd/utils.py b/mvd/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7423e9d75446363d15b0eb9e547e17f31c65f88
--- /dev/null
+++ b/mvd/utils.py
@@ -0,0 +1,85 @@
+# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
+# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
+
+# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
+# The below software and/or models in this distribution may have been
+# modified by THL A29 Limited ("Tencent Modifications").
+# All Tencent Modifications are Copyright (C) THL A29 Limited.
+
+# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
+# except for the third-party components listed below.
+# Hunyuan 3D does not impose any additional limitations beyond what is outlined
+# in the repsective licenses of these third-party components.
+# Users must comply with all terms and conditions of original licenses of these third-party
+# components and must ensure that the usage of the third party components adheres to
+# all relevant laws and regulations.
+
+# For avoidance of doubts, Hunyuan 3D means the large language models and
+# their software and algorithms, including trained model weights, parameters (including
+# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
+# fine-tuning enabling code and other elements of the foregoing made publicly available
+# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
+
+import numpy as np
+from PIL import Image
+
+def to_rgb_image(maybe_rgba: Image.Image):
+ '''
+ convert a PIL.Image to rgb mode with white background
+ maybe_rgba: PIL.Image
+ return: PIL.Image
+ '''
+ if maybe_rgba.mode == 'RGB':
+ return maybe_rgba
+ elif maybe_rgba.mode == 'RGBA':
+ rgba = maybe_rgba
+ img = np.random.randint(255, 256, size=[rgba.size[1], rgba.size[0], 3], dtype=np.uint8)
+ img = Image.fromarray(img, 'RGB')
+ img.paste(rgba, mask=rgba.getchannel('A'))
+ return img
+ else:
+ raise ValueError("Unsupported image type.", maybe_rgba.mode)
+
+def white_out_background(pil_img, is_gray_fg=True):
+ data = pil_img.getdata()
+ new_data = []
+ # convert fore-ground white to gray
+ for r, g, b, a in data:
+ if a < 16:
+ new_data.append((255, 255, 255, 0)) # back-ground to be black
+ else:
+ is_white = is_gray_fg and (r>235) and (g>235) and (b>235)
+ new_r = 235 if is_white else r
+ new_g = 235 if is_white else g
+ new_b = 235 if is_white else b
+ new_data.append((new_r, new_g, new_b, a))
+ pil_img.putdata(new_data)
+ return pil_img
+
+def recenter_img(img, size=512, color=(255,255,255)):
+ img = white_out_background(img)
+ mask = np.array(img)[..., 3]
+ image = np.array(img)[..., :3]
+
+ H, W, C = image.shape
+ coords = np.nonzero(mask)
+ x_min, x_max = coords[0].min(), coords[0].max()
+ y_min, y_max = coords[1].min(), coords[1].max()
+ h = x_max - x_min
+ w = y_max - y_min
+ if h == 0 or w == 0: raise ValueError
+ roi = image[x_min:x_max, y_min:y_max]
+
+ border_ratio = 0.15 # 0.2
+ pad_h = int(h * border_ratio)
+ pad_w = int(w * border_ratio)
+
+ result_tmp = np.full((h + pad_h, w + pad_w, C), color, dtype=np.uint8)
+ result_tmp[pad_h // 2: pad_h // 2 + h, pad_w // 2: pad_w // 2 + w] = roi
+
+ cur_h, cur_w = result_tmp.shape[:2]
+ side = max(cur_h, cur_w)
+ result = np.full((side, side, C), color, dtype=np.uint8)
+ result[(side-cur_h)//2:(side-cur_h)//2+cur_h, (side-cur_w)//2:(side - cur_w)//2+cur_w,:] = result_tmp
+ result = Image.fromarray(result)
+ return result.resize((size, size), Image.LANCZOS) if size else result
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a187439832c19e4a7902be324919d035c28a097c
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,22 @@
+--find-links https://download.pytorch.org/whl/cu118
+torch==2.2.0
+torchvision==0.17.0
+diffusers
+transformers
+rembg
+tqdm
+omegaconf
+matplotlib
+opencv-python
+imageio
+jaxtyping
+einops
+SentencePiece
+accelerate
+trimesh
+PyMCubes
+xatlas
+libigl
+git+https://github.com/facebookresearch/pytorch3d
+git+https://github.com/NVlabs/nvdiffrast
+open3d
\ No newline at end of file
diff --git a/scripts/image_to_3d.sh b/scripts/image_to_3d.sh
new file mode 100644
index 0000000000000000000000000000000000000000..730bef20d97efe4dc715c4164d46007b5072bfe8
--- /dev/null
+++ b/scripts/image_to_3d.sh
@@ -0,0 +1,8 @@
+# image to 3d
+
+python main.py \
+ --image_prompt ./demos/example_000.png \
+ --save_folder ./outputs/test/ \
+ --max_faces_num 90000 \
+ --do_texture \
+ --do_render
diff --git a/scripts/image_to_3d_demo.sh b/scripts/image_to_3d_demo.sh
new file mode 100644
index 0000000000000000000000000000000000000000..eea6263c4e3e7a4ca1685bd26c0299c05f8e2e83
--- /dev/null
+++ b/scripts/image_to_3d_demo.sh
@@ -0,0 +1,8 @@
+# image to 3d
+
+python main.py \
+ --image_prompt ./demos/example_000.png \
+ --save_folder ./outputs/test/ \
+ --max_faces_num 90000 \
+ --do_texture_mapping \
+ --do_render
diff --git a/scripts/image_to_3d_fast.sh b/scripts/image_to_3d_fast.sh
new file mode 100644
index 0000000000000000000000000000000000000000..024107a438e01dbad99924bf59cad92fb924e307
--- /dev/null
+++ b/scripts/image_to_3d_fast.sh
@@ -0,0 +1,6 @@
+# image to 3d fast
+python main.py \
+ --image_prompt ./demos/example_000.png \
+ --save_folder ./outputs/test/ \
+ --max_faces_num 10000 \
+ --use_lite
diff --git a/scripts/image_to_3d_fast_demo.sh b/scripts/image_to_3d_fast_demo.sh
new file mode 100644
index 0000000000000000000000000000000000000000..024107a438e01dbad99924bf59cad92fb924e307
--- /dev/null
+++ b/scripts/image_to_3d_fast_demo.sh
@@ -0,0 +1,6 @@
+# image to 3d fast
+python main.py \
+ --image_prompt ./demos/example_000.png \
+ --save_folder ./outputs/test/ \
+ --max_faces_num 10000 \
+ --use_lite
diff --git a/scripts/text_to_3d.sh b/scripts/text_to_3d.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a81341df18eceef0f47ad8e41229fbc6ff0de3b6
--- /dev/null
+++ b/scripts/text_to_3d.sh
@@ -0,0 +1,7 @@
+# text to 3d fast
+python main.py \
+ --text_prompt "a lovely cat" \
+ --save_folder ./outputs/test/ \
+ --max_faces_num 90000 \
+ --do_texture \
+ --do_render
diff --git a/scripts/text_to_3d_demo.sh b/scripts/text_to_3d_demo.sh
new file mode 100644
index 0000000000000000000000000000000000000000..88f46bfc1d3da38b08014a4dda85103dc6e19b9d
--- /dev/null
+++ b/scripts/text_to_3d_demo.sh
@@ -0,0 +1,7 @@
+# text to 3d fast
+python main.py \
+ --text_prompt "a lovely rabbit" \
+ --save_folder ./outputs/test/ \
+ --max_faces_num 90000 \
+ --do_texture_mapping \
+ --do_render
diff --git a/scripts/text_to_3d_fast.sh b/scripts/text_to_3d_fast.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ed41c4a887d050dbd6b2da839c87d174b5c6efc1
--- /dev/null
+++ b/scripts/text_to_3d_fast.sh
@@ -0,0 +1,6 @@
+# text to 3d fast
+python main.py \
+ --text_prompt "一个广式茶杯" \
+ --save_folder ./outputs/test/ \
+ --max_faces_num 10000 \
+ --use_lite
diff --git a/scripts/text_to_3d_fast_demo.sh b/scripts/text_to_3d_fast_demo.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ed41c4a887d050dbd6b2da839c87d174b5c6efc1
--- /dev/null
+++ b/scripts/text_to_3d_fast_demo.sh
@@ -0,0 +1,6 @@
+# text to 3d fast
+python main.py \
+ --text_prompt "一个广式茶杯" \
+ --save_folder ./outputs/test/ \
+ --max_faces_num 10000 \
+ --use_lite
diff --git a/svrm/.DS_Store b/svrm/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..d9f5f94a8d4408fa092739ff9d5775a56e8da914
Binary files /dev/null and b/svrm/.DS_Store differ
diff --git a/svrm/configs/2024-10-24T22-36-18-project.yaml b/svrm/configs/2024-10-24T22-36-18-project.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..286ff8e108a3b241aed860d21681479ccf0a1e63
--- /dev/null
+++ b/svrm/configs/2024-10-24T22-36-18-project.yaml
@@ -0,0 +1,32 @@
+model:
+ base_learning_rate: 3.0e-05
+ target: svrm.ldm.models.svrm.SVRMModel
+ params:
+
+ img_encoder_config:
+ target: svrm.ldm.modules.encoders.dinov2_mod.FrozenDinoV2ImageEmbedder
+ params:
+ version: dinov2_vitb14
+
+ img_to_triplane_config:
+ target: svrm.ldm.modules.translator.img_to_triplane.ImgToTriplaneModel
+ params:
+ pos_emb_size: 64
+ pos_emb_dim: 1024
+ cam_cond_dim: 20
+ n_heads: 16
+ d_head: 64
+ depth: 16
+ context_dim: 768
+ triplane_dim: 120
+ use_fp16: true
+ use_bf16: false
+ upsample_time: 2
+
+ render_config:
+ target: svrm.ldm.modules.rendering_neus.synthesizer.TriplaneSynthesizer
+ params:
+ triplane_dim: 120
+ samples_per_ray: 128
+
+
diff --git a/svrm/configs/svrm.yaml b/svrm/configs/svrm.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..286ff8e108a3b241aed860d21681479ccf0a1e63
--- /dev/null
+++ b/svrm/configs/svrm.yaml
@@ -0,0 +1,32 @@
+model:
+ base_learning_rate: 3.0e-05
+ target: svrm.ldm.models.svrm.SVRMModel
+ params:
+
+ img_encoder_config:
+ target: svrm.ldm.modules.encoders.dinov2_mod.FrozenDinoV2ImageEmbedder
+ params:
+ version: dinov2_vitb14
+
+ img_to_triplane_config:
+ target: svrm.ldm.modules.translator.img_to_triplane.ImgToTriplaneModel
+ params:
+ pos_emb_size: 64
+ pos_emb_dim: 1024
+ cam_cond_dim: 20
+ n_heads: 16
+ d_head: 64
+ depth: 16
+ context_dim: 768
+ triplane_dim: 120
+ use_fp16: true
+ use_bf16: false
+ upsample_time: 2
+
+ render_config:
+ target: svrm.ldm.modules.rendering_neus.synthesizer.TriplaneSynthesizer
+ params:
+ triplane_dim: 120
+ samples_per_ray: 128
+
+
diff --git a/svrm/ldm/.DS_Store b/svrm/ldm/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..fb57e5a2bfa5f6208a6a12d082e457410cb415be
Binary files /dev/null and b/svrm/ldm/.DS_Store differ
diff --git a/svrm/ldm/models/svrm.py b/svrm/ldm/models/svrm.py
new file mode 100644
index 0000000000000000000000000000000000000000..e12511ec16095176b1d3bed70eaa18a29b7e1b0f
--- /dev/null
+++ b/svrm/ldm/models/svrm.py
@@ -0,0 +1,263 @@
+# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
+# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
+
+# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
+# The below software and/or models in this distribution may have been
+# modified by THL A29 Limited ("Tencent Modifications").
+# All Tencent Modifications are Copyright (C) THL A29 Limited.
+
+# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
+# except for the third-party components listed below.
+# Hunyuan 3D does not impose any additional limitations beyond what is outlined
+# in the repsective licenses of these third-party components.
+# Users must comply with all terms and conditions of original licenses of these third-party
+# components and must ensure that the usage of the third party components adheres to
+# all relevant laws and regulations.
+
+# For avoidance of doubts, Hunyuan 3D means the large language models and
+# their software and algorithms, including trained model weights, parameters (including
+# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
+# fine-tuning enabling code and other elements of the foregoing made publicly available
+# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
+
+import os
+import time
+import math
+import cv2
+import numpy as np
+import itertools
+import shutil
+from tqdm import tqdm
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+try:
+ import trimesh
+ import mcubes
+ import xatlas
+ import open3d as o3d
+except:
+ raise "failed to import 3d libraries "
+
+from ..modules.rendering_neus.mesh import Mesh
+from ..modules.rendering_neus.rasterize import NVDiffRasterizerContext
+
+from ..utils.ops import scale_tensor
+from ..util import count_params, instantiate_from_config
+from ..vis_util import render
+
+
+def unwrap_uv(v_pos, t_pos_idx):
+ print("Using xatlas to perform UV unwrapping, may take a while ...")
+ atlas = xatlas.Atlas()
+ atlas.add_mesh(v_pos, t_pos_idx)
+ atlas.generate(xatlas.ChartOptions(), xatlas.PackOptions())
+ _, indices, uvs = atlas.get_mesh(0)
+ indices = indices.astype(np.int64, casting="same_kind")
+ return uvs, indices
+
+
+def uv_padding(image, hole_mask, uv_padding_size = 2):
+ return cv2.inpaint(
+ (image.detach().cpu().numpy() * 255).astype(np.uint8),
+ (hole_mask.detach().cpu().numpy() * 255).astype(np.uint8),
+ uv_padding_size,
+ cv2.INPAINT_TELEA
+ )
+
+def refine_mesh(vtx_refine, faces_refine):
+ mesh = o3d.geometry.TriangleMesh(
+ vertices=o3d.utility.Vector3dVector(vtx_refine),
+ triangles=o3d.utility.Vector3iVector(faces_refine))
+
+ mesh = mesh.remove_unreferenced_vertices()
+ mesh = mesh.remove_duplicated_triangles()
+ mesh = mesh.remove_duplicated_vertices()
+
+ voxel_size = max(mesh.get_max_bound() - mesh.get_min_bound())
+
+ mesh = mesh.simplify_vertex_clustering(
+ voxel_size=0.007, # 0.005
+ contraction=o3d.geometry.SimplificationContraction.Average)
+
+ mesh = mesh.filter_smooth_simple(number_of_iterations=2)
+
+ vtx_refine = np.asarray(mesh.vertices).astype(np.float32)
+ faces_refine = np.asarray(mesh.triangles)
+ return vtx_refine, faces_refine, mesh
+
+
+class SVRMModel(torch.nn.Module):
+ def __init__(
+ self,
+ img_encoder_config,
+ img_to_triplane_config,
+ render_config,
+ device = "cuda:0",
+ **kwargs
+ ):
+ super().__init__()
+
+ self.img_encoder = instantiate_from_config(img_encoder_config).half()
+ self.img_to_triplane_decoder = instantiate_from_config(img_to_triplane_config).half()
+ self.render = instantiate_from_config(render_config).half()
+ self.device = device
+ count_params(self, verbose=True)
+
+ @torch.no_grad()
+ def export_mesh_with_uv(
+ self,
+ data,
+ mesh_size: int = 384,
+ ctx = None,
+ context_type = 'cuda',
+ texture_res = 1024,
+ target_face_count = 10000,
+ do_texture_mapping = True,
+ out_dir = 'outputs/test'
+ ):
+ """
+ color_type: 0 for ray texture, 1 for vertices texture
+ """
+ st = time.time()
+ here = {'device': self.device, 'dtype': torch.float16}
+ input_view_image = data["input_view"].to(**here) # [b, m, c, h, w]
+ input_view_cam = data["input_view_cam"].to(**here) # [b, m, 20]
+
+ batch_size, input_view_num, *_ = input_view_image.shape
+ assert batch_size == 1, "batch size should be 1"
+
+ input_view_image = rearrange(input_view_image, 'b m c h w -> (b m) c h w')
+ input_view_cam = rearrange(input_view_cam, 'b m d -> (b m) d')
+ input_view_feat = self.img_encoder(input_view_image, input_view_cam)
+ input_view_feat = rearrange(input_view_feat, '(b m) l d -> b (l m) d', m=input_view_num)
+
+ # -- decoder
+ torch.cuda.empty_cache()
+ triplane_gen = self.img_to_triplane_decoder(input_view_feat) # [b, 3, tri_dim, h, w]
+ del input_view_feat
+ torch.cuda.empty_cache()
+
+ # --- triplane nerf render
+
+ cur_triplane = triplane_gen[0:1]
+
+ aabb = torch.tensor([[-0.6, -0.6, -0.6], [0.6, 0.6, 0.6]]).unsqueeze(0).to(**here)
+ grid_out = self.render.forward_grid(planes=cur_triplane, grid_size=mesh_size, aabb=aabb)
+
+ print(f"=====> LRM forward time: {time.time() - st}")
+ st = time.time()
+
+ vtx, faces = mcubes.marching_cubes(0. - grid_out['sdf'].squeeze(0).squeeze(-1).cpu().float().numpy(), 0)
+
+ bbox = aabb[0].cpu().numpy()
+ vtx = vtx / (mesh_size - 1)
+ vtx = vtx * (bbox[1] - bbox[0]) + bbox[0]
+
+ # refine mesh
+ vtx_refine, faces_refine, mesh = refine_mesh(vtx, faces)
+
+ # reduce faces
+ if faces_refine.shape[0] > target_face_count:
+ print(f"reduce face: {faces_refine.shape[0]} -> {target_face_count}")
+ mesh = o3d.geometry.TriangleMesh(
+ vertices = o3d.utility.Vector3dVector(vtx_refine),
+ triangles = o3d.utility.Vector3iVector(faces_refine)
+ )
+
+ # Function to simplify mesh using Quadric Error Metric Decimation by Garland and Heckbert
+ mesh = mesh.simplify_quadric_decimation(target_face_count, boundary_weight=1.0)
+
+ mesh = Mesh(
+ v_pos = torch.from_numpy(np.asarray(mesh.vertices)).to(self.device),
+ t_pos_idx = torch.from_numpy(np.asarray(mesh.triangles)).to(self.device),
+ v_rgb = torch.from_numpy(np.asarray(mesh.vertex_colors)).to(self.device)
+ )
+ vtx_refine = mesh.v_pos.cpu().numpy()
+ faces_refine = mesh.t_pos_idx.cpu().numpy()
+
+ vtx_colors = self.render.forward_points(cur_triplane, torch.tensor(vtx_refine).unsqueeze(0).to(**here))
+ vtx_colors = vtx_colors['rgb'].float().squeeze(0).cpu().numpy()
+
+ color_ratio = 0.8 # increase brightness
+ with open(f'{out_dir}/mesh_with_colors.obj', 'w') as fid:
+ verts = vtx_refine[:, [1,2,0]]
+ for pidx, pp in enumerate(verts):
+ color = vtx_colors[pidx]
+ color = [color[0]**color_ratio, color[1]**color_ratio, color[2]**color_ratio]
+ fid.write('v %f %f %f %f %f %f\n' % (pp[0], pp[1], pp[2], color[0], color[1], color[2]))
+ for i, f in enumerate(faces_refine):
+ f1 = f + 1
+ fid.write('f %d %d %d\n' % (f1[0], f1[1], f1[2]))
+
+ mesh = trimesh.load_mesh(f'{out_dir}/mesh_with_colors.obj')
+ print(f"=====> generate mesh with vertex shading time: {time.time() - st}")
+ st = time.time()
+
+ if not do_texture_mapping:
+ shutil.copy(f'{out_dir}/mesh_with_colors.obj', f'{out_dir}/mesh.obj')
+ mesh.export(f'{out_dir}/mesh.glb', file_type='glb')
+ return None
+
+ ########## export texture ########
+ st = time.time()
+
+ # uv unwrap
+ vtx_tex, t_tex_idx = unwrap_uv(vtx_refine, faces_refine)
+ vtx_refine = torch.from_numpy(vtx_refine).to(self.device)
+ faces_refine = torch.from_numpy(faces_refine).to(self.device)
+ t_tex_idx = torch.from_numpy(t_tex_idx).to(self.device)
+ uv_clip = torch.from_numpy(vtx_tex * 2.0 - 1.0).to(self.device)
+
+ # rasterize
+ ctx = NVDiffRasterizerContext(context_type, cur_triplane.device) if ctx is None else ctx
+ rast = ctx.rasterize_one(
+ torch.cat([
+ uv_clip,
+ torch.zeros_like(uv_clip[..., 0:1]),
+ torch.ones_like(uv_clip[..., 0:1])
+ ], dim=-1),
+ t_tex_idx,
+ (texture_res, texture_res)
+ )[0]
+ hole_mask = ~(rast[:, :, 3] > 0)
+
+ # Interpolate world space position
+ gb_pos = ctx.interpolate_one(vtx_refine, rast[None, ...], faces_refine)[0][0]
+ with torch.no_grad():
+ gb_mask_pos_scale = scale_tensor(gb_pos.unsqueeze(0).view(1, -1, 3), (-1, 1), (-1, 1))
+ tex_map = self.render.forward_points(cur_triplane, gb_mask_pos_scale)['rgb']
+ tex_map = tex_map.float().squeeze(0) # (0, 1)
+ tex_map = tex_map.view((texture_res, texture_res, 3))
+ img = uv_padding(tex_map, hole_mask)
+ img = ((img/255.0) ** color_ratio) * 255 # increase brightness
+ img = img.clip(0, 255).astype(np.uint8)
+
+ verts = vtx_refine.cpu().numpy()[:, [1,2,0]]
+ faces = faces_refine.cpu().numpy()
+
+ with open(f'{out_dir}/texture.mtl', 'w') as fid:
+ fid.write('newmtl material_0\n')
+ fid.write("Ka 1.000 1.000 1.000\n")
+ fid.write("Kd 1.000 1.000 1.000\n")
+ fid.write("Ks 0.000 0.000 0.000\n")
+ fid.write("d 1.0\n")
+ fid.write("illum 2\n")
+ fid.write(f'map_Kd texture.png\n')
+
+ with open(f'{out_dir}/mesh.obj', 'w') as fid:
+ fid.write(f'mtllib texture.mtl\n')
+ for pidx, pp in enumerate(verts):
+ fid.write('v %f %f %f\n' % (pp[0], pp[1], pp[2]))
+ for pidx, pp in enumerate(vtx_tex):
+ fid.write('vt %f %f\n' % (pp[0], 1 - pp[1]))
+ fid.write('usemtl material_0\n')
+ for i, f in enumerate(faces):
+ f1 = f + 1
+ f2 = t_tex_idx[i] + 1
+ fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2],))
+
+ cv2.imwrite(f'{out_dir}/texture.png', img[..., [2, 1, 0]])
+ mesh = trimesh.load_mesh(f'{out_dir}/mesh.obj')
+ mesh.export(f'{out_dir}/mesh.glb', file_type='glb')
+
\ No newline at end of file
diff --git a/svrm/ldm/modules/attention.py b/svrm/ldm/modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..380d3d7b8f1e87aee101c4ac6a9797e3de0682c9
--- /dev/null
+++ b/svrm/ldm/modules/attention.py
@@ -0,0 +1,457 @@
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+import numpy as np
+
+FLASH_IS_AVAILABLE = XFORMERS_IS_AVAILBLE = False
+try:
+ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
+ FLASH_IS_AVAILABLE = True
+except:
+ try:
+ import xformers
+ import xformers.ops
+ XFORMERS_IS_AVAILBLE = True
+ except:
+ pass
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return{el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
+ k = k.softmax(dim=-1)
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
+ return self.to_out(out)
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = rearrange(q, 'b c h w -> b (h w) c')
+ k = rearrange(k, 'b c h w -> b c (h w)')
+ w_ = torch.einsum('bij,bjk->bik', q, k)
+
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, 'b c h w -> b c (h w)')
+ w_ = rearrange(w_, 'b i j -> b j i')
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+ out = einsum('b i j, b j d -> b i d', attn, v) # [b*h, n, d]
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+
+class FlashAttention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
+ f"{heads} heads.")
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+ self.dropout = dropout
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context=None, mask=None):
+ context = default(context, x)
+ h = self.heads
+ dtype = torch.bfloat16 # torch.half
+ q = self.to_q(x).to(dtype)
+ k = self.to_k(context).to(dtype)
+ v = self.to_v(context).to(dtype)
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q, k, v)) # q is [b, 3079, 16, 64]
+ out = flash_attn_func(q, k, v, dropout_p=self.dropout, softmax_scale=None, causal=False, window_size=(-1, -1)) # out is same shape to q
+ out = rearrange(out, 'b n h d -> b n (h d)', h=h)
+ return self.to_out(out.float())
+
+class MemoryEfficientCrossAttention(nn.Module):
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
+ super().__init__()
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
+ f"{heads} heads.")
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.heads = heads
+ self.dim_head = dim_head
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+ self.attention_op: Optional[Any] = None
+
+ def forward(self, x, context=None, mask=None):
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ b, _, _ = q.shape
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (q, k, v),
+ )
+
+ # actually compute the attention, what we cannot get enough of
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+ if exists(mask):
+ raise NotImplementedError
+ out = (
+ out.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+ return self.to_out(out)
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
+ disable_self_attn=False):
+ super().__init__()
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
+ self.norm1 = Fp32LayerNorm(dim)
+ self.norm2 = Fp32LayerNorm(dim)
+ self.norm3 = Fp32LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
+
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+ATTENTION_MODES = {
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention,
+ "softmax-flash": FlashAttention
+}
+
+def modulate(x, shift, scale):
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+class Fp32LayerNorm(nn.LayerNorm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
+class AdaNorm(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(dim, 2 * dim, bias=True)
+ )
+ self.norm = Fp32LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+
+ def forward(self, x, c): # x is fp32, c is fp16
+ shift, scale = self.adaLN_modulation(c.float()).chunk(2, dim=1) # bf16
+ x = modulate(self.norm(x), shift, scale) # fp32
+ return x
+
+
+class BasicTransformerBlockLRM(nn.Module):
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, \
+ checkpoint=True):
+ super().__init__()
+
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
+ attn_mode = "softmax-flash" if FLASH_IS_AVAILABLE else attn_mode
+ assert attn_mode in ATTENTION_MODES
+ attn_cls = ATTENTION_MODES[attn_mode]
+
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, \
+ context_dim=context_dim) # cross-attn
+ self.attn2 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, \
+ context_dim=None) # self-attn
+
+ self.norm1 = Fp32LayerNorm(dim)
+ self.norm2 = Fp32LayerNorm(dim)
+ self.norm3 = Fp32LayerNorm(dim)
+
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None, cam_emb=None): # (torch.float32, torch.float32, torch.bfloat16)
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
+
+
+ def _forward(self, x, context=None, cam_emb=None):
+
+ x = self.attn1(self.norm1(x), context=context) + x # cross-attn
+ x = self.attn2(self.norm2(x), context=None) + x # self-attn
+ x = self.ff(self.norm3(x)) + x
+
+ return x
+
+class ImgToTriplaneTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ """
+ def __init__(self, query_dim, n_heads, d_head, depth=1, dropout=0., context_dim=None, triplane_size=64):
+ super().__init__()
+
+ self.transformer_blocks = nn.ModuleList([
+ BasicTransformerBlockLRM(query_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
+ for d in range(depth)])
+
+ self.norm = Fp32LayerNorm(query_dim, eps=1e-6)
+
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ elif isinstance(module, nn.LayerNorm):
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ if module.weight is not None:
+ nn.init.constant_(module.weight, 1.0)
+ self.apply(_basic_init)
+
+ def forward(self, x, context=None, cam_emb=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ for block in self.transformer_blocks:
+ x = block(x, context=context)
+ x = self.norm(x)
+ return x
+
+
+
+
+
diff --git a/svrm/ldm/modules/encoders/__init__.py b/svrm/ldm/modules/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/svrm/ldm/modules/encoders/dinov2/__init__.py b/svrm/ldm/modules/encoders/dinov2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/svrm/ldm/modules/encoders/dinov2/hub/__init__.py b/svrm/ldm/modules/encoders/dinov2/hub/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/svrm/ldm/modules/encoders/dinov2/hub/backbones.py b/svrm/ldm/modules/encoders/dinov2/hub/backbones.py
new file mode 100644
index 0000000000000000000000000000000000000000..51bb1ffe50525db3ed2ac78c3bb40776dd530645
--- /dev/null
+++ b/svrm/ldm/modules/encoders/dinov2/hub/backbones.py
@@ -0,0 +1,156 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+from typing import Union
+
+import torch
+
+from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
+
+
+class Weights(Enum):
+ LVD142M = "LVD142M"
+
+
+def _make_dinov2_model(
+ *,
+ arch_name: str = "vit_large",
+ img_size: int = 518,
+ patch_size: int = 14,
+ init_values: float = 1.0,
+ ffn_layer: str = "mlp",
+ block_chunks: int = 0,
+ num_register_tokens: int = 0,
+ interpolate_antialias: bool = False,
+ interpolate_offset: float = 0.1,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.LVD142M,
+ **kwargs,
+):
+ from ..models import vision_transformer as vits
+
+ if isinstance(weights, str):
+ try:
+ weights = Weights[weights]
+ except KeyError:
+ raise AssertionError(f"Unsupported weights: {weights}")
+
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
+ vit_kwargs = dict(
+ img_size=img_size,
+ patch_size=patch_size,
+ init_values=init_values,
+ ffn_layer=ffn_layer,
+ block_chunks=block_chunks,
+ num_register_tokens=num_register_tokens,
+ interpolate_antialias=interpolate_antialias,
+ interpolate_offset=interpolate_offset,
+ )
+ vit_kwargs.update(**kwargs)
+ model = vits.__dict__[arch_name](**vit_kwargs)
+
+ if pretrained:
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+ model.load_state_dict(state_dict, strict=True)
+
+ return model
+
+
+def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_giant2",
+ ffn_layer="swiglufused",
+ weights=weights,
+ pretrained=pretrained,
+ **kwargs,
+ )
+
+
+def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_small",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_base",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_large",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_giant2",
+ ffn_layer="swiglufused",
+ weights=weights,
+ pretrained=pretrained,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
diff --git a/svrm/ldm/modules/encoders/dinov2/hub/utils.py b/svrm/ldm/modules/encoders/dinov2/hub/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7afea3273713518e891d1e6b8e86d58b4700fddc
--- /dev/null
+++ b/svrm/ldm/modules/encoders/dinov2/hub/utils.py
@@ -0,0 +1,39 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import itertools
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
+
+
+def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
+ compact_arch_name = arch_name.replace("_", "")[:4]
+ registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
+ return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
+
+
+class CenterPadding(nn.Module):
+ def __init__(self, multiple):
+ super().__init__()
+ self.multiple = multiple
+
+ def _get_pad(self, size):
+ new_size = math.ceil(size / self.multiple) * self.multiple
+ pad_size = new_size - size
+ pad_size_left = pad_size // 2
+ pad_size_right = pad_size - pad_size_left
+ return pad_size_left, pad_size_right
+
+ @torch.inference_mode()
+ def forward(self, x):
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
+ output = F.pad(x, pads)
+ return output
diff --git a/svrm/ldm/modules/encoders/dinov2/layers/__init__.py b/svrm/ldm/modules/encoders/dinov2/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..70dff7a567de20c911559522e678dd6b605fdb9d
--- /dev/null
+++ b/svrm/ldm/modules/encoders/dinov2/layers/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .dino_head import DINOHead
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+from .block import NestedTensorBlockMod
+from .attention import MemEffAttention
diff --git a/svrm/ldm/modules/encoders/dinov2/layers/attention.py b/svrm/ldm/modules/encoders/dinov2/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0f62db59d697e151d4aaec272b897fd07c8a8ab
--- /dev/null
+++ b/svrm/ldm/modules/encoders/dinov2/layers/attention.py
@@ -0,0 +1,89 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+import os
+import warnings
+
+from torch import Tensor
+from torch import nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import memory_efficient_attention, unbind
+
+ XFORMERS_AVAILABLE = True
+ warnings.warn("xFormers is available (Attention)")
+ else:
+ warnings.warn("xFormers is disabled (Attention)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+ warnings.warn("xFormers is not available (Attention)")
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
diff --git a/svrm/ldm/modules/encoders/dinov2/layers/block.py b/svrm/ldm/modules/encoders/dinov2/layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0de346f02648eda86b6c9dc8a599e5da0f88636
--- /dev/null
+++ b/svrm/ldm/modules/encoders/dinov2/layers/block.py
@@ -0,0 +1,269 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import os
+import logging
+import warnings
+from typing import Callable, List, Any, Tuple, Dict
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention, MemEffAttention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+from ....attention import AdaNorm
+
+
+logger = logging.getLogger("dinov2")
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
+
+ XFORMERS_AVAILABLE = True
+ warnings.warn("xFormers is available (Block)")
+ else:
+ warnings.warn("xFormers is disabled (Block)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+
+ warnings.warn("xFormers is not available (Block)")
+
+
+class BlockMod(nn.Module):
+ '''
+ using Modified Block, see below
+ '''
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = AdaNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor, cam_emb: Tensor) -> Tensor:
+ def attn_residual_func(x: Tensor, cam_emb: Tensor = None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x, cam_emb)))
+
+ def ffn_residual_func(x: Tensor, cam_emb: Tensor = None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x, cam_emb)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x, cam_emb))
+ x = x + self.drop_path1(ffn_residual_func(x, cam_emb)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x, cam_emb)
+ x = x + ffn_residual_func(x, cam_emb)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+) -> Tensor:
+ # drop_add_residual_stochastic_depth_list
+
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ # get_branges_scales
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ # add residuals
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlockMod(BlockMod):
+ def forward_nested(self, x_list: List[Tensor], cam_emb_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, cam_emb: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x, cam_emb), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, cam_emb: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x, cam_emb))
+
+ x_list = drop_add_residual_stochastic_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, cam_emb: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x, cam_emb), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, cam_emb: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x, cam_emb)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list, cam_emb_or_cam_emb_list):
+ if isinstance(x_or_x_list, Tensor) and isinstance(cam_emb_or_cam_emb_list, Tensor) :
+ return super().forward(x_or_x_list, cam_emb_or_cam_emb_list)
+ elif isinstance(x_or_x_list, list) and isinstance(cam_emb_or_cam_emb_list, list):
+ if not XFORMERS_AVAILABLE:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return self.forward_nested(x_or_x_list, cam_emb_or_cam_emb_list)
+ else:
+ raise AssertionError
diff --git a/svrm/ldm/modules/encoders/dinov2/layers/dino_head.py b/svrm/ldm/modules/encoders/dinov2/layers/dino_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccca59999e1d686e1341281c61e8961f1b0e6545
--- /dev/null
+++ b/svrm/ldm/modules/encoders/dinov2/layers/dino_head.py
@@ -0,0 +1,58 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn.init import trunc_normal_
+from torch.nn.utils import weight_norm
+
+
+class DINOHead(nn.Module):
+ def __init__(
+ self,
+ in_dim,
+ out_dim,
+ use_bn=False,
+ nlayers=3,
+ hidden_dim=2048,
+ bottleneck_dim=256,
+ mlp_bias=True,
+ ):
+ super().__init__()
+ nlayers = max(nlayers, 1)
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
+ self.apply(self._init_weights)
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
+ self.last_layer.weight_g.data.fill_(1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = self.mlp(x)
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
+ x = self.last_layer(x)
+ return x
+
+
+def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
+ if nlayers == 1:
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
+ else:
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ for _ in range(nlayers - 2):
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
+ return nn.Sequential(*layers)
diff --git a/svrm/ldm/modules/encoders/dinov2/layers/drop_path.py b/svrm/ldm/modules/encoders/dinov2/layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bb1487b0eed4cb14dc0d5d1ee57a2acc78de34a
--- /dev/null
+++ b/svrm/ldm/modules/encoders/dinov2/layers/drop_path.py
@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/svrm/ldm/modules/encoders/dinov2/layers/layer_scale.py b/svrm/ldm/modules/encoders/dinov2/layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..5468ee2dce0a9446c028791de5cff1ff068a4fe5
--- /dev/null
+++ b/svrm/ldm/modules/encoders/dinov2/layers/layer_scale.py
@@ -0,0 +1,27 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/svrm/ldm/modules/encoders/dinov2/layers/mlp.py b/svrm/ldm/modules/encoders/dinov2/layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..0965768a9aef04ac6b81322f4dd60cf035159e91
--- /dev/null
+++ b/svrm/ldm/modules/encoders/dinov2/layers/mlp.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/svrm/ldm/modules/encoders/dinov2/layers/patch_embed.py b/svrm/ldm/modules/encoders/dinov2/layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c3aaf46c523ab1ae27430419187bbad11e302ab
--- /dev/null
+++ b/svrm/ldm/modules/encoders/dinov2/layers/patch_embed.py
@@ -0,0 +1,88 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/svrm/ldm/modules/encoders/dinov2/layers/swiglu_ffn.py b/svrm/ldm/modules/encoders/dinov2/layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..3765d5def655f0a23f3803f4c7f79c33d3ecfd55
--- /dev/null
+++ b/svrm/ldm/modules/encoders/dinov2/layers/swiglu_ffn.py
@@ -0,0 +1,72 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import os
+from typing import Callable, Optional
+import warnings
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import SwiGLU
+
+ XFORMERS_AVAILABLE = True
+ warnings.warn("xFormers is available (SwiGLU)")
+ else:
+ warnings.warn("xFormers is disabled (SwiGLU)")
+ raise ImportError
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+ warnings.warn("xFormers is not available (SwiGLU)")
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/svrm/ldm/modules/encoders/dinov2/models/__init__.py b/svrm/ldm/modules/encoders/dinov2/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..01c40a2a10ce2b1e39f666328132c2a80111072d
--- /dev/null
+++ b/svrm/ldm/modules/encoders/dinov2/models/__init__.py
@@ -0,0 +1,43 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+
+from . import vision_transformer as vits
+
+
+logger = logging.getLogger("dinov2")
+
+
+def build_model(args, only_teacher=False, img_size=224):
+ args.arch = args.arch.removesuffix("_memeff")
+ if "vit" in args.arch:
+ vit_kwargs = dict(
+ img_size=img_size,
+ patch_size=args.patch_size,
+ init_values=args.layerscale,
+ ffn_layer=args.ffn_layer,
+ block_chunks=args.block_chunks,
+ qkv_bias=args.qkv_bias,
+ proj_bias=args.proj_bias,
+ ffn_bias=args.ffn_bias,
+ num_register_tokens=args.num_register_tokens,
+ interpolate_offset=args.interpolate_offset,
+ interpolate_antialias=args.interpolate_antialias,
+ )
+ teacher = vits.__dict__[args.arch](**vit_kwargs)
+ if only_teacher:
+ return teacher, teacher.embed_dim
+ student = vits.__dict__[args.arch](
+ **vit_kwargs,
+ drop_path_rate=args.drop_path_rate,
+ drop_path_uniform=args.drop_path_uniform,
+ )
+ embed_dim = student.embed_dim
+ return student, teacher, embed_dim
+
+
+def build_model_from_cfg(cfg, only_teacher=False):
+ return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
diff --git a/svrm/ldm/modules/encoders/dinov2/models/vision_transformer.py b/svrm/ldm/modules/encoders/dinov2/models/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..666987e13e5c93ecc62e1d3e3b97bca0987d5edd
--- /dev/null
+++ b/svrm/ldm/modules/encoders/dinov2/models/vision_transformer.py
@@ -0,0 +1,415 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+
+from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlockMod as BlockMod
+from ....attention import AdaNorm
+
+logger = logging.getLogger("dinov2")
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=BlockMod,
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ pos_emb_dim=768,
+ cam_cond_dim=20
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+
+
+ norm_layer = AdaNorm
+ self.cam_embed = nn.Sequential(
+ nn.Linear(cam_cond_dim, pos_emb_dim, bias=True),
+ nn.SiLU(),
+ nn.Linear(pos_emb_dim, pos_emb_dim, bias=True))
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
+
+ sqrt_N = math.sqrt(N)
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
+ scale_factor=(sx, sy),
+ mode="bicubic",
+ antialias=self.interpolate_antialias,
+ )
+
+ assert int(w0) == patch_pos_embed.shape[-2]
+ assert int(h0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.register_tokens is not None:
+ x = torch.cat(
+ (
+ x[:, :1],
+ self.register_tokens.expand(x.shape[0], -1, -1),
+ x[:, 1:],
+ ),
+ dim=1,
+ )
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features_list_with_camera(self, x_list, cam_cond_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ cam_emb = [self.cam_embed(cam_cond) for cam_cond in cam_cond_list]
+ for blk in self.blocks:
+ x = blk(x, cam_emb)
+
+ all_x = x
+ all_cam_emb = cam_emb
+ output = []
+ for x, cam_emb, masks in zip(all_x, all_cam_emb, masks_list):
+ x_norm = self.norm(x, cam_emb)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+
+ def forward_features_with_camera(self, x, cam_cond, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, cam_cond, masks)
+ cam_emb = self.cam_embed(cam_cond)
+ x = self.prepare_tokens_with_masks(x, masks)
+ for blk in self.blocks:
+ x = blk(x, cam_emb)
+ x_norm = self.norm(x, cam_emb)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+
+ def _get_inter_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_inter_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+
+ ret = self.forward_features_with_camera(*args, **kwargs)
+
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, AdaNorm):
+ nn.init.constant_(module.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(module.adaLN_modulation[-1].bias, 0)
+ elif isinstance(module, nn.LayerNorm):
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ if module.weight is not None:
+ nn.init.constant_(module.weight, 1.0)
+
+
+def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(BlockMod, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(BlockMod, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
diff --git a/svrm/ldm/modules/encoders/dinov2_mod.py b/svrm/ldm/modules/encoders/dinov2_mod.py
new file mode 100644
index 0000000000000000000000000000000000000000..19c4bb6c9e09734b18d621b9955d46ca7cd6275f
--- /dev/null
+++ b/svrm/ldm/modules/encoders/dinov2_mod.py
@@ -0,0 +1,67 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+
+from .dinov2.hub.backbones import dinov2_vitb14
+
+class FrozenDinoV2ImageEmbedder(nn.Module):
+ """
+ Uses the dinov2 image encoder with camera modulation.
+ Not actually frozen... If you want that set cond_stage_trainable=False in cfg
+ """
+ def __init__(
+ self,
+ version='dinov2_vitb14',
+ ckpt_path=None,
+ lrm_mode='plain_lrm',
+ ):
+ super().__init__()
+ self.lrm_mode = lrm_mode
+ assert version in ['dinov2_vitb14', 'dinov2_vits14', 'dinov2_vitl14', 'dinov2_vitg14']
+
+
+ self.model = dinov2_vitb14(pretrained=False)
+
+ if ckpt_path is not None:
+ self.load_pretrained(ckpt_path)
+ else:
+ print('None pretrained model for dinov2 encoder ...')
+
+
+ def load_pretrained(self, ckpt_path):
+ print('Loading dinov2 encoder ...')
+ orig_state_dict = torch.load(ckpt_path, map_location='cpu')
+ try:
+ ret = self.model.load_state_dict(orig_state_dict, strict=False)
+ print(ret)
+ print('Successfully loaded orig state dict')
+ except:
+ new_state_dict = OrderedDict()
+ for k, v in orig_state_dict['state_dict'].items():
+ if 'img_encoder' in k:
+ new_state_dict[k.replace('img_encoder.model.', '')] = v
+ ret = self.model.load_state_dict(new_state_dict, strict=False)
+ print(ret)
+ print('Successfully loaded new state dict')
+
+
+ def forward(self, x, *args, **kwargs):
+ ret = self.model.forward_features_with_camera(x, *args, **kwargs)
+ output = torch.cat([ret['x_norm_clstoken'].unsqueeze(1), ret['x_norm_patchtokens']], dim=1)
+ return output
diff --git a/svrm/ldm/modules/rendering_neus/__init__.py b/svrm/ldm/modules/rendering_neus/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a7f09c03caa76149581d5b15b541aecb3892d6a
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2023, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Empty
diff --git a/svrm/ldm/modules/rendering_neus/mesh.py b/svrm/ldm/modules/rendering_neus/mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e224b84d91cf5a999f83f24cd94873151494094
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/mesh.py
@@ -0,0 +1,311 @@
+from __future__ import annotations
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from ...utils.typing import *
+
+def dot(x, y):
+ return torch.sum(x * y, -1, keepdim=True)
+
+class Mesh:
+ def __init__(
+ self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], v_rgb: Integer[Tensor, "Nf 3"], **kwargs
+ ) -> None:
+ self.v_pos: Float[Tensor, "Nv 3"] = v_pos
+ self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx
+ self.v_rgb: Optional[Float[Tensor, "Nv 3"]] = v_rgb
+
+ self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None
+ self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None
+ self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None
+ self._t_tex_idx: Optional[Float[Tensor, "Nf 3"]] = None
+ # self._v_rgb: Optional[Float[Tensor, "Nv 3"]] = None
+ self._edges: Optional[Integer[Tensor, "Ne 2"]] = None
+ self.extras: Dict[str, Any] = {}
+ for k, v in kwargs.items():
+ self.add_extra(k, v)
+
+ def add_extra(self, k, v) -> None:
+ self.extras[k] = v
+
+ def remove_outlier(self, outlier_n_faces_threshold: Union[int, float]) -> Mesh:
+ if self.requires_grad:
+ print("Mesh is differentiable, not removing outliers")
+ return self
+
+ # use trimesh to first split the mesh into connected components
+ # then remove the components with less than n_face_threshold faces
+ import trimesh
+
+ # construct a trimesh object
+ mesh = trimesh.Trimesh(
+ vertices=self.v_pos.detach().cpu().numpy(),
+ faces=self.t_pos_idx.detach().cpu().numpy(),
+ )
+
+ # split the mesh into connected components
+ components = mesh.split(only_watertight=False)
+ # log the number of faces in each component
+ print(
+ "Mesh has {} components, with faces: {}".format(
+ len(components), [c.faces.shape[0] for c in components]
+ )
+ )
+
+ n_faces_threshold: int
+ if isinstance(outlier_n_faces_threshold, float):
+ # set the threshold to the number of faces in the largest component multiplied by outlier_n_faces_threshold
+ n_faces_threshold = int(
+ max([c.faces.shape[0] for c in components]) * outlier_n_faces_threshold
+ )
+ else:
+ # set the threshold directly to outlier_n_faces_threshold
+ n_faces_threshold = outlier_n_faces_threshold
+
+ # log the threshold
+ print(
+ "Removing components with less than {} faces".format(n_faces_threshold)
+ )
+
+ # remove the components with less than n_face_threshold faces
+ components = [c for c in components if c.faces.shape[0] >= n_faces_threshold]
+
+ # log the number of faces in each component after removing outliers
+ print(
+ "Mesh has {} components after removing outliers, with faces: {}".format(
+ len(components), [c.faces.shape[0] for c in components]
+ )
+ )
+ # merge the components
+ mesh = trimesh.util.concatenate(components)
+
+ # convert back to our mesh format
+ v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos)
+ t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx)
+
+ clean_mesh = Mesh(v_pos, t_pos_idx)
+ # keep the extras unchanged
+
+ if len(self.extras) > 0:
+ clean_mesh.extras = self.extras
+ print(
+ f"The following extra attributes are inherited from the original mesh unchanged: {list(self.extras.keys())}"
+ )
+ return clean_mesh
+
+ @property
+ def requires_grad(self):
+ return self.v_pos.requires_grad
+
+ @property
+ def v_nrm(self):
+ if self._v_nrm is None:
+ self._v_nrm = self._compute_vertex_normal()
+ return self._v_nrm
+
+ @property
+ def v_tng(self):
+ if self._v_tng is None:
+ self._v_tng = self._compute_vertex_tangent()
+ return self._v_tng
+
+ @property
+ def v_tex(self):
+ if self._v_tex is None:
+ self._v_tex, self._t_tex_idx = self._unwrap_uv()
+ return self._v_tex
+
+ @property
+ def t_tex_idx(self):
+ if self._t_tex_idx is None:
+ self._v_tex, self._t_tex_idx = self._unwrap_uv()
+ return self._t_tex_idx
+
+ # @property
+ # def v_rgb(self):
+ # return self._v_rgb
+
+ @property
+ def edges(self):
+ if self._edges is None:
+ self._edges = self._compute_edges()
+ return self._edges
+
+ def _compute_vertex_normal(self):
+ i0 = self.t_pos_idx[:, 0]
+ i1 = self.t_pos_idx[:, 1]
+ i2 = self.t_pos_idx[:, 2]
+
+ v0 = self.v_pos[i0, :]
+ v1 = self.v_pos[i1, :]
+ v2 = self.v_pos[i2, :]
+
+ face_normals = torch.cross(v1 - v0, v2 - v0)
+
+ # Splat face normals to vertices
+ v_nrm = torch.zeros_like(self.v_pos)
+ v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
+ v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
+ v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
+
+ # Normalize, replace zero (degenerated) normals with some default value
+ v_nrm = torch.where(
+ dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
+ )
+ v_nrm = F.normalize(v_nrm, dim=1)
+
+ if torch.is_anomaly_enabled():
+ assert torch.all(torch.isfinite(v_nrm))
+
+ return v_nrm
+
+ def _compute_vertex_tangent(self):
+ vn_idx = [None] * 3
+ pos = [None] * 3
+ tex = [None] * 3
+ for i in range(0, 3):
+ pos[i] = self.v_pos[self.t_pos_idx[:, i]]
+ tex[i] = self.v_tex[self.t_tex_idx[:, i]]
+ # t_nrm_idx is always the same as t_pos_idx
+ vn_idx[i] = self.t_pos_idx[:, i]
+
+ tangents = torch.zeros_like(self.v_nrm)
+ tansum = torch.zeros_like(self.v_nrm)
+
+ # Compute tangent space for each triangle
+ uve1 = tex[1] - tex[0]
+ uve2 = tex[2] - tex[0]
+ pe1 = pos[1] - pos[0]
+ pe2 = pos[2] - pos[0]
+
+ nom = pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2]
+ denom = uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1]
+
+ # Avoid division by zero for degenerated texture coordinates
+ tang = nom / torch.where(
+ denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6)
+ )
+
+ # Update all 3 vertices
+ for i in range(0, 3):
+ idx = vn_idx[i][:, None].repeat(1, 3)
+ tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
+ tansum.scatter_add_(
+ 0, idx, torch.ones_like(tang)
+ ) # tansum[n_i] = tansum[n_i] + 1
+ tangents = tangents / tansum
+
+ # Normalize and make sure tangent is perpendicular to normal
+ tangents = F.normalize(tangents, dim=1)
+ tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm)
+
+ if torch.is_anomaly_enabled():
+ assert torch.all(torch.isfinite(tangents))
+
+ return tangents
+
+ def _unwrap_uv(
+ self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {}
+ ):
+ print("Using xatlas to perform UV unwrapping, may take a while ...")
+
+ import xatlas
+
+ atlas = xatlas.Atlas()
+ atlas.add_mesh(
+ self.v_pos.detach().cpu().numpy(),
+ self.t_pos_idx.cpu().numpy(),
+ )
+ co = xatlas.ChartOptions()
+ po = xatlas.PackOptions()
+ for k, v in xatlas_chart_options.items():
+ setattr(co, k, v)
+ for k, v in xatlas_pack_options.items():
+ setattr(po, k, v)
+ atlas.generate(co, po)
+ vmapping, indices, uvs = atlas.get_mesh(0)
+ vmapping = (
+ torch.from_numpy(
+ vmapping.astype(np.uint64, casting="same_kind").view(np.int64)
+ )
+ .to(self.v_pos.device)
+ .long()
+ )
+ uvs = torch.from_numpy(uvs).to(self.v_pos.device).float()
+ indices = (
+ torch.from_numpy(
+ indices.astype(np.uint64, casting="same_kind").view(np.int64)
+ )
+ .to(self.v_pos.device)
+ .long()
+ )
+ return uvs, indices
+
+ def unwrap_uv(
+ self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {}
+ ):
+ self._v_tex, self._t_tex_idx = self._unwrap_uv(
+ xatlas_chart_options, xatlas_pack_options
+ )
+
+ def set_vertex_color(self, v_rgb):
+ assert v_rgb.shape[0] == self.v_pos.shape[0]
+ self._v_rgb = v_rgb
+
+ def _compute_edges(self):
+ # Compute edges
+ edges = torch.cat(
+ [
+ self.t_pos_idx[:, [0, 1]],
+ self.t_pos_idx[:, [1, 2]],
+ self.t_pos_idx[:, [2, 0]],
+ ],
+ dim=0,
+ )
+ edges = edges.sort()[0]
+ edges = torch.unique(edges, dim=0)
+ return edges
+
+ def normal_consistency(self) -> Float[Tensor, ""]:
+ edge_nrm: Float[Tensor, "Ne 2 3"] = self.v_nrm[self.edges]
+ nc = (
+ 1.0 - torch.cosine_similarity(edge_nrm[:, 0], edge_nrm[:, 1], dim=-1)
+ ).mean()
+ return nc
+
+ def _laplacian_uniform(self):
+ # from stable-dreamfusion
+ # https://github.com/ashawkey/stable-dreamfusion/blob/8fb3613e9e4cd1ded1066b46e80ca801dfb9fd06/nerf/renderer.py#L224
+ verts, faces = self.v_pos, self.t_pos_idx
+
+ V = verts.shape[0]
+ F = faces.shape[0]
+
+ # Neighbor indices
+ ii = faces[:, [1, 2, 0]].flatten()
+ jj = faces[:, [2, 0, 1]].flatten()
+ adj = torch.stack([torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique(
+ dim=1
+ )
+ adj_values = torch.ones(adj.shape[1]).to(verts)
+
+ # Diagonal indices
+ diag_idx = adj[0]
+
+ # Build the sparse matrix
+ idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1)
+ values = torch.cat((-adj_values, adj_values))
+
+ # The coalesce operation sums the duplicate indices, resulting in the
+ # correct diagonal
+ return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce()
+
+ def laplacian(self) -> Float[Tensor, ""]:
+ with torch.no_grad():
+ L = self._laplacian_uniform()
+ loss = L.mm(self.v_pos)
+ loss = loss.norm(dim=1)
+ loss = loss.mean()
+ return loss
\ No newline at end of file
diff --git a/svrm/ldm/modules/rendering_neus/rasterize.py b/svrm/ldm/modules/rendering_neus/rasterize.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e270fbf75be18e84540e31bcf566a83efb4827a
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/rasterize.py
@@ -0,0 +1,78 @@
+import nvdiffrast.torch as dr
+import torch
+
+from ...utils.typing import *
+
+
+class NVDiffRasterizerContext:
+ def __init__(self, context_type: str, device: torch.device) -> None:
+ self.device = device
+ self.ctx = self.initialize_context(context_type, device)
+
+ def initialize_context(
+ self, context_type: str, device: torch.device
+ ) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]:
+ if context_type == "gl":
+ return dr.RasterizeGLContext(device=device)
+ elif context_type == "cuda":
+ return dr.RasterizeCudaContext(device=device)
+ else:
+ raise ValueError(f"Unknown rasterizer context type: {context_type}")
+
+ def vertex_transform(
+ self, verts: Float[Tensor, "Nv 3"], mvp_mtx: Float[Tensor, "B 4 4"]
+ ) -> Float[Tensor, "B Nv 4"]:
+ verts_homo = torch.cat(
+ [verts, torch.ones([verts.shape[0], 1]).to(verts)], dim=-1
+ )
+ return torch.matmul(verts_homo, mvp_mtx.permute(0, 2, 1))
+
+ def rasterize(
+ self,
+ pos: Float[Tensor, "B Nv 4"],
+ tri: Integer[Tensor, "Nf 3"],
+ resolution: Union[int, Tuple[int, int]],
+ ):
+ # rasterize in instance mode (single topology)
+ return dr.rasterize(self.ctx, pos.float(), tri.int(), resolution, grad_db=True)
+
+ def rasterize_one(
+ self,
+ pos: Float[Tensor, "Nv 4"],
+ tri: Integer[Tensor, "Nf 3"],
+ resolution: Union[int, Tuple[int, int]],
+ ):
+ # rasterize one single mesh under a single viewpoint
+ rast, rast_db = self.rasterize(pos[None, ...], tri, resolution)
+ return rast[0], rast_db[0]
+
+ def antialias(
+ self,
+ color: Float[Tensor, "B H W C"],
+ rast: Float[Tensor, "B H W 4"],
+ pos: Float[Tensor, "B Nv 4"],
+ tri: Integer[Tensor, "Nf 3"],
+ ) -> Float[Tensor, "B H W C"]:
+ return dr.antialias(color.float(), rast, pos.float(), tri.int())
+
+ def interpolate(
+ self,
+ attr: Float[Tensor, "B Nv C"],
+ rast: Float[Tensor, "B H W 4"],
+ tri: Integer[Tensor, "Nf 3"],
+ rast_db=None,
+ diff_attrs=None,
+ ) -> Float[Tensor, "B H W C"]:
+ return dr.interpolate(
+ attr.float(), rast, tri.int(), rast_db=rast_db, diff_attrs=diff_attrs
+ )
+
+ def interpolate_one(
+ self,
+ attr: Float[Tensor, "Nv C"],
+ rast: Float[Tensor, "B H W 4"],
+ tri: Integer[Tensor, "Nf 3"],
+ rast_db=None,
+ diff_attrs=None,
+ ) -> Float[Tensor, "B H W C"]:
+ return self.interpolate(attr[None, ...], rast, tri, rast_db, diff_attrs)
diff --git a/svrm/ldm/modules/rendering_neus/synthesizer.py b/svrm/ldm/modules/rendering_neus/synthesizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7acb391a27dead7490b87f742abc05fd1087d218
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/synthesizer.py
@@ -0,0 +1,277 @@
+# ORIGINAL LICENSE
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# Modified by Zexin He
+# The modifications are subject to the same license as the original.
+
+
+import itertools
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .utils.renderer import ImportanceRenderer, sample_from_planes
+from .utils.ray_sampler import RaySampler
+from ...utils.ops import get_rank
+
+
+class OSGDecoder(nn.Module):
+ """
+ Triplane decoder that gives RGB and sigma values from sampled features.
+ Using ReLU here instead of Softplus in the original implementation.
+
+ Reference:
+ EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112
+ """
+ def __init__(self, n_features: int,
+ hidden_dim: int = 64,
+ num_layers: int = 2,
+ activation: nn.Module = nn.ReLU,
+ sdf_bias='sphere',
+ sdf_bias_params=0.5,
+ output_normal=True,
+ normal_type='finite_difference'):
+ super().__init__()
+ self.sdf_bias = sdf_bias
+ self.sdf_bias_params = sdf_bias_params
+ self.output_normal = output_normal
+ self.normal_type = normal_type
+ self.net = nn.Sequential(
+ nn.Linear(3 * n_features, hidden_dim),
+ activation(),
+ *itertools.chain(*[[
+ nn.Linear(hidden_dim, hidden_dim),
+ activation(),
+ ] for _ in range(num_layers - 2)]),
+ nn.Linear(hidden_dim, 1 + 3),
+ )
+ # init all bias to zero
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.zeros_(m.bias)
+
+ def forward(self, ray_directions, sample_coordinates, plane_axes, planes, options):
+ # Aggregate features by mean
+ # sampled_features = sampled_features.mean(1)
+ # Aggregate features by concatenation
+ # torch.set_grad_enabled(True)
+ # sample_coordinates.requires_grad_(True)
+
+ sampled_features = sample_from_planes(plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp'])
+
+
+ _N, n_planes, _M, _C = sampled_features.shape
+ sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
+ x = sampled_features
+
+ N, M, C = x.shape
+ # x = x.contiguous().view(N*M, C)
+
+ x = self.net(x)
+ x = x.view(N, M, -1)
+ rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
+
+ sdf = x[..., 0:1]
+ # import ipdb; ipdb.set_trace()
+ # print(f'sample_coordinates shape: {sample_coordinates.shape}')
+ # sdf = self.get_shifted_sdf(sample_coordinates, sdf)
+
+ # calculate normal
+ eps = 0.01
+ offsets = torch.as_tensor(
+ [[eps, 0.0, 0.0], [0.0, eps, 0.0], [0.0, 0.0, eps]]
+ ).to(sample_coordinates)
+ points_offset = (
+ sample_coordinates[..., None, :] + offsets # Float[Tensor, "... 3 3"]
+ ).clamp(options['sampler_bbox_min'], options['sampler_bbox_max'])
+
+ sdf_offset_list = [self.forward_sdf(
+ plane_axes,
+ planes,
+ points_offset[:,:,i,:],
+ options
+ ).unsqueeze(-2) for i in range(points_offset.shape[-2])] # Float[Tensor, "... 3 1"]
+ # import ipdb; ipdb.set_trace()
+
+ sdf_offset = torch.cat(sdf_offset_list, -2)
+ sdf_grad = (sdf_offset[..., 0::1, 0] - sdf) / eps
+
+ normal = F.normalize(sdf_grad, dim=-1).to(sdf.dtype)
+ return {'rgb': rgb, 'sdf': sdf, 'normal': normal, 'sdf_grad': sdf_grad}
+
+ def forward_sdf(self, plane_axes, planes, points_offset, options):
+
+ sampled_features = sample_from_planes(plane_axes, planes, points_offset, padding_mode='zeros', box_warp=options['box_warp'])
+ _N, n_planes, _M, _C = sampled_features.shape
+ sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
+ x = sampled_features
+
+ N, M, C = x.shape
+ # x = x.contiguous().view(N*M, C)
+
+ x = self.net(x)
+ x = x.view(N, M, -1)
+ sdf = x[..., 0:1]
+ # sdf = self.get_shifted_sdf(points_offset, sdf)
+ return sdf
+
+ def get_shifted_sdf(
+ self, points, sdf
+ ):
+ if self.sdf_bias == "sphere":
+ assert isinstance(self.sdf_bias_params, float)
+ radius = self.sdf_bias_params
+ sdf_bias = (points**2).sum(dim=-1, keepdim=True).sqrt() - radius
+ else:
+ raise ValueError(f"Unknown sdf bias {self.cfg.sdf_bias}")
+ return sdf + sdf_bias.to(sdf.dtype)
+
+
+class TriplaneSynthesizer(nn.Module):
+ """
+ Synthesizer that renders a triplane volume with planes and a camera.
+
+ Reference:
+ EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19
+ """
+
+ DEFAULT_RENDERING_KWARGS = {
+ 'ray_start': 'auto',
+ 'ray_end': 'auto',
+ 'box_warp': 1.2,
+ # 'box_warp': 1.,
+ 'white_back': True,
+ 'disparity_space_sampling': False,
+ 'clamp_mode': 'softplus',
+ # 'sampler_bbox_min': -1,
+ # 'sampler_bbox_max': 1.,
+ 'sampler_bbox_min': -0.6,
+ 'sampler_bbox_max': 0.6,
+ }
+ print('DEFAULT_RENDERING_KWARGS')
+ print(DEFAULT_RENDERING_KWARGS)
+
+
+ def __init__(self, triplane_dim: int, samples_per_ray: int, osg_decoder='default'):
+ super().__init__()
+
+ # attributes
+ self.triplane_dim = triplane_dim
+ self.rendering_kwargs = {
+ **self.DEFAULT_RENDERING_KWARGS,
+ 'depth_resolution': samples_per_ray,
+ 'depth_resolution_importance': 0
+ # 'depth_resolution': samples_per_ray // 2,
+ # 'depth_resolution_importance': samples_per_ray // 2,
+ }
+
+ # renderings
+ self.renderer = ImportanceRenderer()
+ self.ray_sampler = RaySampler()
+ # modules
+ if osg_decoder == 'default':
+ self.decoder = OSGDecoder(n_features=triplane_dim)
+ else:
+ raise NotImplementedError
+
+ def forward(self, planes, ray_origins, ray_directions, render_size, bgcolor=None):
+ # planes: (N, 3, D', H', W')
+ # render_size: int
+ assert ray_origins.dim() == 3, "ray_origins should be 3-dimensional"
+
+
+ # Perform volume rendering
+ rgb_samples, depth_samples, weights_samples, sdf_grad, normal_samples = self.renderer(
+ planes, self.decoder, ray_origins, ray_directions, self.rendering_kwargs, bgcolor
+ )
+ N = planes.shape[0]
+
+ # zhaohx : add for normals
+ normal_samples = F.normalize(normal_samples, dim=-1)
+ normal_samples = (normal_samples + 1.0) / 2.0 # for visualization
+ normal_samples = torch.lerp(torch.zeros_like(normal_samples), normal_samples, weights_samples)
+
+ # Reshape into 'raw' neural-rendered image
+ Himg = Wimg = render_size
+ rgb_images = rgb_samples.permute(0, 2, 1).reshape(N, rgb_samples.shape[-1], Himg, Wimg).contiguous()
+ depth_images = depth_samples.permute(0, 2, 1).reshape(N, 1, Himg, Wimg)
+ weight_images = weights_samples.permute(0, 2, 1).reshape(N, 1, Himg, Wimg)
+
+ # zhaohx : add for normals
+ normal_images = normal_samples.permute(0, 2, 1).reshape(N, normal_samples.shape[-1], Himg, Wimg).contiguous()
+
+ # return {
+ # 'images_rgb': rgb_images,
+ # 'images_depth': depth_images,
+ # 'images_weight': weight_images,
+ # }
+
+ return {
+ 'comp_rgb': rgb_images,
+ 'comp_depth': depth_images,
+ 'opacity': weight_images,
+ 'sdf_grad': sdf_grad,
+ 'comp_normal': normal_images
+ }
+ # 输出normal的话在这个return里加
+
+ def forward_grid(self, planes, grid_size: int, aabb: torch.Tensor = None):
+ # planes: (N, 3, D', H', W')
+ # grid_size: int
+ # aabb: (N, 2, 3)
+ if aabb is None:
+ aabb = torch.tensor([
+ [self.rendering_kwargs['sampler_bbox_min']] * 3,
+ [self.rendering_kwargs['sampler_bbox_max']] * 3,
+ ], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1)
+ assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb"
+ N = planes.shape[0]
+
+ # create grid points for triplane query
+ grid_points = []
+ for i in range(N):
+ grid_points.append(torch.stack(torch.meshgrid(
+ torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device),
+ torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device),
+ torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device),
+ indexing='ij',
+ ), dim=-1).reshape(-1, 3))
+ cube_grid = torch.stack(grid_points, dim=0).to(planes.device)
+
+ features = self.forward_points(planes, cube_grid)
+
+ # reshape into grid
+ features = {
+ k: v.reshape(N, grid_size, grid_size, grid_size, -1)
+ for k, v in features.items()
+ }
+ return features
+
+ def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**20):
+ # planes: (N, 3, D', H', W')
+ # points: (N, P, 3)
+ N, P = points.shape[:2]
+
+ # query triplane in chunks
+ outs = []
+ for i in range(0, points.shape[1], chunk_size):
+ chunk_points = points[:, i:i+chunk_size]
+
+ # query triplane
+ # chunk_out = self.renderer.run_model_activated(
+ chunk_out = self.renderer.run_model(
+ planes=planes,
+ decoder=self.decoder,
+ sample_coordinates=chunk_points,
+ sample_directions=torch.zeros_like(chunk_points),
+ options=self.rendering_kwargs,
+ )
+ outs.append(chunk_out)
+
+ # concatenate the outputs
+ point_features = {
+ k: torch.cat([out[k] for out in outs], dim=1)
+ for k in outs[0].keys()
+ }
+ return point_features
diff --git a/svrm/ldm/modules/rendering_neus/third_party/__init__.py b/svrm/ldm/modules/rendering_neus/third_party/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bb64cc793df9f8474d19b5b29acdd82ed7870da
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/third_party/__init__.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+# empty
diff --git a/svrm/ldm/modules/rendering_neus/third_party/custom_ops.py b/svrm/ldm/modules/rendering_neus/third_party/custom_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a70b33337358d67f34332ba1456f3efc1f08d68
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/third_party/custom_ops.py
@@ -0,0 +1,159 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+import glob
+import hashlib
+import importlib
+import os
+import re
+import shutil
+import uuid
+
+import torch
+import torch.utils.cpp_extension
+from torch.utils.file_baton import FileBaton
+
+#----------------------------------------------------------------------------
+# Global options.
+
+verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
+
+#----------------------------------------------------------------------------
+# Internal helper funcs.
+
+def _find_compiler_bindir():
+ patterns = [
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
+ ]
+ for pattern in patterns:
+ matches = sorted(glob.glob(pattern))
+ if len(matches):
+ return matches[-1]
+ return None
+
+#----------------------------------------------------------------------------
+
+def _get_mangled_gpu_name():
+ name = torch.cuda.get_device_name().lower()
+ out = []
+ for c in name:
+ if re.match('[a-z0-9_-]+', c):
+ out.append(c)
+ else:
+ out.append('-')
+ return ''.join(out)
+
+#----------------------------------------------------------------------------
+# Main entry point for compiling and loading C++/CUDA plugins.
+
+_cached_plugins = dict()
+
+def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
+ assert verbosity in ['none', 'brief', 'full']
+ if headers is None:
+ headers = []
+ if source_dir is not None:
+ sources = [os.path.join(source_dir, fname) for fname in sources]
+ headers = [os.path.join(source_dir, fname) for fname in headers]
+
+ # Already cached?
+ if module_name in _cached_plugins:
+ return _cached_plugins[module_name]
+
+ # Print status.
+ if verbosity == 'full':
+ print(f'Setting up PyTorch plugin "{module_name}"...')
+ elif verbosity == 'brief':
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
+ verbose_build = (verbosity == 'full')
+
+ # Compile and load.
+ try: # pylint: disable=too-many-nested-blocks
+ # Make sure we can find the necessary compiler binaries.
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
+ compiler_bindir = _find_compiler_bindir()
+ if compiler_bindir is None:
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
+ os.environ['PATH'] += ';' + compiler_bindir
+
+ # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
+ # break the build or unnecessarily restrict what's available to nvcc.
+ # Unset it to let nvcc decide based on what's available on the
+ # machine.
+ os.environ['TORCH_CUDA_ARCH_LIST'] = ''
+
+ # Incremental build md5sum trickery. Copies all the input source files
+ # into a cached build directory under a combined md5 digest of the input
+ # source files. Copying is done only if the combined digest has changed.
+ # This keeps input file timestamps and filenames the same as in previous
+ # extension builds, allowing for fast incremental rebuilds.
+ #
+ # This optimization is done only in case all the source files reside in
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
+ # environment variable is set (we take this as a signal that the user
+ # actually cares about this.)
+ #
+ # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
+ # around the *.cu dependency bug in ninja config.
+ #
+ all_source_files = sorted(sources + headers)
+ all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
+ if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
+
+ # Compute combined hash digest for all source files.
+ hash_md5 = hashlib.md5()
+ for src in all_source_files:
+ with open(src, 'rb') as f:
+ hash_md5.update(f.read())
+
+ # Select cached build directory name.
+ source_digest = hash_md5.hexdigest()
+ build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
+ cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
+
+ if not os.path.isdir(cached_build_dir):
+ tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
+ os.makedirs(tmpdir)
+ for src in all_source_files:
+ shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
+ try:
+ os.replace(tmpdir, cached_build_dir) # atomic
+ except OSError:
+ # source directory already exists, delete tmpdir and its contents.
+ shutil.rmtree(tmpdir)
+ if not os.path.isdir(cached_build_dir): raise
+
+ # Compile.
+ cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
+ torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
+ verbose=verbose_build, sources=cached_sources, **build_kwargs)
+ else:
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
+
+ # Load.
+ module = importlib.import_module(module_name)
+
+ except:
+ if verbosity == 'brief':
+ print('Failed!')
+ raise
+
+ # Print status and add to cache dict.
+ if verbosity == 'full':
+ print(f'Done setting up PyTorch plugin "{module_name}".')
+ elif verbosity == 'brief':
+ print('Done.')
+ _cached_plugins[module_name] = module
+ return module
+
+#----------------------------------------------------------------------------
diff --git a/svrm/ldm/modules/rendering_neus/third_party/dnnlib/__init__.py b/svrm/ldm/modules/rendering_neus/third_party/dnnlib/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e4540b93aba3437b83b517eeecdf66c133fbdfd
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/third_party/dnnlib/__init__.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+from .util import EasyDict, make_cache_dir_path
diff --git a/svrm/ldm/modules/rendering_neus/third_party/dnnlib/util.py b/svrm/ldm/modules/rendering_neus/third_party/dnnlib/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..a26590e44be0e9a5011699a88910c0df12ad68d3
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/third_party/dnnlib/util.py
@@ -0,0 +1,493 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Miscellaneous utility classes and functions."""
+
+import ctypes
+import fnmatch
+import importlib
+import inspect
+import numpy as np
+import os
+import shutil
+import sys
+import types
+import io
+import pickle
+import re
+import requests
+import html
+import hashlib
+import glob
+import tempfile
+import urllib
+import urllib.request
+import uuid
+
+from distutils.util import strtobool
+from typing import Any, List, Tuple, Union
+
+
+# Util classes
+# ------------------------------------------------------------------------------------------
+
+
+class EasyDict(dict):
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
+
+ def __getattr__(self, name: str) -> Any:
+ try:
+ return self[name]
+ except KeyError:
+ raise AttributeError(name)
+
+ def __setattr__(self, name: str, value: Any) -> None:
+ self[name] = value
+
+ def __delattr__(self, name: str) -> None:
+ del self[name]
+
+
+class Logger(object):
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
+
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
+ self.file = None
+
+ if file_name is not None:
+ self.file = open(file_name, file_mode)
+
+ self.should_flush = should_flush
+ self.stdout = sys.stdout
+ self.stderr = sys.stderr
+
+ sys.stdout = self
+ sys.stderr = self
+
+ def __enter__(self) -> "Logger":
+ return self
+
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+ self.close()
+
+ def write(self, text: Union[str, bytes]) -> None:
+ """Write text to stdout (and a file) and optionally flush."""
+ if isinstance(text, bytes):
+ text = text.decode()
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
+ return
+
+ if self.file is not None:
+ self.file.write(text)
+
+ self.stdout.write(text)
+
+ if self.should_flush:
+ self.flush()
+
+ def flush(self) -> None:
+ """Flush written text to both stdout and a file, if open."""
+ if self.file is not None:
+ self.file.flush()
+
+ self.stdout.flush()
+
+ def close(self) -> None:
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
+ self.flush()
+
+ # if using multiple loggers, prevent closing in wrong order
+ if sys.stdout is self:
+ sys.stdout = self.stdout
+ if sys.stderr is self:
+ sys.stderr = self.stderr
+
+ if self.file is not None:
+ self.file.close()
+ self.file = None
+
+
+# Cache directories
+# ------------------------------------------------------------------------------------------
+
+_dnnlib_cache_dir = None
+
+def set_cache_dir(path: str) -> None:
+ global _dnnlib_cache_dir
+ _dnnlib_cache_dir = path
+
+def make_cache_dir_path(*paths: str) -> str:
+ if _dnnlib_cache_dir is not None:
+ return os.path.join(_dnnlib_cache_dir, *paths)
+ if 'DNNLIB_CACHE_DIR' in os.environ:
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
+ if 'HOME' in os.environ:
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
+ if 'USERPROFILE' in os.environ:
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
+
+# Small util functions
+# ------------------------------------------------------------------------------------------
+
+
+def format_time(seconds: Union[int, float]) -> str:
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
+ s = int(np.rint(seconds))
+
+ if s < 60:
+ return "{0}s".format(s)
+ elif s < 60 * 60:
+ return "{0}m {1:02}s".format(s // 60, s % 60)
+ elif s < 24 * 60 * 60:
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
+ else:
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
+
+
+def format_time_brief(seconds: Union[int, float]) -> str:
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
+ s = int(np.rint(seconds))
+
+ if s < 60:
+ return "{0}s".format(s)
+ elif s < 60 * 60:
+ return "{0}m {1:02}s".format(s // 60, s % 60)
+ elif s < 24 * 60 * 60:
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
+ else:
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
+
+
+def ask_yes_no(question: str) -> bool:
+ """Ask the user the question until the user inputs a valid answer."""
+ while True:
+ try:
+ print("{0} [y/n]".format(question))
+ return strtobool(input().lower())
+ except ValueError:
+ pass
+
+
+def tuple_product(t: Tuple) -> Any:
+ """Calculate the product of the tuple elements."""
+ result = 1
+
+ for v in t:
+ result *= v
+
+ return result
+
+
+_str_to_ctype = {
+ "uint8": ctypes.c_ubyte,
+ "uint16": ctypes.c_uint16,
+ "uint32": ctypes.c_uint32,
+ "uint64": ctypes.c_uint64,
+ "int8": ctypes.c_byte,
+ "int16": ctypes.c_int16,
+ "int32": ctypes.c_int32,
+ "int64": ctypes.c_int64,
+ "float32": ctypes.c_float,
+ "float64": ctypes.c_double
+}
+
+
+def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
+ type_str = None
+
+ if isinstance(type_obj, str):
+ type_str = type_obj
+ elif hasattr(type_obj, "__name__"):
+ type_str = type_obj.__name__
+ elif hasattr(type_obj, "name"):
+ type_str = type_obj.name
+ else:
+ raise RuntimeError("Cannot infer type name from input")
+
+ assert type_str in _str_to_ctype.keys()
+
+ my_dtype = np.dtype(type_str)
+ my_ctype = _str_to_ctype[type_str]
+
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
+
+ return my_dtype, my_ctype
+
+
+def is_pickleable(obj: Any) -> bool:
+ try:
+ with io.BytesIO() as stream:
+ pickle.dump(obj, stream)
+ return True
+ except:
+ return False
+
+
+# Functionality to import modules/objects by name, and call functions by name
+# ------------------------------------------------------------------------------------------
+
+def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
+ """Searches for the underlying module behind the name to some python object.
+ Returns the module and the object name (original name with module part removed)."""
+
+ # allow convenience shorthands, substitute them by full names
+ obj_name = re.sub("^np.", "numpy.", obj_name)
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
+
+ # list alternatives for (module_name, local_obj_name)
+ parts = obj_name.split(".")
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
+
+ # try each alternative in turn
+ for module_name, local_obj_name in name_pairs:
+ try:
+ module = importlib.import_module(module_name) # may raise ImportError
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
+ return module, local_obj_name
+ except:
+ pass
+
+ # maybe some of the modules themselves contain errors?
+ for module_name, _local_obj_name in name_pairs:
+ try:
+ importlib.import_module(module_name) # may raise ImportError
+ except ImportError:
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
+ raise
+
+ # maybe the requested attribute is missing?
+ for module_name, local_obj_name in name_pairs:
+ try:
+ module = importlib.import_module(module_name) # may raise ImportError
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
+ except ImportError:
+ pass
+
+ # we are out of luck, but we have no idea why
+ raise ImportError(obj_name)
+
+
+def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
+ """Traverses the object name and returns the last (rightmost) python object."""
+ if obj_name == '':
+ return module
+ obj = module
+ for part in obj_name.split("."):
+ obj = getattr(obj, part)
+ return obj
+
+
+def get_obj_by_name(name: str) -> Any:
+ """Finds the python object with the given name."""
+ module, obj_name = get_module_from_obj_name(name)
+ return get_obj_from_module(module, obj_name)
+
+
+def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
+ """Finds the python object with the given name and calls it as a function."""
+ assert func_name is not None
+ func_obj = get_obj_by_name(func_name)
+ assert callable(func_obj)
+ return func_obj(*args, **kwargs)
+
+
+def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
+ """Finds the python class with the given name and constructs it with the given arguments."""
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
+
+
+def get_module_dir_by_obj_name(obj_name: str) -> str:
+ """Get the directory path of the module containing the given object name."""
+ module, _ = get_module_from_obj_name(obj_name)
+ return os.path.dirname(inspect.getfile(module))
+
+
+def is_top_level_function(obj: Any) -> bool:
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
+
+
+def get_top_level_function_name(obj: Any) -> str:
+ """Return the fully-qualified name of a top-level function."""
+ assert is_top_level_function(obj)
+ module = obj.__module__
+ if module == '__main__':
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
+ return module + "." + obj.__name__
+
+
+# File system helpers
+# ------------------------------------------------------------------------------------------
+
+def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
+ """List all files recursively in a given directory while ignoring given file and directory names.
+ Returns list of tuples containing both absolute and relative paths."""
+ assert os.path.isdir(dir_path)
+ base_name = os.path.basename(os.path.normpath(dir_path))
+
+ if ignores is None:
+ ignores = []
+
+ result = []
+
+ for root, dirs, files in os.walk(dir_path, topdown=True):
+ for ignore_ in ignores:
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
+
+ # dirs need to be edited in-place
+ for d in dirs_to_remove:
+ dirs.remove(d)
+
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
+
+ absolute_paths = [os.path.join(root, f) for f in files]
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
+
+ if add_base_to_relative:
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
+
+ assert len(absolute_paths) == len(relative_paths)
+ result += zip(absolute_paths, relative_paths)
+
+ return result
+
+
+def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
+ """Takes in a list of tuples of (src, dst) paths and copies files.
+ Will create all necessary directories."""
+ for file in files:
+ target_dir_name = os.path.dirname(file[1])
+
+ # will create all intermediate-level directories
+ if not os.path.exists(target_dir_name):
+ os.makedirs(target_dir_name)
+
+ shutil.copyfile(file[0], file[1])
+
+
+# URL helpers
+# ------------------------------------------------------------------------------------------
+
+def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
+ """Determine whether the given object is a valid URL string."""
+ if not isinstance(obj, str) or not "://" in obj:
+ return False
+ if allow_file_urls and obj.startswith('file://'):
+ return True
+ try:
+ res = requests.compat.urlparse(obj)
+ if not res.scheme or not res.netloc or not "." in res.netloc:
+ return False
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
+ if not res.scheme or not res.netloc or not "." in res.netloc:
+ return False
+ except:
+ return False
+ return True
+
+
+def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
+ """Download the given URL and return a binary-mode file object to access the data."""
+ assert num_attempts >= 1
+ assert not (return_filename and (not cache))
+
+ # Doesn't look like an URL scheme so interpret it as a local filename.
+ if not re.match('^[a-z]+://', url):
+ return url if return_filename else open(url, "rb")
+
+ # Handle file URLs. This code handles unusual file:// patterns that
+ # arise on Windows:
+ #
+ # file:///c:/foo.txt
+ #
+ # which would translate to a local '/c:/foo.txt' filename that's
+ # invalid. Drop the forward slash for such pathnames.
+ #
+ # If you touch this code path, you should test it on both Linux and
+ # Windows.
+ #
+ # Some internet resources suggest using urllib.request.url2pathname() but
+ # but that converts forward slashes to backslashes and this causes
+ # its own set of problems.
+ if url.startswith('file://'):
+ filename = urllib.parse.urlparse(url).path
+ if re.match(r'^/[a-zA-Z]:', filename):
+ filename = filename[1:]
+ return filename if return_filename else open(filename, "rb")
+
+ assert is_url(url)
+
+ # Lookup from cache.
+ if cache_dir is None:
+ cache_dir = make_cache_dir_path('downloads')
+
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
+ if cache:
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
+ if len(cache_files) == 1:
+ filename = cache_files[0]
+ return filename if return_filename else open(filename, "rb")
+
+ # Download.
+ url_name = None
+ url_data = None
+ with requests.Session() as session:
+ if verbose:
+ print("Downloading %s ..." % url, end="", flush=True)
+ for attempts_left in reversed(range(num_attempts)):
+ try:
+ with session.get(url) as res:
+ res.raise_for_status()
+ if len(res.content) == 0:
+ raise IOError("No data received")
+
+ if len(res.content) < 8192:
+ content_str = res.content.decode("utf-8")
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
+ if len(links) == 1:
+ url = requests.compat.urljoin(url, links[0])
+ raise IOError("Google Drive virus checker nag")
+ if "Google Drive - Quota exceeded" in content_str:
+ raise IOError("Google Drive download quota exceeded -- please try again later")
+
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
+ url_name = match[1] if match else url
+ url_data = res.content
+ if verbose:
+ print(" done")
+ break
+ except KeyboardInterrupt:
+ raise
+ except:
+ if not attempts_left:
+ if verbose:
+ print(" failed")
+ raise
+ if verbose:
+ print(".", end="", flush=True)
+
+ # Save to cache.
+ if cache:
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
+ os.makedirs(cache_dir, exist_ok=True)
+ with open(temp_file, "wb") as f:
+ f.write(url_data)
+ os.replace(temp_file, cache_file) # atomic
+ if return_filename:
+ return cache_file
+
+ # Return data as file object.
+ assert not return_filename
+ return io.BytesIO(url_data)
diff --git a/svrm/ldm/modules/rendering_neus/third_party/misc.py b/svrm/ldm/modules/rendering_neus/third_party/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..09428a800825bee0e3828460ccd5a0c9e1e648aa
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/third_party/misc.py
@@ -0,0 +1,268 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+import re
+import contextlib
+import numpy as np
+import torch
+import warnings
+from ldm.modules.neus.third_party import dnnlib
+
+#----------------------------------------------------------------------------
+# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
+# same constant is used multiple times.
+
+_constant_cache = dict()
+
+def constant(value, shape=None, dtype=None, device=None, memory_format=None):
+ value = np.asarray(value)
+ if shape is not None:
+ shape = tuple(shape)
+ if dtype is None:
+ dtype = torch.get_default_dtype()
+ if device is None:
+ device = torch.device('cpu')
+ if memory_format is None:
+ memory_format = torch.contiguous_format
+
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
+ tensor = _constant_cache.get(key, None)
+ if tensor is None:
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
+ if shape is not None:
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
+ tensor = tensor.contiguous(memory_format=memory_format)
+ _constant_cache[key] = tensor
+ return tensor
+
+#----------------------------------------------------------------------------
+# Replace NaN/Inf with specified numerical values.
+
+try:
+ nan_to_num = torch.nan_to_num # 1.8.0a0
+except AttributeError:
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
+ assert isinstance(input, torch.Tensor)
+ if posinf is None:
+ posinf = torch.finfo(input.dtype).max
+ if neginf is None:
+ neginf = torch.finfo(input.dtype).min
+ assert nan == 0
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
+
+#----------------------------------------------------------------------------
+# Symbolic assert.
+
+try:
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
+except AttributeError:
+ symbolic_assert = torch.Assert # 1.7.0
+
+#----------------------------------------------------------------------------
+# Context manager to temporarily suppress known warnings in torch.jit.trace().
+# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
+
+@contextlib.contextmanager
+def suppress_tracer_warnings():
+ flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
+ warnings.filters.insert(0, flt)
+ yield
+ warnings.filters.remove(flt)
+
+#----------------------------------------------------------------------------
+# Assert that the shape of a tensor matches the given list of integers.
+# None indicates that the size of a dimension is allowed to vary.
+# Performs symbolic assertion when used in torch.jit.trace().
+
+def assert_shape(tensor, ref_shape):
+ if tensor.ndim != len(ref_shape):
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
+ if ref_size is None:
+ pass
+ elif isinstance(ref_size, torch.Tensor):
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
+ elif isinstance(size, torch.Tensor):
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
+ elif size != ref_size:
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
+
+#----------------------------------------------------------------------------
+# Function decorator that calls torch.autograd.profiler.record_function().
+
+def profiled_function(fn):
+ def decorator(*args, **kwargs):
+ with torch.autograd.profiler.record_function(fn.__name__):
+ return fn(*args, **kwargs)
+ decorator.__name__ = fn.__name__
+ return decorator
+
+#----------------------------------------------------------------------------
+# Sampler for torch.utils.data.DataLoader that loops over the dataset
+# indefinitely, shuffling items as it goes.
+
+class InfiniteSampler(torch.utils.data.Sampler):
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
+ assert len(dataset) > 0
+ assert num_replicas > 0
+ assert 0 <= rank < num_replicas
+ assert 0 <= window_size <= 1
+ super().__init__(dataset)
+ self.dataset = dataset
+ self.rank = rank
+ self.num_replicas = num_replicas
+ self.shuffle = shuffle
+ self.seed = seed
+ self.window_size = window_size
+
+ def __iter__(self):
+ order = np.arange(len(self.dataset))
+ rnd = None
+ window = 0
+ if self.shuffle:
+ rnd = np.random.RandomState(self.seed)
+ rnd.shuffle(order)
+ window = int(np.rint(order.size * self.window_size))
+
+ idx = 0
+ while True:
+ i = idx % order.size
+ if idx % self.num_replicas == self.rank:
+ yield order[i]
+ if window >= 2:
+ j = (i - rnd.randint(window)) % order.size
+ order[i], order[j] = order[j], order[i]
+ idx += 1
+
+#----------------------------------------------------------------------------
+# Utilities for operating with torch.nn.Module parameters and buffers.
+
+def params_and_buffers(module):
+ assert isinstance(module, torch.nn.Module)
+ return list(module.parameters()) + list(module.buffers())
+
+def named_params_and_buffers(module):
+ assert isinstance(module, torch.nn.Module)
+ return list(module.named_parameters()) + list(module.named_buffers())
+
+def copy_params_and_buffers(src_module, dst_module, require_all=False):
+ assert isinstance(src_module, torch.nn.Module)
+ assert isinstance(dst_module, torch.nn.Module)
+ src_tensors = dict(named_params_and_buffers(src_module))
+ for name, tensor in named_params_and_buffers(dst_module):
+ assert (name in src_tensors) or (not require_all)
+ if name in src_tensors:
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
+
+#----------------------------------------------------------------------------
+# Context manager for easily enabling/disabling DistributedDataParallel
+# synchronization.
+
+@contextlib.contextmanager
+def ddp_sync(module, sync):
+ assert isinstance(module, torch.nn.Module)
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
+ yield
+ else:
+ with module.no_sync():
+ yield
+
+#----------------------------------------------------------------------------
+# Check DistributedDataParallel consistency across processes.
+
+def check_ddp_consistency(module, ignore_regex=None):
+ assert isinstance(module, torch.nn.Module)
+ for name, tensor in named_params_and_buffers(module):
+ fullname = type(module).__name__ + '.' + name
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
+ continue
+ tensor = tensor.detach()
+ if tensor.is_floating_point():
+ tensor = nan_to_num(tensor)
+ other = tensor.clone()
+ torch.distributed.broadcast(tensor=other, src=0)
+ assert (tensor == other).all(), fullname
+
+#----------------------------------------------------------------------------
+# Print summary table of module hierarchy.
+
+def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
+ assert isinstance(module, torch.nn.Module)
+ assert not isinstance(module, torch.jit.ScriptModule)
+ assert isinstance(inputs, (tuple, list))
+
+ # Register hooks.
+ entries = []
+ nesting = [0]
+ def pre_hook(_mod, _inputs):
+ nesting[0] += 1
+ def post_hook(mod, _inputs, outputs):
+ nesting[0] -= 1
+ if nesting[0] <= max_nesting:
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
+
+ # Run module.
+ outputs = module(*inputs)
+ for hook in hooks:
+ hook.remove()
+
+ # Identify unique outputs, parameters, and buffers.
+ tensors_seen = set()
+ for e in entries:
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
+
+ # Filter out redundant entries.
+ if skip_redundant:
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
+
+ # Construct table.
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
+ rows += [['---'] * len(rows[0])]
+ param_total = 0
+ buffer_total = 0
+ submodule_names = {mod: name for name, mod in module.named_modules()}
+ for e in entries:
+ name = '' if e.mod is module else submodule_names[e.mod]
+ param_size = sum(t.numel() for t in e.unique_params)
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
+ output_shapes = [str(list(t.shape)) for t in e.outputs]
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
+ rows += [[
+ name + (':0' if len(e.outputs) >= 2 else ''),
+ str(param_size) if param_size else '-',
+ str(buffer_size) if buffer_size else '-',
+ (output_shapes + ['-'])[0],
+ (output_dtypes + ['-'])[0],
+ ]]
+ for idx in range(1, len(e.outputs)):
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
+ param_total += param_size
+ buffer_total += buffer_size
+ rows += [['---'] * len(rows[0])]
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
+
+ # Print table.
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
+ print()
+ for row in rows:
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
+ print()
+ return outputs
+
+#----------------------------------------------------------------------------
diff --git a/svrm/ldm/modules/rendering_neus/third_party/ops/__init__.py b/svrm/ldm/modules/rendering_neus/third_party/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bb64cc793df9f8474d19b5b29acdd82ed7870da
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/third_party/ops/__init__.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+# empty
diff --git a/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.cpp b/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..f507b2807e4e2e8cd6e79de122767a88ee1b052a
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.cpp
@@ -0,0 +1,103 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include
+#include
+#include "bias_act.h"
+
+//------------------------------------------------------------------------
+
+static bool has_same_layout(torch::Tensor x, torch::Tensor y)
+{
+ if (x.dim() != y.dim())
+ return false;
+ for (int64_t i = 0; i < x.dim(); i++)
+ {
+ if (x.size(i) != y.size(i))
+ return false;
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
+ return false;
+ }
+ return true;
+}
+
+//------------------------------------------------------------------------
+
+static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
+{
+ // Validate arguments.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
+
+ // Validate layout.
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
+
+ // Create output tensor.
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+ torch::Tensor y = torch::empty_like(x);
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
+
+ // Initialize CUDA kernel parameters.
+ bias_act_kernel_params p;
+ p.x = x.data_ptr();
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
+ p.y = y.data_ptr();
+ p.grad = grad;
+ p.act = act;
+ p.alpha = alpha;
+ p.gain = gain;
+ p.clamp = clamp;
+ p.sizeX = (int)x.numel();
+ p.sizeB = (int)b.numel();
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
+
+ // Choose CUDA kernel.
+ void* kernel;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
+ {
+ kernel = choose_bias_act_kernel(p);
+ });
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
+
+ // Launch CUDA kernel.
+ p.loopX = 4;
+ int blockSize = 4 * 32;
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
+ void* args[] = {&p};
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
+ return y;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("bias_act", &bias_act);
+}
+
+//------------------------------------------------------------------------
diff --git a/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.cu b/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.cu
new file mode 100644
index 0000000000000000000000000000000000000000..a0118b34928895bc3c4f68d96c5bb2750b42aa72
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.cu
@@ -0,0 +1,177 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include "bias_act.h"
+
+//------------------------------------------------------------------------
+// Helpers.
+
+template struct InternalType;
+template <> struct InternalType { typedef double scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+
+//------------------------------------------------------------------------
+// CUDA kernel.
+
+template
+__global__ void bias_act_kernel(bias_act_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+ int G = p.grad;
+ scalar_t alpha = (scalar_t)p.alpha;
+ scalar_t gain = (scalar_t)p.gain;
+ scalar_t clamp = (scalar_t)p.clamp;
+ scalar_t one = (scalar_t)1;
+ scalar_t two = (scalar_t)2;
+ scalar_t expRange = (scalar_t)80;
+ scalar_t halfExpRange = (scalar_t)40;
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
+
+ // Loop over elements.
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
+ {
+ // Load.
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
+ scalar_t y = 0;
+
+ // Apply bias.
+ ((G == 0) ? x : xref) += b;
+
+ // linear
+ if (A == 1)
+ {
+ if (G == 0) y = x;
+ if (G == 1) y = x;
+ }
+
+ // relu
+ if (A == 2)
+ {
+ if (G == 0) y = (x > 0) ? x : 0;
+ if (G == 1) y = (yy > 0) ? x : 0;
+ }
+
+ // lrelu
+ if (A == 3)
+ {
+ if (G == 0) y = (x > 0) ? x : x * alpha;
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
+ }
+
+ // tanh
+ if (A == 4)
+ {
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
+ if (G == 1) y = x * (one - yy * yy);
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
+ }
+
+ // sigmoid
+ if (A == 5)
+ {
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
+ if (G == 1) y = x * yy * (one - yy);
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
+ }
+
+ // elu
+ if (A == 6)
+ {
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
+ }
+
+ // selu
+ if (A == 7)
+ {
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
+ }
+
+ // softplus
+ if (A == 8)
+ {
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
+ if (G == 1) y = x * (one - exp(-yy));
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
+ }
+
+ // swish
+ if (A == 9)
+ {
+ if (G == 0)
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
+ else
+ {
+ scalar_t c = exp(xref);
+ scalar_t d = c + one;
+ if (G == 1)
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
+ else
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
+ }
+ }
+
+ // Apply gain.
+ y *= gain * dy;
+
+ // Clamp.
+ if (clamp >= 0)
+ {
+ if (G == 0)
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
+ else
+ y = (yref > -clamp & yref < clamp) ? y : 0;
+ }
+
+ // Store.
+ ((T*)p.y)[xi] = (T)y;
+ }
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
+{
+ if (p.act == 1) return (void*)bias_act_kernel;
+ if (p.act == 2) return (void*)bias_act_kernel;
+ if (p.act == 3) return (void*)bias_act_kernel;
+ if (p.act == 4) return (void*)bias_act_kernel;
+ if (p.act == 5) return (void*)bias_act_kernel;
+ if (p.act == 6) return (void*)bias_act_kernel;
+ if (p.act == 7) return (void*)bias_act_kernel;
+ if (p.act == 8) return (void*)bias_act_kernel;
+ if (p.act == 9) return (void*)bias_act_kernel;
+ return NULL;
+}
+
+//------------------------------------------------------------------------
+// Template specializations.
+
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.h b/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.h
new file mode 100644
index 0000000000000000000000000000000000000000..8a2c3a861ce21e6470de2aaaec24bb2713259c63
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.h
@@ -0,0 +1,42 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+//------------------------------------------------------------------------
+// CUDA kernel parameters.
+
+struct bias_act_kernel_params
+{
+ const void* x; // [sizeX]
+ const void* b; // [sizeB] or NULL
+ const void* xref; // [sizeX] or NULL
+ const void* yref; // [sizeX] or NULL
+ const void* dy; // [sizeX] or NULL
+ void* y; // [sizeX]
+
+ int grad;
+ int act;
+ float alpha;
+ float gain;
+ float clamp;
+
+ int sizeX;
+ int sizeB;
+ int stepB;
+ int loopX;
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.py b/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d4e2aedd6deddafe1fd9c80a009f6d1ec99746b
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.py
@@ -0,0 +1,211 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Custom PyTorch ops for efficient bias and activation."""
+
+import os
+import numpy as np
+import torch
+from ldm.modules.neus.third_party import dnnlib
+
+from .. import custom_ops
+from .. import misc
+
+#----------------------------------------------------------------------------
+
+activation_funcs = {
+ 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
+ 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
+ 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
+ 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
+ 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
+ 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
+ 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
+ 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
+ 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
+}
+
+#----------------------------------------------------------------------------
+
+_plugin = None
+_null_tensor = torch.empty([0])
+
+def _init():
+ global _plugin
+ if _plugin is None:
+ _plugin = custom_ops.get_plugin(
+ module_name='bias_act_plugin',
+ sources=['bias_act.cpp', 'bias_act.cu'],
+ headers=['bias_act.h'],
+ source_dir=os.path.dirname(__file__),
+ extra_cuda_cflags=['--use_fast_math'],
+ )
+ return True
+
+#----------------------------------------------------------------------------
+
+def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
+ r"""Fused bias and activation function.
+
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
+ the fused op is considerably more efficient than performing the same calculation
+ using standard PyTorch ops. It supports first and second order gradients,
+ but not third order gradients.
+
+ Args:
+ x: Input activation tensor. Can be of any shape.
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
+ as `x`. The shape must be known, and it must match the dimension of `x`
+ corresponding to `dim`.
+ dim: The dimension in `x` corresponding to the elements of `b`.
+ The value of `dim` is ignored if `b` is not specified.
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
+ See `activation_funcs` for a full list. `None` is not allowed.
+ alpha: Shape parameter for the activation function, or `None` to use the default.
+ gain: Scaling factor for the output tensor, or `None` to use default.
+ See `activation_funcs` for the default scaling of each activation function.
+ If unsure, consider specifying 1.
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
+ the clamping (default).
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
+
+ Returns:
+ Tensor of the same shape and datatype as `x`.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert impl in ['ref', 'cuda']
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Add bias.
+ if b is not None:
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
+ assert 0 <= dim < x.ndim
+ assert b.shape[0] == x.shape[dim]
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
+
+ # Evaluate activation function.
+ alpha = float(alpha)
+ x = spec.func(x, alpha=alpha)
+
+ # Scale by gain.
+ gain = float(gain)
+ if gain != 1:
+ x = x * gain
+
+ # Clamp.
+ if clamp >= 0:
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
+ return x
+
+#----------------------------------------------------------------------------
+
+_bias_act_cuda_cache = dict()
+
+def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
+ """Fast CUDA implementation of `bias_act()` using custom ops.
+ """
+ # Parse arguments.
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Lookup from cache.
+ key = (dim, act, alpha, gain, clamp)
+ if key in _bias_act_cuda_cache:
+ return _bias_act_cuda_cache[key]
+
+ # Forward op.
+ class BiasActCuda(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format
+ x = x.contiguous(memory_format=ctx.memory_format)
+ b = b.contiguous() if b is not None else _null_tensor
+ y = x
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
+ ctx.save_for_backward(
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ y if 'y' in spec.ref else _null_tensor)
+ return y
+
+ @staticmethod
+ def backward(ctx, dy): # pylint: disable=arguments-differ
+ dy = dy.contiguous(memory_format=ctx.memory_format)
+ x, b, y = ctx.saved_tensors
+ dx = None
+ db = None
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ dx = dy
+ if act != 'linear' or gain != 1 or clamp >= 0:
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
+
+ if ctx.needs_input_grad[1]:
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
+
+ return dx, db
+
+ # Backward op.
+ class BiasActCudaGrad(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
+ ctx.save_for_backward(
+ dy if spec.has_2nd_grad else _null_tensor,
+ x, b, y)
+ return dx
+
+ @staticmethod
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
+ dy, x, b, y = ctx.saved_tensors
+ d_dy = None
+ d_x = None
+ d_b = None
+ d_y = None
+
+ if ctx.needs_input_grad[0]:
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
+
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
+
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
+
+ return d_dy, d_x, d_b, d_y
+
+ # Add to cache.
+ _bias_act_cuda_cache[key] = BiasActCuda
+ return BiasActCuda
+
+#----------------------------------------------------------------------------
diff --git a/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample.cpp b/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..05ba9baa08b8bf0aac3643219306b871f9dee650
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample.cpp
@@ -0,0 +1,57 @@
+#include
+
+#include
+
+// CUDA forward declarations
+
+namespace at {namespace native {
+std::vector grid_sample2d_cuda_grad2(
+ const torch::Tensor &grad2_grad_input,
+ const torch::Tensor &grad2_grad_grid,
+ const torch::Tensor &grad_output,
+ const torch::Tensor &input,
+ const torch::Tensor &grid,
+ bool padding_mode,
+ bool align_corners);
+std::vector grid_sample3d_cuda_grad2(
+ const torch::Tensor &grad2_grad_input,
+ const torch::Tensor &grad2_grad_grid,
+ const torch::Tensor &grad_output,
+ const torch::Tensor &input,
+ const torch::Tensor &grid,
+ bool padding_mode,
+ bool align_corners);
+}}
+
+std::vector grid_sample2d_grad2(
+ const torch::Tensor &grad2_grad_input,
+ const torch::Tensor &grad2_grad_grid,
+ const torch::Tensor &grad_output,
+ const torch::Tensor &input,
+ const torch::Tensor &grid,
+ bool padding_mode,
+ bool align_corners) {
+
+ return at::native::grid_sample2d_cuda_grad2(grad2_grad_input, grad2_grad_grid,
+ grad_output, input, grid, padding_mode, align_corners);
+}
+
+std::vector grid_sample3d_grad2(
+ const torch::Tensor &grad2_grad_input,
+ const torch::Tensor &grad2_grad_grid,
+ const torch::Tensor &grad_output,
+ const torch::Tensor &input,
+ const torch::Tensor &grid,
+ bool padding_mode,
+ bool align_corners) {
+
+ return at::native::grid_sample3d_cuda_grad2(grad2_grad_input, grad2_grad_grid,
+ grad_output, input, grid, padding_mode, align_corners);
+}
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("grad2_2d", &grid_sample2d_grad2, "grid_sample2d second derivative");
+ m.def("grad2_3d", &grid_sample3d_grad2, "grid_sample3d second derivative");
+}
+
diff --git a/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample.cu b/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample.cu
new file mode 100644
index 0000000000000000000000000000000000000000..02b5a33918eec1c6baac348b765bb2e3467ce319
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample.cu
@@ -0,0 +1,668 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+
+#include
+
+namespace at { namespace native {
+namespace {
+
+using namespace at::cuda::detail;
+
+using at::native::detail::GridSamplerInterpolation;
+using at::native::detail::GridSamplerPadding;
+
+template
+ C10_LAUNCH_BOUNDS_1(256)
+ __global__ void grid_sampler_2d_grad2_kernel(
+ const index_t nthreads,
+ TensorInfo grad2_grad_input,
+ TensorInfo grad2_grad_grid,
+ TensorInfo grad_output,
+ TensorInfo input,
+ TensorInfo grid,
+ TensorInfo grad_grad_output,
+ TensorInfo grad_input,
+ TensorInfo grad_grid,
+ const GridSamplerPadding padding_mode,
+ bool align_corners,
+ const index_t grad_input_memory_span) {
+
+ index_t C = input.sizes[1];
+ index_t inp_H = input.sizes[2];
+ index_t inp_W = input.sizes[3];
+
+ index_t out_H = grid.sizes[1];
+ index_t out_W = grid.sizes[2];
+
+ index_t g2inp_sN = grad2_grad_input.strides[0];
+ index_t g2inp_sC = grad2_grad_input.strides[1];
+ index_t g2inp_sH = grad2_grad_input.strides[2];
+ index_t g2inp_sW = grad2_grad_input.strides[3];
+
+ index_t g2grid_sN = grad2_grad_grid.strides[0];
+ index_t g2grid_sH = grad2_grad_grid.strides[1];
+ index_t g2grid_sW = grad2_grad_grid.strides[2];
+ index_t g2grid_sCoor = grad2_grad_grid.strides[3];
+
+ index_t gOut_sN = grad_output.strides[0];
+ index_t gOut_sC = grad_output.strides[1];
+ index_t gOut_sH = grad_output.strides[2];
+ index_t gOut_sW = grad_output.strides[3];
+
+ index_t inp_sN = input.strides[0];
+ index_t inp_sC = input.strides[1];
+ index_t inp_sH = input.strides[2];
+ index_t inp_sW = input.strides[3];
+
+ index_t grid_sN = grid.strides[0];
+ index_t grid_sH = grid.strides[1];
+ index_t grid_sW = grid.strides[2];
+ index_t grid_sCoor = grid.strides[3];
+
+ index_t gInp_sN = grad_input.strides[0];
+ index_t gInp_sC = grad_input.strides[1];
+ index_t gInp_sH = grad_input.strides[2];
+ index_t gInp_sW = grad_input.strides[3];
+
+ index_t gGrid_sW = grad_grid.strides[2];
+
+ index_t ggOut_sN = grad_grad_output.strides[0];
+ index_t ggOut_sC = grad_grad_output.strides[1];
+ index_t ggOut_sH = grad_grad_output.strides[2];
+ index_t ggOut_sW = grad_grad_output.strides[3];
+
+ CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
+ const index_t w = index % out_W;
+ const index_t h = (index / out_W) % out_H;
+ const index_t n = index / (out_H * out_W);
+
+ /* Grid related staff */
+ index_t grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
+
+ // get the corresponding input x, y co-ordinates from grid
+ scalar_t x = grid.data[grid_offset];
+ scalar_t y = grid.data[grid_offset + grid_sCoor];
+
+ // multipliers for gradients on ix and iy
+ scalar_t gix_mult, giy_mult;
+ scalar_t ix = grid_sampler_compute_source_index_set_grad(x, inp_W, padding_mode, align_corners, &gix_mult);
+ scalar_t iy = grid_sampler_compute_source_index_set_grad(y, inp_H, padding_mode, align_corners, &giy_mult);
+
+ // get NE, NW, SE, SW pixel values from (x, y)
+ index_t ix_nw = static_cast(::floor(ix));
+ index_t iy_nw = static_cast(::floor(iy));
+ index_t ix_ne = ix_nw + 1;
+ index_t iy_ne = iy_nw;
+ index_t ix_sw = ix_nw;
+ index_t iy_sw = iy_nw + 1;
+ index_t ix_se = ix_nw + 1;
+ index_t iy_se = iy_nw + 1;
+
+ // get surfaces to each neighbor:
+ scalar_t nw = (ix_se - ix) * (iy_se - iy);
+ scalar_t ne = (ix - ix_sw) * (iy_sw - iy);
+ scalar_t sw = (ix_ne - ix) * (iy - iy_ne);
+ scalar_t se = (ix - ix_nw) * (iy - iy_nw);
+
+ /* grad2_grad_input related init */
+ scalar_t *g2_inp_ptr_NC = grad2_grad_input.data + n * g2inp_sN;
+
+ /* grad2_grad_grid related init */
+ grid_offset = n * g2grid_sN + h * g2grid_sH + w * g2grid_sW;
+ scalar_t dx = grad2_grad_grid.data[grid_offset];
+ scalar_t dy = grad2_grad_grid.data[grid_offset + g2grid_sCoor];
+
+ dx = dx * gix_mult;
+ dy = dy * giy_mult;
+
+ /* grad_output related init */
+ scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW;
+
+ /* input related init */
+ scalar_t *inp_ptr_NC = input.data + n * inp_sN;
+
+ /* grad_grad_output related init */
+ scalar_t *ggOut_ptr_NCHW = grad_grad_output.data + n * ggOut_sN + h * ggOut_sH + w * ggOut_sW;
+
+ /* grad_input related init */
+ index_t NC_offset = n * gInp_sN;
+
+ /* grad_grid related init */
+ scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW;
+ scalar_t gix = static_cast(0), giy = static_cast(0);
+
+ scalar_t nw_val, ne_val, sw_val, se_val;
+ scalar_t g2_nw_val, g2_ne_val, g2_sw_val, g2_se_val;
+
+ scalar_t zero = static_cast(0);
+ for (index_t c = 0; c < C;
+ ++c,
+ g2_inp_ptr_NC += g2inp_sC,
+ inp_ptr_NC += inp_sC,
+ NC_offset += gInp_sC,
+ gOut_ptr_NCHW += gOut_sC,
+ ggOut_ptr_NCHW += ggOut_sC) {
+
+ nw_val = within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)? inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW]: zero;
+ ne_val = within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)? inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW]: zero;
+ sw_val = within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)? inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW]: zero;
+ se_val = within_bounds_2d(iy_se, ix_se, inp_H, inp_W)? inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW]: zero;
+
+ g2_nw_val = within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)? g2_inp_ptr_NC[iy_nw * g2inp_sH + ix_nw * g2inp_sW]: zero;
+ g2_ne_val = within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)? g2_inp_ptr_NC[iy_ne * g2inp_sH + ix_ne * g2inp_sW]: zero;
+ g2_sw_val = within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)? g2_inp_ptr_NC[iy_sw * g2inp_sH + ix_sw * g2inp_sW]: zero;
+ g2_se_val = within_bounds_2d(iy_se, ix_se, inp_H, inp_W)? g2_inp_ptr_NC[iy_se * g2inp_sH + ix_se * g2inp_sW]: zero;
+
+ // Computing gradient wrt to grad_output = grad2_grad_input * x * y + grad2_grad_grid_x * y * val + grad2_grad_grid_y * x * val
+ // grad2_grad_input * x * y
+ *ggOut_ptr_NCHW = static_cast(0);
+ *ggOut_ptr_NCHW += g2_nw_val * nw + g2_ne_val * ne + g2_sw_val * sw + g2_se_val * se;
+
+ scalar_t nw_tmp = -dx * (iy_se - iy) - dy * (ix_se - ix);
+ scalar_t ne_tmp = +dx * (iy_sw - iy) - dy * (ix - ix_sw);
+ scalar_t sw_tmp = -dx * (iy - iy_ne) + dy * (ix_ne - ix);
+ scalar_t se_tmp = +dx * (iy - iy_nw) + dy * (ix - ix_nw);
+
+
+ // grad2_grad_grid_x * y * val + grad2_grad_grid_y * x * val
+ *ggOut_ptr_NCHW += nw_val * nw_tmp + ne_tmp * ne_val + sw_tmp * sw_val + se_tmp * se_val;
+
+ // Computing gradient wrt input = grad2_grad_grid_x * grad_output * y + grad2_grad_grid_y * grad_output * x
+ scalar_t gOut = *gOut_ptr_NCHW;
+ //scalar_t val;
+ //val = gOut * (-dx * (iy_se - iy) - dy * (ix_se - ix));
+ safe_add_2d(grad_input.data, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, nw_tmp * gOut, NC_offset, grad_input_memory_span);
+ //val = gOut * (+dx * (iy_sw - iy) - dy * (ix - ix_sw));
+ safe_add_2d(grad_input.data, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, ne_tmp * gOut, NC_offset, grad_input_memory_span);
+ //val = gOut * (-dx * (iy - iy_ne) + dy * (ix_ne - ix));
+ safe_add_2d(grad_input.data, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, sw_tmp * gOut, NC_offset, grad_input_memory_span);
+ //val = gOut * (+dx * (iy - iy_nw) + dy * (ix - ix_nw));
+ safe_add_2d(grad_input.data, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, se_tmp * gOut, NC_offset, grad_input_memory_span);
+
+ scalar_t dxy = nw_val - ne_val - sw_val + se_val;
+ // Computing gradient wrt grid_x = grad2_grad_input * y * gOut + grad2_grad_grid_y * val * gOut
+ gix += gOut * (-g2_nw_val * (iy_se - iy) + g2_ne_val * (iy_sw - iy)
+ -g2_sw_val * (iy - iy_ne) + g2_se_val * (iy - iy_nw));
+ gix += gOut * dy * dxy;
+
+ // Computing gradient wrt grid_y = grad2_grad_input * x * gOut + grad2_grad_grid_x * val * gOut
+ giy += gOut * (-g2_nw_val * (ix_se - ix) - g2_ne_val * (ix - ix_sw)
+ +g2_sw_val * (ix_ne - ix) + g2_se_val * (ix - ix_nw));
+ giy += gOut * dx * dxy;
+ }
+
+ gGrid_ptr_NHW[0] = gix * gix_mult;
+ gGrid_ptr_NHW[1] = giy * giy_mult;
+ }
+}
+
+template
+ C10_LAUNCH_BOUNDS_1(256)
+ __global__ void grid_sampler_3d_grad2_kernel(
+ const index_t nthreads,
+ TensorInfo grad2_grad_input,
+ TensorInfo grad2_grad_grid,
+ TensorInfo grad_output,
+ TensorInfo input,
+ TensorInfo grid,
+ TensorInfo grad_grad_output,
+ TensorInfo grad_input,
+ TensorInfo grad_grid,
+ const GridSamplerPadding padding_mode,
+ bool align_corners,
+ const index_t grad_input_memory_span) {
+
+ index_t C = input.sizes[1];
+ index_t inp_D = input.sizes[2];
+ index_t inp_H = input.sizes[3];
+ index_t inp_W = input.sizes[4];
+
+ index_t out_D = grid.sizes[1];
+ index_t out_H = grid.sizes[2];
+ index_t out_W = grid.sizes[3];
+
+ index_t g2inp_sN = grad2_grad_input.strides[0];
+ index_t g2inp_sC = grad2_grad_input.strides[1];
+ index_t g2inp_sD = grad2_grad_input.strides[2];
+ index_t g2inp_sH = grad2_grad_input.strides[3];
+ index_t g2inp_sW = grad2_grad_input.strides[4];
+
+ index_t g2grid_sN = grad2_grad_grid.strides[0];
+ index_t g2grid_sD = grad2_grad_grid.strides[1];
+ index_t g2grid_sH = grad2_grad_grid.strides[2];
+ index_t g2grid_sW = grad2_grad_grid.strides[3];
+ index_t g2grid_sCoor = grad2_grad_grid.strides[4];
+
+ index_t gOut_sN = grad_output.strides[0];
+ index_t gOut_sC = grad_output.strides[1];
+ index_t gOut_sD = grad_output.strides[2];
+ index_t gOut_sH = grad_output.strides[3];
+ index_t gOut_sW = grad_output.strides[4];
+
+ index_t inp_sN = input.strides[0];
+ index_t inp_sC = input.strides[1];
+ index_t inp_sD = input.strides[2];
+ index_t inp_sH = input.strides[3];
+ index_t inp_sW = input.strides[4];
+
+ index_t grid_sN = grid.strides[0];
+ index_t grid_sD = grid.strides[1];
+ index_t grid_sH = grid.strides[2];
+ index_t grid_sW = grid.strides[3];
+ index_t grid_sCoor = grid.strides[4];
+
+ index_t gInp_sN = grad_input.strides[0];
+ index_t gInp_sC = grad_input.strides[1];
+ index_t gInp_sD = grad_input.strides[2];
+ index_t gInp_sH = grad_input.strides[3];
+ index_t gInp_sW = grad_input.strides[4];
+
+ index_t gGrid_sW = grad_grid.strides[3];
+
+ index_t ggOut_sN = grad_grad_output.strides[0];
+ index_t ggOut_sC = grad_grad_output.strides[1];
+ index_t ggOut_sD = grad_grad_output.strides[2];
+ index_t ggOut_sH = grad_grad_output.strides[3];
+ index_t ggOut_sW = grad_grad_output.strides[4];
+
+ CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
+ const index_t w = index % out_W;
+ const index_t h = (index / out_W) % out_H;
+ const index_t d = (index / (out_H * out_W)) % out_D;
+ const index_t n = index / (out_D * out_H * out_W);
+
+ /* Grid related staff */
+ index_t grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;
+
+ // get the corresponding input x, y co-ordinates from grid
+ scalar_t ix = grid.data[grid_offset];
+ scalar_t iy = grid.data[grid_offset + grid_sCoor];
+ scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor];
+
+ // multipliers for gradients on ix and iy
+ scalar_t gix_mult, giy_mult, giz_mult;
+ ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult);
+ iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult);
+ iz = grid_sampler_compute_source_index_set_grad(iz, inp_D, padding_mode, align_corners, &giz_mult);
+
+ // get NE, NW, SE, SW pixel values from (x, y)
+ index_t ix_tnw = static_cast(::floor(ix));
+ index_t iy_tnw = static_cast(::floor(iy));
+ index_t iz_tnw = static_cast(::floor(iz));
+
+ index_t ix_tne = ix_tnw + 1;
+ index_t iy_tne = iy_tnw;
+ index_t iz_tne = iz_tnw;
+
+ index_t ix_tsw = ix_tnw;
+ index_t iy_tsw = iy_tnw + 1;
+ index_t iz_tsw = iz_tnw;
+
+ index_t ix_tse = ix_tnw + 1;
+ index_t iy_tse = iy_tnw + 1;
+ index_t iz_tse = iz_tnw;
+
+ index_t ix_bnw = ix_tnw;
+ index_t iy_bnw = iy_tnw;
+ index_t iz_bnw = iz_tnw + 1;
+
+ index_t ix_bne = ix_tnw + 1;
+ index_t iy_bne = iy_tnw;
+ index_t iz_bne = iz_tnw + 1;
+
+ index_t ix_bsw = ix_tnw;
+ index_t iy_bsw = iy_tnw + 1;
+ index_t iz_bsw = iz_tnw + 1;
+
+ index_t ix_bse = ix_tnw + 1;
+ index_t iy_bse = iy_tnw + 1;
+ index_t iz_bse = iz_tnw + 1;
+
+ // get surfaces to each neighbor:
+ scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
+ scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
+ scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
+ scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
+ scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
+ scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
+ scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
+ scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
+
+ /* grad2_grad_input related init */
+ scalar_t *g2_inp_ptr_NC = grad2_grad_input.data + n * g2inp_sN;
+
+ /* grad2_grad_grid related init */
+ grid_offset = n * g2grid_sN + d * g2grid_sD + h * g2grid_sH + w * g2grid_sW;
+ scalar_t dx = grad2_grad_grid.data[grid_offset];
+ scalar_t dy = grad2_grad_grid.data[grid_offset + g2grid_sCoor];
+ scalar_t dz = grad2_grad_grid.data[grid_offset + 2 * g2grid_sCoor];
+
+ dx = dx * gix_mult;
+ dy = dy * giy_mult;
+ dz = dz * giz_mult;
+
+ /* grad_output related init */
+ scalar_t *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
+
+ /* input related init */
+ scalar_t *inp_ptr_NC = input.data + n * inp_sN;
+
+ /* grad_grad_output related init */
+ scalar_t *ggOut_ptr_NCDHW = grad_grad_output.data + n * ggOut_sN + d * ggOut_sD + h * ggOut_sH + w * ggOut_sW;
+
+ /* grad_input related init */
+ index_t NC_offset = n * gInp_sN;
+
+ /* grad_grid related init */
+ scalar_t *gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sW;
+ scalar_t gix = static_cast(0), giy = static_cast(0), giz = static_cast(0);
+
+ scalar_t tnw_val, tne_val, tsw_val, tse_val, bnw_val, bne_val, bsw_val, bse_val;
+ scalar_t g2_tnw_val, g2_tne_val, g2_tsw_val, g2_tse_val, g2_bnw_val, g2_bne_val, g2_bsw_val, g2_bse_val;
+
+ scalar_t zero = static_cast(0);
+ for (index_t c = 0; c < C;
+ ++c,
+ g2_inp_ptr_NC += g2inp_sC,
+ inp_ptr_NC += inp_sC,
+ NC_offset += gInp_sC,
+ gOut_ptr_NCDHW += gOut_sC,
+ ggOut_ptr_NCDHW += ggOut_sC) {
+
+ if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
+ tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW];
+ g2_tnw_val = g2_inp_ptr_NC[iz_tnw * g2inp_sD + iy_tnw * g2inp_sH + ix_tnw * g2inp_sW];
+ } else {
+ tnw_val = zero;
+ g2_tnw_val = zero;
+ }
+ if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
+ tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW];
+ g2_tne_val = g2_inp_ptr_NC[iz_tne * g2inp_sD + iy_tne * g2inp_sH + ix_tne * g2inp_sW];
+ } else {
+ tne_val = zero;
+ g2_tne_val = zero;
+ }
+ if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
+ tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW];
+ g2_tsw_val = g2_inp_ptr_NC[iz_tsw * g2inp_sD + iy_tsw * g2inp_sH + ix_tsw * g2inp_sW];
+ } else {
+ tsw_val = zero;
+ g2_tsw_val = zero;
+ }
+ if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
+ tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW];
+ g2_tse_val = g2_inp_ptr_NC[iz_tse * g2inp_sD + iy_tse * g2inp_sH + ix_tse * g2inp_sW];
+ } else {
+ tse_val = zero;
+ g2_tse_val = zero;
+ }
+
+ if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
+ bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW];
+ g2_bnw_val = g2_inp_ptr_NC[iz_bnw * g2inp_sD + iy_bnw * g2inp_sH + ix_bnw * g2inp_sW];
+ } else {
+ bnw_val = zero;
+ g2_bnw_val = zero;
+ }
+ if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
+ bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW];
+ g2_bne_val = g2_inp_ptr_NC[iz_bne * g2inp_sD + iy_bne * g2inp_sH + ix_bne * g2inp_sW];
+ } else {
+ bne_val = zero;
+ g2_bne_val = zero;
+ }
+ if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
+ bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW];
+ g2_bsw_val = g2_inp_ptr_NC[iz_bsw * g2inp_sD + iy_bsw * g2inp_sH + ix_bsw * g2inp_sW];
+ } else {
+ bsw_val = zero;
+ g2_bsw_val = zero;
+ }
+ if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
+ bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW];
+ g2_bse_val = g2_inp_ptr_NC[iz_bse * g2inp_sD + iy_bse * g2inp_sH + ix_bse * g2inp_sW];
+ } else {
+ bse_val = zero;
+ g2_bse_val = zero;
+ }
+
+ // Computing gradient wrt to grad_output =
+ // grad2_grad_input * x * y * z
+ *ggOut_ptr_NCDHW = static_cast(0);
+ *ggOut_ptr_NCDHW += g2_tnw_val * tnw + g2_tne_val * tne + g2_tsw_val * tsw + g2_tse_val * tse
+ +g2_bnw_val * bnw + g2_bne_val * bne + g2_bsw_val * bsw + g2_bse_val * bse;
+
+ // +val * (grad2_grad_grid_x * y * z + grad2_grad_grid_y * x * z + grad2_grad_grid_z * x * y)
+ scalar_t tnw_tmp = (-dx * (iy_bse - iy) * (iz_bse - iz) - dy * (ix_bse - ix) * (iz_bse - iz) - dz * (ix_bse - ix) * (iy_bse - iy));
+ scalar_t tne_tmp = (+dx * (iy_bsw - iy) * (iz_bsw - iz) - dy * (ix - ix_bsw) * (iz_bsw - iz) - dz * (ix - ix_bsw) * (iy_bsw - iy));
+ scalar_t tsw_tmp = (-dx * (iy - iy_bne) * (iz_bne - iz) + dy * (ix_bne - ix) * (iz_bne - iz) - dz * (ix_bne - ix) * (iy - iy_bne));
+ scalar_t tse_tmp = (+dx * (iy - iy_bnw) * (iz_bnw - iz) + dy * (ix - ix_bnw) * (iz_bnw - iz) - dz * (ix - ix_bnw) * (iy - iy_bnw));
+ scalar_t bnw_tmp = (-dx * (iy_tse - iy) * (iz - iz_tse) - dy * (ix_tse - ix) * (iz - iz_tse) + dz * (ix_tse - ix) * (iy_tse - iy));
+ scalar_t bne_tmp = (+dx * (iy_tsw - iy) * (iz - iz_tsw) - dy * (ix - ix_tsw) * (iz - iz_tsw) + dz * (ix - ix_tsw) * (iy_tsw - iy));
+ scalar_t bsw_tmp = (-dx * (iy - iy_tne) * (iz - iz_tne) + dy * (ix_tne - ix) * (iz - iz_tne) + dz * (ix_tne - ix) * (iy - iy_tne));
+ scalar_t bse_tmp = (+dx * (iy - iy_tnw) * (iz - iz_tnw) + dy * (ix - ix_tnw) * (iz - iz_tnw) + dz * (ix - ix_tnw) * (iy - iy_tnw));
+
+ *ggOut_ptr_NCDHW += tnw_val * tnw_tmp + tne_val * tne_tmp + tsw_val * tsw_tmp + tse_val * tse_tmp
+ +bnw_val * bnw_tmp + bne_val * bne_tmp + bsw_val * bsw_tmp + bse_val * bse_tmp;
+
+ // Computing gradient wrt input = grad2_grad_grid_x * grad_output * y * z + grad2_grad_grid_y * grad_output * x * z +
+ // grad2_grad_grid_z * grad_output * y * z
+ scalar_t gOut = *gOut_ptr_NCDHW;
+
+ safe_add_3d(grad_input.data, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+ safe_add_3d(grad_input.data, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+ safe_add_3d(grad_input.data, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+ safe_add_3d(grad_input.data, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+ safe_add_3d(grad_input.data, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+ safe_add_3d(grad_input.data, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+ safe_add_3d(grad_input.data, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+ safe_add_3d(grad_input.data, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+
+ //Computing gradient wrt grid
+ scalar_t dxy = (tnw_val * (iz_bse - iz) - tne_val * (iz_bsw - iz)
+ -tsw_val * (iz_bne - iz) + tse_val * (iz_bnw - iz)
+ +bnw_val * (iz - iz_tse) - bne_val * (iz - iz_tsw)
+ -bsw_val * (iz - iz_tne) + bse_val * (iz - iz_tnw));
+
+ scalar_t dxz = (tnw_val * (iy_bse - iy) - tne_val * (iy_bsw - iy)
+ +tsw_val * (iy - iy_bne) - tse_val * (iy - iy_bnw)
+ -bnw_val * (iy_tse - iy) + bne_val * (iy_tsw - iy)
+ -bsw_val * (iy - iy_tne) + bse_val * (iy - iy_tnw));
+
+ scalar_t dyz = (tnw_val * (ix_bse - ix) + tne_val * (ix - ix_bsw)
+ -tsw_val * (ix_bne - ix) - tse_val * (ix - ix_bnw)
+ -bnw_val * (ix_tse - ix) - bne_val * (ix - ix_tsw)
+ +bsw_val * (ix_tne - ix) + bse_val * (ix - ix_tnw));
+
+
+ // Computing gradient wrt grid_x =
+ // grad2_grad_input * z * y * gOut
+ gix += gOut * (-g2_tnw_val * (iy_bse - iy) * (iz_bse - iz) + g2_tne_val * (iy_bsw - iy) * (iz_bsw - iz)
+ -g2_tsw_val * (iy - iy_bne) * (iz_bne - iz) + g2_tse_val * (iy - iy_bnw) * (iz_bnw - iz)
+ -g2_bnw_val * (iy_tse - iy) * (iz - iz_tse) + g2_bne_val * (iy_tsw - iy) * (iz - iz_tsw)
+ -g2_bsw_val * (iy - iy_tne) * (iz - iz_tne) + g2_bse_val * (iy - iy_tnw) * (iz - iz_tnw));
+
+ //+ grad2_grad_grid_z * y * val * gOut + grad2_grad_grid_y * z * val * gOut
+ gix += gOut * (dz * dxz + dy * dxy);
+
+ // Computing gradient wrt grid_y =
+ // grad2_grad_input * x * z * gOut
+ giy += gOut * (-g2_tnw_val * (ix_bse - ix) * (iz_bse - iz) - g2_tne_val * (ix - ix_bsw) * (iz_bsw - iz)
+ +g2_tsw_val * (ix_bne - ix) * (iz_bne - iz) + g2_tse_val * (ix - ix_bnw) * (iz_bnw - iz)
+ -g2_bnw_val * (ix_tse - ix) * (iz - iz_tse) - g2_bne_val * (ix - ix_tsw) * (iz - iz_tsw)
+ +g2_bsw_val * (ix_tne - ix) * (iz - iz_tne) + g2_bse_val * (ix - ix_tnw) * (iz - iz_tnw));
+ //+ grad2_grad_grid_x * z * val * gOut + grad2_grad_grid_z * x * val * gOut
+ giy += gOut * (dx * dxy + dz * dyz);
+
+ // Computing gradient wrt grid_z =
+ // grad2_grad_input * x * y * gOut
+ giz += gOut * (-g2_tnw_val * (ix_bse - ix) * (iy_bse - iy) - g2_tne_val * (ix - ix_bsw) * (iy_bsw - iy)
+ -g2_tsw_val * (ix_bne - ix) * (iy - iy_bne) - g2_tse_val * (ix - ix_bnw) * (iy - iy_bnw)
+ +g2_bnw_val * (ix_tse - ix) * (iy_tse - iy) + g2_bne_val * (ix - ix_tsw) * (iy_tsw - iy)
+ +g2_bsw_val * (ix_tne - ix) * (iy - iy_tne) + g2_bse_val * (ix - ix_tnw) * (iy - iy_tnw));
+ //+ grad2_grad_grid_x * y * val * gOut + grad2_grad_grid_y * x * val * gOut
+ giz += gOut * (dx * dxz + dy * dyz);
+ }
+
+ gGrid_ptr_NDHW[0] = gix * gix_mult;
+ gGrid_ptr_NDHW[1] = giy * giy_mult;
+ gGrid_ptr_NDHW[2] = giz * giz_mult;
+ }
+}}
+
+
+std::vector grid_sample2d_cuda_grad2(
+ const torch::Tensor &grad2_grad_input,
+ const torch::Tensor &grad2_grad_grid,
+ const torch::Tensor &grad_output,
+ const torch::Tensor &input,
+ const torch::Tensor &grid,
+ bool padding_mode,
+ bool align_corners) {
+
+ const auto batch_size = input.size(0);
+ const auto C = input.size(1);
+ const auto H_IN = input.size(2);
+ const auto W_IN = input.size(3);
+
+ const auto H_OUT = grid.size(1);
+ const auto W_OUT = grid.size(2);
+
+ torch::Tensor grad_grad_output = torch::zeros_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ torch::Tensor grad_input = torch::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ torch::Tensor grad_grid = torch::zeros_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+
+ int64_t count = batch_size * H_OUT * W_OUT;
+
+ if (count > 0) {
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_grad2_cuda", [&] {
+ if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) &&
+ canUse32BitIndexMath(grad_output)) {
+ grid_sampler_2d_grad2_kernel
+ <<>>(
+ static_cast(count),
+ getTensorInfo(grad2_grad_input),
+ getTensorInfo(grad2_grad_grid),
+ getTensorInfo(grad_output),
+ getTensorInfo(input),
+ getTensorInfo(grid),
+ getTensorInfo(grad_grad_output),
+ getTensorInfo(grad_input),
+ getTensorInfo(grad_grid),
+ static_cast(padding_mode),
+ align_corners,
+ static_cast(grad_input.numel()));
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ } else {
+ grid_sampler_2d_grad2_kernel
+ <<>>(
+ count,
+ getTensorInfo(grad2_grad_input),
+ getTensorInfo(grad2_grad_grid),
+ getTensorInfo(grad_output),
+ getTensorInfo(input),
+ getTensorInfo(grid),
+ getTensorInfo(grad_grad_output),
+ getTensorInfo(grad_input),
+ getTensorInfo(grad_grid),
+ static_cast(padding_mode),
+ align_corners,
+ grad_input.numel());
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ }
+ });
+ }
+
+ return {grad_grad_output, grad_input, grad_grid};
+}
+
+std::vector grid_sample3d_cuda_grad2(
+ const torch::Tensor &grad2_grad_input,
+ const torch::Tensor &grad2_grad_grid,
+ const torch::Tensor &grad_output,
+ const torch::Tensor &input,
+ const torch::Tensor &grid,
+ bool padding_mode,
+ bool align_corners) {
+
+ const auto batch_size = input.size(0);
+ const auto C = input.size(1);
+ const auto D_IN = input.size(2);
+ const auto H_IN = input.size(3);
+ const auto W_IN = input.size(4);
+
+ const auto D_OUT = grid.size(1);
+ const auto H_OUT = grid.size(2);
+ const auto W_OUT = grid.size(3);
+
+ torch::Tensor grad_grad_output = torch::zeros_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ torch::Tensor grad_input = torch::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ torch::Tensor grad_grid = torch::zeros_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+
+ int64_t count = batch_size * D_OUT * H_OUT * W_OUT;
+
+ if (count > 0) {
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_grad2_cuda", [&] {
+ if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) &&
+ canUse32BitIndexMath(grad_output)) {
+ grid_sampler_3d_grad2_kernel
+ <<>>(
+ static_cast(count),
+ getTensorInfo(grad2_grad_input),
+ getTensorInfo(grad2_grad_grid),
+ getTensorInfo(grad_output),
+ getTensorInfo(input),
+ getTensorInfo(grid),
+ getTensorInfo(grad_grad_output),
+ getTensorInfo(grad_input),
+ getTensorInfo(grad_grid),
+ static_cast(padding_mode),
+ align_corners,
+ static_cast(grad_input.numel()));
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ } else {
+ grid_sampler_3d_grad2_kernel
+ <<>>(
+ count,
+ getTensorInfo(grad2_grad_input),
+ getTensorInfo(grad2_grad_grid),
+ getTensorInfo(grad_output),
+ getTensorInfo(input),
+ getTensorInfo(grid),
+ getTensorInfo(grad_grad_output),
+ getTensorInfo(grad_input),
+ getTensorInfo(grad_grid),
+ static_cast(padding_mode),
+ align_corners,
+ grad_input.numel());
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ }
+ });
+ }
+
+ return {grad_grad_output, grad_input, grad_grid};
+}
+
+}}
diff --git a/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample.py b/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..8195ca18e17643cd286348ffabf274330330c5d0
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample.py
@@ -0,0 +1,145 @@
+from torch.utils.cpp_extension import load
+import torch
+from pkg_resources import parse_version
+
+from .. import custom_ops
+import os
+
+# gridsample_grad2 = load(name='gridsample_grad2', sources=['third_party/ops/gridsample_cuda.cpp', 'third_party/ops/gridsample_cuda.cu'], verbose=True)
+
+gridsample_grad2 = load(name='gridsample_grad2', sources=[os.path.join(os.path.dirname(__file__), f) for f in ['grid_sample.cpp', 'grid_sample.cu']], verbose=True)
+
+gridsample_grad2 = None
+
+def _init():
+ global gridsample_grad2
+ if gridsample_grad2 is None:
+ gridsample_grad2 = custom_ops.get_plugin(
+ module_name='gridsample_grad2',
+ sources=['gridsample_cuda.cpp', 'gridsample_cuda.cu'],
+ headers=None,
+ source_dir=os.path.dirname(__file__),
+ extra_cuda_cflags=['--use_fast_math'],
+ )
+ return True
+
+_init()
+
+def grid_sample_2d(input, grid, padding_mode='zeros', align_corners=True):
+ assert padding_mode in ['zeros', 'border']
+ return _GridSample2dForward.apply(input, grid, padding_mode, align_corners)
+
+
+def grid_sample_3d(input, grid, padding_mode='zeros', align_corners=True):
+ assert padding_mode in ['zeros', 'border']
+ return _GridSample3dForward.apply(input, grid, padding_mode, align_corners)
+
+
+_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a')
+
+
+class _GridSample2dForward(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, grid, padding_mode=0, align_corners=True):
+ assert input.ndim == 4
+ assert grid.ndim == 4
+ assert input.shape[0] == grid.shape[0]
+ assert grid.shape[3] == 2
+
+ output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear',
+ padding_mode=padding_mode, align_corners=align_corners)
+ ctx.save_for_backward(input, grid)
+ ctx.padding_mode = ['zeros', 'border'].index(padding_mode)
+ ctx.align_corners = align_corners
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, grid = ctx.saved_tensors
+ grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid, ctx.padding_mode, ctx.align_corners)
+ return grad_input, grad_grid, None, None
+
+class _GridSample2dBackward(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, grad_output, input, grid, padding_mode=0, align_corners=True):
+ op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')[0]
+ if _use_pytorch_1_11_api:
+ output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
+ # breakpoint()
+ grad_input, grad_grid = op(grad_output, input, grid, 0, padding_mode, align_corners, output_mask)
+ else:
+ grad_input, grad_grid = op(grad_output, input, grid, 0, padding_mode, align_corners)
+
+ ctx.save_for_backward(grad_output, input, grid)
+ ctx.padding_mode = padding_mode
+ ctx.align_corners = align_corners
+
+ return grad_input, grad_grid
+
+ @staticmethod
+ def backward(ctx, grad2_grad_input, grad2_grad_grid):
+ grad_output, input, grid = ctx.saved_tensors
+ assert grad_output.is_cuda and input.is_cuda and grid.is_cuda and grad2_grad_input.is_cuda and grad2_grad_grid.is_cuda
+ out = gridsample_grad2.grad2_2d(grad2_grad_input, grad2_grad_grid, grad_output,
+ input, grid, ctx.padding_mode, ctx.align_corners)
+
+ grad_grad_output = out[0]
+ grad_input = out[1]
+ grad_grid = out[2]
+
+ return grad_grad_output, grad_input, grad_grid, None, None
+
+
+class _GridSample3dForward(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, grid, padding_mode=0, align_corners=True):
+ assert input.ndim == 5
+ assert grid.ndim == 5
+ assert input.shape[0] == grid.shape[0]
+ assert grid.shape[4] == 3
+
+ output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear',
+ padding_mode=padding_mode, align_corners=align_corners)
+ ctx.save_for_backward(input, grid)
+ ctx.padding_mode = ['zeros', 'border'].index(padding_mode)
+ ctx.align_corners = align_corners
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, grid = ctx.saved_tensors
+ grad_input, grad_grid = _GridSample3dBackward.apply(grad_output, input, grid, ctx.padding_mode, ctx.align_corners)
+ return grad_input, grad_grid, None, None
+
+
+
+class _GridSample3dBackward(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, grad_output, input, grid, padding_mode=0, align_corners=True):
+ op = torch._C._jit_get_operation('aten::grid_sampler_3d_backward')
+ if _use_pytorch_1_11_api:
+ output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
+ grad_input, grad_grid = op(grad_output, input, grid, 0, padding_mode, align_corners, output_mask)
+ else:
+ grad_input, grad_grid = op(grad_output, input, grid, 0, padding_mode, align_corners)
+
+ ctx.save_for_backward(grad_output, input, grid)
+ ctx.padding_mode = padding_mode
+ ctx.align_corners = align_corners
+
+ return grad_input, grad_grid
+
+ @staticmethod
+ def backward(ctx, grad2_grad_input, grad2_grad_grid):
+ grad_output, input, grid = ctx.saved_tensors
+ assert grad_output.is_cuda and input.is_cuda and grid.is_cuda and grad2_grad_input.is_cuda and grad2_grad_grid.is_cuda
+ out = gridsample_grad2.grad2_3d(grad2_grad_input, grad2_grad_grid, grad_output,
+ input, grid, ctx.padding_mode, ctx.align_corners)
+
+ grad_grad_output = out[0]
+ grad_input = out[1]
+ grad_grid = out[2]
+
+ return grad_grad_output, grad_input, grad_grid, None, None
+
+
diff --git a/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample_gradfix.py b/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample_gradfix.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5d766c16fa3dca90f60abba16d634dfb0510300
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample_gradfix.py
@@ -0,0 +1,79 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Custom replacement for `torch.nn.functional.grid_sample` that
+supports arbitrarily high order gradients between the input and output.
+Only works on 2D images and assumes
+`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
+
+import torch
+
+# pylint: disable=redefined-builtin
+# pylint: disable=arguments-differ
+# pylint: disable=protected-access
+
+#----------------------------------------------------------------------------
+
+enabled = True # Enable the custom op by setting this to true.
+
+#----------------------------------------------------------------------------
+
+def grid_sample(input, grid):
+ if _should_use_custom_op():
+ return _GridSample2dForward.apply(input, grid)
+ return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
+
+#----------------------------------------------------------------------------
+
+def _should_use_custom_op():
+ return enabled
+
+#----------------------------------------------------------------------------
+
+class _GridSample2dForward(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, grid):
+ assert input.ndim == 4
+ assert grid.ndim == 4
+ output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
+ ctx.save_for_backward(input, grid)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, grid = ctx.saved_tensors
+ grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
+ return grad_input, grad_grid
+
+#----------------------------------------------------------------------------
+
+class _GridSample2dBackward(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, grad_output, input, grid):
+ op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
+ ctx.save_for_backward(grid)
+ return grad_input, grad_grid
+
+ @staticmethod
+ def backward(ctx, grad2_grad_input, grad2_grad_grid):
+ _ = grad2_grad_grid # unused
+ grid, = ctx.saved_tensors
+ grad2_grad_output = None
+ grad2_input = None # grad2_grad_input #
+ grad2_grid = None # grad2_grad_grid #
+
+ if ctx.needs_input_grad[0]:
+ grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
+
+ assert not ctx.needs_input_grad[2]
+ return grad2_grad_output, grad2_input, grad2_grid
+
+#----------------------------------------------------------------------------
diff --git a/svrm/ldm/modules/rendering_neus/third_party/ops/gridsample_cuda.cpp b/svrm/ldm/modules/rendering_neus/third_party/ops/gridsample_cuda.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..05ba9baa08b8bf0aac3643219306b871f9dee650
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/third_party/ops/gridsample_cuda.cpp
@@ -0,0 +1,57 @@
+#include
+
+#include
+
+// CUDA forward declarations
+
+namespace at {namespace native {
+std::vector grid_sample2d_cuda_grad2(
+ const torch::Tensor &grad2_grad_input,
+ const torch::Tensor &grad2_grad_grid,
+ const torch::Tensor &grad_output,
+ const torch::Tensor &input,
+ const torch::Tensor &grid,
+ bool padding_mode,
+ bool align_corners);
+std::vector grid_sample3d_cuda_grad2(
+ const torch::Tensor &grad2_grad_input,
+ const torch::Tensor &grad2_grad_grid,
+ const torch::Tensor &grad_output,
+ const torch::Tensor &input,
+ const torch::Tensor &grid,
+ bool padding_mode,
+ bool align_corners);
+}}
+
+std::vector grid_sample2d_grad2(
+ const torch::Tensor &grad2_grad_input,
+ const torch::Tensor &grad2_grad_grid,
+ const torch::Tensor &grad_output,
+ const torch::Tensor &input,
+ const torch::Tensor &grid,
+ bool padding_mode,
+ bool align_corners) {
+
+ return at::native::grid_sample2d_cuda_grad2(grad2_grad_input, grad2_grad_grid,
+ grad_output, input, grid, padding_mode, align_corners);
+}
+
+std::vector grid_sample3d_grad2(
+ const torch::Tensor &grad2_grad_input,
+ const torch::Tensor &grad2_grad_grid,
+ const torch::Tensor &grad_output,
+ const torch::Tensor &input,
+ const torch::Tensor &grid,
+ bool padding_mode,
+ bool align_corners) {
+
+ return at::native::grid_sample3d_cuda_grad2(grad2_grad_input, grad2_grad_grid,
+ grad_output, input, grid, padding_mode, align_corners);
+}
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("grad2_2d", &grid_sample2d_grad2, "grid_sample2d second derivative");
+ m.def("grad2_3d", &grid_sample3d_grad2, "grid_sample3d second derivative");
+}
+
diff --git a/svrm/ldm/modules/rendering_neus/third_party/ops/gridsample_cuda.cu b/svrm/ldm/modules/rendering_neus/third_party/ops/gridsample_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..02b5a33918eec1c6baac348b765bb2e3467ce319
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/third_party/ops/gridsample_cuda.cu
@@ -0,0 +1,668 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+
+#include
+
+namespace at { namespace native {
+namespace {
+
+using namespace at::cuda::detail;
+
+using at::native::detail::GridSamplerInterpolation;
+using at::native::detail::GridSamplerPadding;
+
+template
+ C10_LAUNCH_BOUNDS_1(256)
+ __global__ void grid_sampler_2d_grad2_kernel(
+ const index_t nthreads,
+ TensorInfo grad2_grad_input,
+ TensorInfo grad2_grad_grid,
+ TensorInfo grad_output,
+ TensorInfo input,
+ TensorInfo grid,
+ TensorInfo grad_grad_output,
+ TensorInfo grad_input,
+ TensorInfo grad_grid,
+ const GridSamplerPadding padding_mode,
+ bool align_corners,
+ const index_t grad_input_memory_span) {
+
+ index_t C = input.sizes[1];
+ index_t inp_H = input.sizes[2];
+ index_t inp_W = input.sizes[3];
+
+ index_t out_H = grid.sizes[1];
+ index_t out_W = grid.sizes[2];
+
+ index_t g2inp_sN = grad2_grad_input.strides[0];
+ index_t g2inp_sC = grad2_grad_input.strides[1];
+ index_t g2inp_sH = grad2_grad_input.strides[2];
+ index_t g2inp_sW = grad2_grad_input.strides[3];
+
+ index_t g2grid_sN = grad2_grad_grid.strides[0];
+ index_t g2grid_sH = grad2_grad_grid.strides[1];
+ index_t g2grid_sW = grad2_grad_grid.strides[2];
+ index_t g2grid_sCoor = grad2_grad_grid.strides[3];
+
+ index_t gOut_sN = grad_output.strides[0];
+ index_t gOut_sC = grad_output.strides[1];
+ index_t gOut_sH = grad_output.strides[2];
+ index_t gOut_sW = grad_output.strides[3];
+
+ index_t inp_sN = input.strides[0];
+ index_t inp_sC = input.strides[1];
+ index_t inp_sH = input.strides[2];
+ index_t inp_sW = input.strides[3];
+
+ index_t grid_sN = grid.strides[0];
+ index_t grid_sH = grid.strides[1];
+ index_t grid_sW = grid.strides[2];
+ index_t grid_sCoor = grid.strides[3];
+
+ index_t gInp_sN = grad_input.strides[0];
+ index_t gInp_sC = grad_input.strides[1];
+ index_t gInp_sH = grad_input.strides[2];
+ index_t gInp_sW = grad_input.strides[3];
+
+ index_t gGrid_sW = grad_grid.strides[2];
+
+ index_t ggOut_sN = grad_grad_output.strides[0];
+ index_t ggOut_sC = grad_grad_output.strides[1];
+ index_t ggOut_sH = grad_grad_output.strides[2];
+ index_t ggOut_sW = grad_grad_output.strides[3];
+
+ CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
+ const index_t w = index % out_W;
+ const index_t h = (index / out_W) % out_H;
+ const index_t n = index / (out_H * out_W);
+
+ /* Grid related staff */
+ index_t grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
+
+ // get the corresponding input x, y co-ordinates from grid
+ scalar_t x = grid.data[grid_offset];
+ scalar_t y = grid.data[grid_offset + grid_sCoor];
+
+ // multipliers for gradients on ix and iy
+ scalar_t gix_mult, giy_mult;
+ scalar_t ix = grid_sampler_compute_source_index_set_grad(x, inp_W, padding_mode, align_corners, &gix_mult);
+ scalar_t iy = grid_sampler_compute_source_index_set_grad(y, inp_H, padding_mode, align_corners, &giy_mult);
+
+ // get NE, NW, SE, SW pixel values from (x, y)
+ index_t ix_nw = static_cast(::floor(ix));
+ index_t iy_nw = static_cast(::floor(iy));
+ index_t ix_ne = ix_nw + 1;
+ index_t iy_ne = iy_nw;
+ index_t ix_sw = ix_nw;
+ index_t iy_sw = iy_nw + 1;
+ index_t ix_se = ix_nw + 1;
+ index_t iy_se = iy_nw + 1;
+
+ // get surfaces to each neighbor:
+ scalar_t nw = (ix_se - ix) * (iy_se - iy);
+ scalar_t ne = (ix - ix_sw) * (iy_sw - iy);
+ scalar_t sw = (ix_ne - ix) * (iy - iy_ne);
+ scalar_t se = (ix - ix_nw) * (iy - iy_nw);
+
+ /* grad2_grad_input related init */
+ scalar_t *g2_inp_ptr_NC = grad2_grad_input.data + n * g2inp_sN;
+
+ /* grad2_grad_grid related init */
+ grid_offset = n * g2grid_sN + h * g2grid_sH + w * g2grid_sW;
+ scalar_t dx = grad2_grad_grid.data[grid_offset];
+ scalar_t dy = grad2_grad_grid.data[grid_offset + g2grid_sCoor];
+
+ dx = dx * gix_mult;
+ dy = dy * giy_mult;
+
+ /* grad_output related init */
+ scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW;
+
+ /* input related init */
+ scalar_t *inp_ptr_NC = input.data + n * inp_sN;
+
+ /* grad_grad_output related init */
+ scalar_t *ggOut_ptr_NCHW = grad_grad_output.data + n * ggOut_sN + h * ggOut_sH + w * ggOut_sW;
+
+ /* grad_input related init */
+ index_t NC_offset = n * gInp_sN;
+
+ /* grad_grid related init */
+ scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW;
+ scalar_t gix = static_cast(0), giy = static_cast(0);
+
+ scalar_t nw_val, ne_val, sw_val, se_val;
+ scalar_t g2_nw_val, g2_ne_val, g2_sw_val, g2_se_val;
+
+ scalar_t zero = static_cast(0);
+ for (index_t c = 0; c < C;
+ ++c,
+ g2_inp_ptr_NC += g2inp_sC,
+ inp_ptr_NC += inp_sC,
+ NC_offset += gInp_sC,
+ gOut_ptr_NCHW += gOut_sC,
+ ggOut_ptr_NCHW += ggOut_sC) {
+
+ nw_val = within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)? inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW]: zero;
+ ne_val = within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)? inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW]: zero;
+ sw_val = within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)? inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW]: zero;
+ se_val = within_bounds_2d(iy_se, ix_se, inp_H, inp_W)? inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW]: zero;
+
+ g2_nw_val = within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)? g2_inp_ptr_NC[iy_nw * g2inp_sH + ix_nw * g2inp_sW]: zero;
+ g2_ne_val = within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)? g2_inp_ptr_NC[iy_ne * g2inp_sH + ix_ne * g2inp_sW]: zero;
+ g2_sw_val = within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)? g2_inp_ptr_NC[iy_sw * g2inp_sH + ix_sw * g2inp_sW]: zero;
+ g2_se_val = within_bounds_2d(iy_se, ix_se, inp_H, inp_W)? g2_inp_ptr_NC[iy_se * g2inp_sH + ix_se * g2inp_sW]: zero;
+
+ // Computing gradient wrt to grad_output = grad2_grad_input * x * y + grad2_grad_grid_x * y * val + grad2_grad_grid_y * x * val
+ // grad2_grad_input * x * y
+ *ggOut_ptr_NCHW = static_cast(0);
+ *ggOut_ptr_NCHW += g2_nw_val * nw + g2_ne_val * ne + g2_sw_val * sw + g2_se_val * se;
+
+ scalar_t nw_tmp = -dx * (iy_se - iy) - dy * (ix_se - ix);
+ scalar_t ne_tmp = +dx * (iy_sw - iy) - dy * (ix - ix_sw);
+ scalar_t sw_tmp = -dx * (iy - iy_ne) + dy * (ix_ne - ix);
+ scalar_t se_tmp = +dx * (iy - iy_nw) + dy * (ix - ix_nw);
+
+
+ // grad2_grad_grid_x * y * val + grad2_grad_grid_y * x * val
+ *ggOut_ptr_NCHW += nw_val * nw_tmp + ne_tmp * ne_val + sw_tmp * sw_val + se_tmp * se_val;
+
+ // Computing gradient wrt input = grad2_grad_grid_x * grad_output * y + grad2_grad_grid_y * grad_output * x
+ scalar_t gOut = *gOut_ptr_NCHW;
+ //scalar_t val;
+ //val = gOut * (-dx * (iy_se - iy) - dy * (ix_se - ix));
+ safe_add_2d(grad_input.data, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, nw_tmp * gOut, NC_offset, grad_input_memory_span);
+ //val = gOut * (+dx * (iy_sw - iy) - dy * (ix - ix_sw));
+ safe_add_2d(grad_input.data, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, ne_tmp * gOut, NC_offset, grad_input_memory_span);
+ //val = gOut * (-dx * (iy - iy_ne) + dy * (ix_ne - ix));
+ safe_add_2d(grad_input.data, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, sw_tmp * gOut, NC_offset, grad_input_memory_span);
+ //val = gOut * (+dx * (iy - iy_nw) + dy * (ix - ix_nw));
+ safe_add_2d(grad_input.data, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, se_tmp * gOut, NC_offset, grad_input_memory_span);
+
+ scalar_t dxy = nw_val - ne_val - sw_val + se_val;
+ // Computing gradient wrt grid_x = grad2_grad_input * y * gOut + grad2_grad_grid_y * val * gOut
+ gix += gOut * (-g2_nw_val * (iy_se - iy) + g2_ne_val * (iy_sw - iy)
+ -g2_sw_val * (iy - iy_ne) + g2_se_val * (iy - iy_nw));
+ gix += gOut * dy * dxy;
+
+ // Computing gradient wrt grid_y = grad2_grad_input * x * gOut + grad2_grad_grid_x * val * gOut
+ giy += gOut * (-g2_nw_val * (ix_se - ix) - g2_ne_val * (ix - ix_sw)
+ +g2_sw_val * (ix_ne - ix) + g2_se_val * (ix - ix_nw));
+ giy += gOut * dx * dxy;
+ }
+
+ gGrid_ptr_NHW[0] = gix * gix_mult;
+ gGrid_ptr_NHW[1] = giy * giy_mult;
+ }
+}
+
+template
+ C10_LAUNCH_BOUNDS_1(256)
+ __global__ void grid_sampler_3d_grad2_kernel(
+ const index_t nthreads,
+ TensorInfo grad2_grad_input,
+ TensorInfo grad2_grad_grid,
+ TensorInfo grad_output,
+ TensorInfo input,
+ TensorInfo grid,
+ TensorInfo grad_grad_output,
+ TensorInfo grad_input,
+ TensorInfo grad_grid,
+ const GridSamplerPadding padding_mode,
+ bool align_corners,
+ const index_t grad_input_memory_span) {
+
+ index_t C = input.sizes[1];
+ index_t inp_D = input.sizes[2];
+ index_t inp_H = input.sizes[3];
+ index_t inp_W = input.sizes[4];
+
+ index_t out_D = grid.sizes[1];
+ index_t out_H = grid.sizes[2];
+ index_t out_W = grid.sizes[3];
+
+ index_t g2inp_sN = grad2_grad_input.strides[0];
+ index_t g2inp_sC = grad2_grad_input.strides[1];
+ index_t g2inp_sD = grad2_grad_input.strides[2];
+ index_t g2inp_sH = grad2_grad_input.strides[3];
+ index_t g2inp_sW = grad2_grad_input.strides[4];
+
+ index_t g2grid_sN = grad2_grad_grid.strides[0];
+ index_t g2grid_sD = grad2_grad_grid.strides[1];
+ index_t g2grid_sH = grad2_grad_grid.strides[2];
+ index_t g2grid_sW = grad2_grad_grid.strides[3];
+ index_t g2grid_sCoor = grad2_grad_grid.strides[4];
+
+ index_t gOut_sN = grad_output.strides[0];
+ index_t gOut_sC = grad_output.strides[1];
+ index_t gOut_sD = grad_output.strides[2];
+ index_t gOut_sH = grad_output.strides[3];
+ index_t gOut_sW = grad_output.strides[4];
+
+ index_t inp_sN = input.strides[0];
+ index_t inp_sC = input.strides[1];
+ index_t inp_sD = input.strides[2];
+ index_t inp_sH = input.strides[3];
+ index_t inp_sW = input.strides[4];
+
+ index_t grid_sN = grid.strides[0];
+ index_t grid_sD = grid.strides[1];
+ index_t grid_sH = grid.strides[2];
+ index_t grid_sW = grid.strides[3];
+ index_t grid_sCoor = grid.strides[4];
+
+ index_t gInp_sN = grad_input.strides[0];
+ index_t gInp_sC = grad_input.strides[1];
+ index_t gInp_sD = grad_input.strides[2];
+ index_t gInp_sH = grad_input.strides[3];
+ index_t gInp_sW = grad_input.strides[4];
+
+ index_t gGrid_sW = grad_grid.strides[3];
+
+ index_t ggOut_sN = grad_grad_output.strides[0];
+ index_t ggOut_sC = grad_grad_output.strides[1];
+ index_t ggOut_sD = grad_grad_output.strides[2];
+ index_t ggOut_sH = grad_grad_output.strides[3];
+ index_t ggOut_sW = grad_grad_output.strides[4];
+
+ CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
+ const index_t w = index % out_W;
+ const index_t h = (index / out_W) % out_H;
+ const index_t d = (index / (out_H * out_W)) % out_D;
+ const index_t n = index / (out_D * out_H * out_W);
+
+ /* Grid related staff */
+ index_t grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;
+
+ // get the corresponding input x, y co-ordinates from grid
+ scalar_t ix = grid.data[grid_offset];
+ scalar_t iy = grid.data[grid_offset + grid_sCoor];
+ scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor];
+
+ // multipliers for gradients on ix and iy
+ scalar_t gix_mult, giy_mult, giz_mult;
+ ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult);
+ iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult);
+ iz = grid_sampler_compute_source_index_set_grad(iz, inp_D, padding_mode, align_corners, &giz_mult);
+
+ // get NE, NW, SE, SW pixel values from (x, y)
+ index_t ix_tnw = static_cast(::floor(ix));
+ index_t iy_tnw = static_cast(::floor(iy));
+ index_t iz_tnw = static_cast(::floor(iz));
+
+ index_t ix_tne = ix_tnw + 1;
+ index_t iy_tne = iy_tnw;
+ index_t iz_tne = iz_tnw;
+
+ index_t ix_tsw = ix_tnw;
+ index_t iy_tsw = iy_tnw + 1;
+ index_t iz_tsw = iz_tnw;
+
+ index_t ix_tse = ix_tnw + 1;
+ index_t iy_tse = iy_tnw + 1;
+ index_t iz_tse = iz_tnw;
+
+ index_t ix_bnw = ix_tnw;
+ index_t iy_bnw = iy_tnw;
+ index_t iz_bnw = iz_tnw + 1;
+
+ index_t ix_bne = ix_tnw + 1;
+ index_t iy_bne = iy_tnw;
+ index_t iz_bne = iz_tnw + 1;
+
+ index_t ix_bsw = ix_tnw;
+ index_t iy_bsw = iy_tnw + 1;
+ index_t iz_bsw = iz_tnw + 1;
+
+ index_t ix_bse = ix_tnw + 1;
+ index_t iy_bse = iy_tnw + 1;
+ index_t iz_bse = iz_tnw + 1;
+
+ // get surfaces to each neighbor:
+ scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
+ scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
+ scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
+ scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
+ scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
+ scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
+ scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
+ scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
+
+ /* grad2_grad_input related init */
+ scalar_t *g2_inp_ptr_NC = grad2_grad_input.data + n * g2inp_sN;
+
+ /* grad2_grad_grid related init */
+ grid_offset = n * g2grid_sN + d * g2grid_sD + h * g2grid_sH + w * g2grid_sW;
+ scalar_t dx = grad2_grad_grid.data[grid_offset];
+ scalar_t dy = grad2_grad_grid.data[grid_offset + g2grid_sCoor];
+ scalar_t dz = grad2_grad_grid.data[grid_offset + 2 * g2grid_sCoor];
+
+ dx = dx * gix_mult;
+ dy = dy * giy_mult;
+ dz = dz * giz_mult;
+
+ /* grad_output related init */
+ scalar_t *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
+
+ /* input related init */
+ scalar_t *inp_ptr_NC = input.data + n * inp_sN;
+
+ /* grad_grad_output related init */
+ scalar_t *ggOut_ptr_NCDHW = grad_grad_output.data + n * ggOut_sN + d * ggOut_sD + h * ggOut_sH + w * ggOut_sW;
+
+ /* grad_input related init */
+ index_t NC_offset = n * gInp_sN;
+
+ /* grad_grid related init */
+ scalar_t *gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sW;
+ scalar_t gix = static_cast(0), giy = static_cast(0), giz = static_cast(0);
+
+ scalar_t tnw_val, tne_val, tsw_val, tse_val, bnw_val, bne_val, bsw_val, bse_val;
+ scalar_t g2_tnw_val, g2_tne_val, g2_tsw_val, g2_tse_val, g2_bnw_val, g2_bne_val, g2_bsw_val, g2_bse_val;
+
+ scalar_t zero = static_cast(0);
+ for (index_t c = 0; c < C;
+ ++c,
+ g2_inp_ptr_NC += g2inp_sC,
+ inp_ptr_NC += inp_sC,
+ NC_offset += gInp_sC,
+ gOut_ptr_NCDHW += gOut_sC,
+ ggOut_ptr_NCDHW += ggOut_sC) {
+
+ if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
+ tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW];
+ g2_tnw_val = g2_inp_ptr_NC[iz_tnw * g2inp_sD + iy_tnw * g2inp_sH + ix_tnw * g2inp_sW];
+ } else {
+ tnw_val = zero;
+ g2_tnw_val = zero;
+ }
+ if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
+ tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW];
+ g2_tne_val = g2_inp_ptr_NC[iz_tne * g2inp_sD + iy_tne * g2inp_sH + ix_tne * g2inp_sW];
+ } else {
+ tne_val = zero;
+ g2_tne_val = zero;
+ }
+ if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
+ tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW];
+ g2_tsw_val = g2_inp_ptr_NC[iz_tsw * g2inp_sD + iy_tsw * g2inp_sH + ix_tsw * g2inp_sW];
+ } else {
+ tsw_val = zero;
+ g2_tsw_val = zero;
+ }
+ if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
+ tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW];
+ g2_tse_val = g2_inp_ptr_NC[iz_tse * g2inp_sD + iy_tse * g2inp_sH + ix_tse * g2inp_sW];
+ } else {
+ tse_val = zero;
+ g2_tse_val = zero;
+ }
+
+ if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
+ bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW];
+ g2_bnw_val = g2_inp_ptr_NC[iz_bnw * g2inp_sD + iy_bnw * g2inp_sH + ix_bnw * g2inp_sW];
+ } else {
+ bnw_val = zero;
+ g2_bnw_val = zero;
+ }
+ if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
+ bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW];
+ g2_bne_val = g2_inp_ptr_NC[iz_bne * g2inp_sD + iy_bne * g2inp_sH + ix_bne * g2inp_sW];
+ } else {
+ bne_val = zero;
+ g2_bne_val = zero;
+ }
+ if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
+ bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW];
+ g2_bsw_val = g2_inp_ptr_NC[iz_bsw * g2inp_sD + iy_bsw * g2inp_sH + ix_bsw * g2inp_sW];
+ } else {
+ bsw_val = zero;
+ g2_bsw_val = zero;
+ }
+ if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
+ bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW];
+ g2_bse_val = g2_inp_ptr_NC[iz_bse * g2inp_sD + iy_bse * g2inp_sH + ix_bse * g2inp_sW];
+ } else {
+ bse_val = zero;
+ g2_bse_val = zero;
+ }
+
+ // Computing gradient wrt to grad_output =
+ // grad2_grad_input * x * y * z
+ *ggOut_ptr_NCDHW = static_cast(0);
+ *ggOut_ptr_NCDHW += g2_tnw_val * tnw + g2_tne_val * tne + g2_tsw_val * tsw + g2_tse_val * tse
+ +g2_bnw_val * bnw + g2_bne_val * bne + g2_bsw_val * bsw + g2_bse_val * bse;
+
+ // +val * (grad2_grad_grid_x * y * z + grad2_grad_grid_y * x * z + grad2_grad_grid_z * x * y)
+ scalar_t tnw_tmp = (-dx * (iy_bse - iy) * (iz_bse - iz) - dy * (ix_bse - ix) * (iz_bse - iz) - dz * (ix_bse - ix) * (iy_bse - iy));
+ scalar_t tne_tmp = (+dx * (iy_bsw - iy) * (iz_bsw - iz) - dy * (ix - ix_bsw) * (iz_bsw - iz) - dz * (ix - ix_bsw) * (iy_bsw - iy));
+ scalar_t tsw_tmp = (-dx * (iy - iy_bne) * (iz_bne - iz) + dy * (ix_bne - ix) * (iz_bne - iz) - dz * (ix_bne - ix) * (iy - iy_bne));
+ scalar_t tse_tmp = (+dx * (iy - iy_bnw) * (iz_bnw - iz) + dy * (ix - ix_bnw) * (iz_bnw - iz) - dz * (ix - ix_bnw) * (iy - iy_bnw));
+ scalar_t bnw_tmp = (-dx * (iy_tse - iy) * (iz - iz_tse) - dy * (ix_tse - ix) * (iz - iz_tse) + dz * (ix_tse - ix) * (iy_tse - iy));
+ scalar_t bne_tmp = (+dx * (iy_tsw - iy) * (iz - iz_tsw) - dy * (ix - ix_tsw) * (iz - iz_tsw) + dz * (ix - ix_tsw) * (iy_tsw - iy));
+ scalar_t bsw_tmp = (-dx * (iy - iy_tne) * (iz - iz_tne) + dy * (ix_tne - ix) * (iz - iz_tne) + dz * (ix_tne - ix) * (iy - iy_tne));
+ scalar_t bse_tmp = (+dx * (iy - iy_tnw) * (iz - iz_tnw) + dy * (ix - ix_tnw) * (iz - iz_tnw) + dz * (ix - ix_tnw) * (iy - iy_tnw));
+
+ *ggOut_ptr_NCDHW += tnw_val * tnw_tmp + tne_val * tne_tmp + tsw_val * tsw_tmp + tse_val * tse_tmp
+ +bnw_val * bnw_tmp + bne_val * bne_tmp + bsw_val * bsw_tmp + bse_val * bse_tmp;
+
+ // Computing gradient wrt input = grad2_grad_grid_x * grad_output * y * z + grad2_grad_grid_y * grad_output * x * z +
+ // grad2_grad_grid_z * grad_output * y * z
+ scalar_t gOut = *gOut_ptr_NCDHW;
+
+ safe_add_3d(grad_input.data, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+ safe_add_3d(grad_input.data, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+ safe_add_3d(grad_input.data, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+ safe_add_3d(grad_input.data, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+ safe_add_3d(grad_input.data, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+ safe_add_3d(grad_input.data, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+ safe_add_3d(grad_input.data, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+ safe_add_3d(grad_input.data, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+
+ //Computing gradient wrt grid
+ scalar_t dxy = (tnw_val * (iz_bse - iz) - tne_val * (iz_bsw - iz)
+ -tsw_val * (iz_bne - iz) + tse_val * (iz_bnw - iz)
+ +bnw_val * (iz - iz_tse) - bne_val * (iz - iz_tsw)
+ -bsw_val * (iz - iz_tne) + bse_val * (iz - iz_tnw));
+
+ scalar_t dxz = (tnw_val * (iy_bse - iy) - tne_val * (iy_bsw - iy)
+ +tsw_val * (iy - iy_bne) - tse_val * (iy - iy_bnw)
+ -bnw_val * (iy_tse - iy) + bne_val * (iy_tsw - iy)
+ -bsw_val * (iy - iy_tne) + bse_val * (iy - iy_tnw));
+
+ scalar_t dyz = (tnw_val * (ix_bse - ix) + tne_val * (ix - ix_bsw)
+ -tsw_val * (ix_bne - ix) - tse_val * (ix - ix_bnw)
+ -bnw_val * (ix_tse - ix) - bne_val * (ix - ix_tsw)
+ +bsw_val * (ix_tne - ix) + bse_val * (ix - ix_tnw));
+
+
+ // Computing gradient wrt grid_x =
+ // grad2_grad_input * z * y * gOut
+ gix += gOut * (-g2_tnw_val * (iy_bse - iy) * (iz_bse - iz) + g2_tne_val * (iy_bsw - iy) * (iz_bsw - iz)
+ -g2_tsw_val * (iy - iy_bne) * (iz_bne - iz) + g2_tse_val * (iy - iy_bnw) * (iz_bnw - iz)
+ -g2_bnw_val * (iy_tse - iy) * (iz - iz_tse) + g2_bne_val * (iy_tsw - iy) * (iz - iz_tsw)
+ -g2_bsw_val * (iy - iy_tne) * (iz - iz_tne) + g2_bse_val * (iy - iy_tnw) * (iz - iz_tnw));
+
+ //+ grad2_grad_grid_z * y * val * gOut + grad2_grad_grid_y * z * val * gOut
+ gix += gOut * (dz * dxz + dy * dxy);
+
+ // Computing gradient wrt grid_y =
+ // grad2_grad_input * x * z * gOut
+ giy += gOut * (-g2_tnw_val * (ix_bse - ix) * (iz_bse - iz) - g2_tne_val * (ix - ix_bsw) * (iz_bsw - iz)
+ +g2_tsw_val * (ix_bne - ix) * (iz_bne - iz) + g2_tse_val * (ix - ix_bnw) * (iz_bnw - iz)
+ -g2_bnw_val * (ix_tse - ix) * (iz - iz_tse) - g2_bne_val * (ix - ix_tsw) * (iz - iz_tsw)
+ +g2_bsw_val * (ix_tne - ix) * (iz - iz_tne) + g2_bse_val * (ix - ix_tnw) * (iz - iz_tnw));
+ //+ grad2_grad_grid_x * z * val * gOut + grad2_grad_grid_z * x * val * gOut
+ giy += gOut * (dx * dxy + dz * dyz);
+
+ // Computing gradient wrt grid_z =
+ // grad2_grad_input * x * y * gOut
+ giz += gOut * (-g2_tnw_val * (ix_bse - ix) * (iy_bse - iy) - g2_tne_val * (ix - ix_bsw) * (iy_bsw - iy)
+ -g2_tsw_val * (ix_bne - ix) * (iy - iy_bne) - g2_tse_val * (ix - ix_bnw) * (iy - iy_bnw)
+ +g2_bnw_val * (ix_tse - ix) * (iy_tse - iy) + g2_bne_val * (ix - ix_tsw) * (iy_tsw - iy)
+ +g2_bsw_val * (ix_tne - ix) * (iy - iy_tne) + g2_bse_val * (ix - ix_tnw) * (iy - iy_tnw));
+ //+ grad2_grad_grid_x * y * val * gOut + grad2_grad_grid_y * x * val * gOut
+ giz += gOut * (dx * dxz + dy * dyz);
+ }
+
+ gGrid_ptr_NDHW[0] = gix * gix_mult;
+ gGrid_ptr_NDHW[1] = giy * giy_mult;
+ gGrid_ptr_NDHW[2] = giz * giz_mult;
+ }
+}}
+
+
+std::vector grid_sample2d_cuda_grad2(
+ const torch::Tensor &grad2_grad_input,
+ const torch::Tensor &grad2_grad_grid,
+ const torch::Tensor &grad_output,
+ const torch::Tensor &input,
+ const torch::Tensor &grid,
+ bool padding_mode,
+ bool align_corners) {
+
+ const auto batch_size = input.size(0);
+ const auto C = input.size(1);
+ const auto H_IN = input.size(2);
+ const auto W_IN = input.size(3);
+
+ const auto H_OUT = grid.size(1);
+ const auto W_OUT = grid.size(2);
+
+ torch::Tensor grad_grad_output = torch::zeros_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ torch::Tensor grad_input = torch::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ torch::Tensor grad_grid = torch::zeros_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+
+ int64_t count = batch_size * H_OUT * W_OUT;
+
+ if (count > 0) {
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_grad2_cuda", [&] {
+ if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) &&
+ canUse32BitIndexMath(grad_output)) {
+ grid_sampler_2d_grad2_kernel
+ <<>>(
+ static_cast(count),
+ getTensorInfo(grad2_grad_input),
+ getTensorInfo(grad2_grad_grid),
+ getTensorInfo(grad_output),
+ getTensorInfo(input),
+ getTensorInfo(grid),
+ getTensorInfo(grad_grad_output),
+ getTensorInfo(grad_input),
+ getTensorInfo(grad_grid),
+ static_cast(padding_mode),
+ align_corners,
+ static_cast(grad_input.numel()));
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ } else {
+ grid_sampler_2d_grad2_kernel
+ <<>>(
+ count,
+ getTensorInfo(grad2_grad_input),
+ getTensorInfo(grad2_grad_grid),
+ getTensorInfo(grad_output),
+ getTensorInfo(input),
+ getTensorInfo(grid),
+ getTensorInfo(grad_grad_output),
+ getTensorInfo(grad_input),
+ getTensorInfo(grad_grid),
+ static_cast(padding_mode),
+ align_corners,
+ grad_input.numel());
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ }
+ });
+ }
+
+ return {grad_grad_output, grad_input, grad_grid};
+}
+
+std::vector grid_sample3d_cuda_grad2(
+ const torch::Tensor &grad2_grad_input,
+ const torch::Tensor &grad2_grad_grid,
+ const torch::Tensor &grad_output,
+ const torch::Tensor &input,
+ const torch::Tensor &grid,
+ bool padding_mode,
+ bool align_corners) {
+
+ const auto batch_size = input.size(0);
+ const auto C = input.size(1);
+ const auto D_IN = input.size(2);
+ const auto H_IN = input.size(3);
+ const auto W_IN = input.size(4);
+
+ const auto D_OUT = grid.size(1);
+ const auto H_OUT = grid.size(2);
+ const auto W_OUT = grid.size(3);
+
+ torch::Tensor grad_grad_output = torch::zeros_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ torch::Tensor grad_input = torch::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ torch::Tensor grad_grid = torch::zeros_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+
+ int64_t count = batch_size * D_OUT * H_OUT * W_OUT;
+
+ if (count > 0) {
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_grad2_cuda", [&] {
+ if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) &&
+ canUse32BitIndexMath(grad_output)) {
+ grid_sampler_3d_grad2_kernel
+ <<>>(
+ static_cast(count),
+ getTensorInfo(grad2_grad_input),
+ getTensorInfo(grad2_grad_grid),
+ getTensorInfo(grad_output),
+ getTensorInfo(input),
+ getTensorInfo(grid),
+ getTensorInfo(grad_grad_output),
+ getTensorInfo(grad_input),
+ getTensorInfo(grad_grid),
+ static_cast(padding_mode),
+ align_corners,
+ static_cast(grad_input.numel()));
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ } else {
+ grid_sampler_3d_grad2_kernel
+ <<>>(
+ count,
+ getTensorInfo(grad2_grad_input),
+ getTensorInfo(grad2_grad_grid),
+ getTensorInfo(grad_output),
+ getTensorInfo(input),
+ getTensorInfo(grid),
+ getTensorInfo(grad_grad_output),
+ getTensorInfo(grad_input),
+ getTensorInfo(grad_grid),
+ static_cast(padding_mode),
+ align_corners,
+ grad_input.numel());
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ }
+ });
+ }
+
+ return {grad_grad_output, grad_input, grad_grid};
+}
+
+}}
diff --git a/svrm/ldm/modules/rendering_neus/third_party/pytorch_ssim/__init__.py b/svrm/ldm/modules/rendering_neus/third_party/pytorch_ssim/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..14d38b1722a7242ecd5353063eecb38ccc3aee02
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/third_party/pytorch_ssim/__init__.py
@@ -0,0 +1,93 @@
+import torch
+import torch.nn.functional as F
+from torch.autograd import Variable
+import numpy as np
+from math import exp
+
+
+def gaussian(window_size, sigma):
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
+ return gauss / gauss.sum()
+
+
+def create_window(window_size, channel):
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
+ return window
+
+
+def _ssim(img1, img2, window, window_size, channel, use_padding, size_average=True):
+
+ if use_padding:
+ padding_size = window_size // 2
+ else:
+ padding_size = 0
+
+ mu1 = F.conv2d(img1, window, padding=padding_size, groups=channel)
+ mu2 = F.conv2d(img2, window, padding=padding_size, groups=channel)
+
+ mu1_sq = mu1.pow(2)
+ mu2_sq = mu2.pow(2)
+ mu1_mu2 = mu1 * mu2
+
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=padding_size, groups=channel) - mu1_sq
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=padding_size, groups=channel) - mu2_sq
+ sigma12 = F.conv2d(img1 * img2, window, padding=padding_size, groups=channel) - mu1_mu2
+
+ C1 = 0.01 ** 2
+ C2 = 0.03 ** 2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
+
+ if size_average:
+ return ssim_map.mean()
+ else:
+ return ssim_map.mean(1).mean(1).mean(1)
+
+
+class SSIM(torch.nn.Module):
+ def __init__(self, window_size=11, use_padding=True, size_average=True):
+ super(SSIM, self).__init__()
+ self.window_size = window_size
+ self.size_average = size_average
+ self.use_padding = use_padding
+ self.channel = 1
+ self.window = create_window(window_size, self.channel)
+
+ def forward(self, img1, img2):
+ (_, channel, _, _) = img1.size()
+
+ if channel == self.channel and self.window.data.type() == img1.data.type():
+ window = self.window
+ else:
+ window = create_window(self.window_size, channel)
+
+ if img1.is_cuda:
+ window = window.cuda(img1.get_device())
+ window = window.type_as(img1)
+
+ self.window = window
+ self.channel = channel
+
+ return _ssim(img1, img2, window, self.window_size, channel, self.use_padding, self.size_average)
+
+
+def ssim(img1, img2, use_padding=True, window_size=11, size_average=True):
+ """SSIM only defined at intensity channel. For RGB or YUV or other image format, this function computes SSIm at each
+ channel and averge them.
+ :param img1: (B, C, H, W) float32 in [0, 1]
+ :param img2: (B, C, H, W) float32 in [0, 1]
+ :param use_padding: we use conv2d when we compute mean and var for each patch, this use_padding is for that conv2d.
+ :param window_size: patch size
+ :param size_average:
+ :return: a tensor that contains only one scalar.
+ """
+ (_, channel, _, _) = img1.size()
+ window = create_window(window_size, channel)
+
+ if img1.is_cuda:
+ window = window.cuda(img1.get_device())
+ window = window.type_as(img1)
+
+ return _ssim(img1, img2, window, window_size, channel, use_padding, size_average)
\ No newline at end of file
diff --git a/svrm/ldm/modules/rendering_neus/utils/__init__.py b/svrm/ldm/modules/rendering_neus/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3910899b3e0bfba650b294c1dd08a559a234933f
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/utils/__init__.py
@@ -0,0 +1,9 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
diff --git a/svrm/ldm/modules/rendering_neus/utils/math_utils.py b/svrm/ldm/modules/rendering_neus/utils/math_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ca717cfc9d6fb0926d7d8d165661e60ffc27731
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/utils/math_utils.py
@@ -0,0 +1,118 @@
+# MIT License
+
+# Copyright (c) 2022 Petr Kellnhofer
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import torch
+
+def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor:
+ """
+ Left-multiplies MxM @ NxM. Returns NxM.
+ """
+ res = torch.matmul(vectors4, matrix.T)
+ return res
+
+
+def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
+ """
+ Normalize vector lengths.
+ """
+ return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
+
+def torch_dot(x: torch.Tensor, y: torch.Tensor):
+ """
+ Dot product of two tensors.
+ """
+ return (x * y).sum(-1)
+
+
+def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length):
+ """
+ Author: Petr Kellnhofer
+ Intersects rays with the [-1, 1] NDC volume.
+ Returns min and max distance of entry.
+ Returns -1 for no intersection.
+ https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
+ """
+ o_shape = rays_o.shape
+ rays_o = rays_o.detach().reshape(-1, 3)
+ rays_d = rays_d.detach().reshape(-1, 3)
+
+
+ bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)]
+ bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)]
+ bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device)
+ is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
+
+ # Precompute inverse for stability.
+ invdir = 1 / rays_d
+ sign = (invdir < 0).long()
+
+ # Intersect with YZ plane.
+ tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
+ tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
+
+ # Intersect with XZ plane.
+ tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
+ tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
+
+ # Resolve parallel rays.
+ is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
+
+ # Use the shortest intersection.
+ tmin = torch.max(tmin, tymin)
+ tmax = torch.min(tmax, tymax)
+
+ # Intersect with XY plane.
+ tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
+ tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
+
+ # Resolve parallel rays.
+ is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
+
+ # Use the shortest intersection.
+ tmin = torch.max(tmin, tzmin)
+ tmax = torch.min(tmax, tzmax)
+
+ # Mark invalid.
+ tmin[torch.logical_not(is_valid)] = -1
+ tmax[torch.logical_not(is_valid)] = -2
+
+ return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
+
+
+def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
+ """
+ Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
+ Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
+ """
+ # create a tensor of 'num' steps from 0 to 1
+ steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
+
+ # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
+ # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
+ # "cannot statically infer the expected size of a list in this contex", hence the code below
+ for i in range(start.ndim):
+ steps = steps.unsqueeze(-1)
+
+ # the output starts at 'start' and increments until 'stop' in each dimension
+ out = start[None] + steps * (stop - start)[None]
+
+ return out
diff --git a/svrm/ldm/modules/rendering_neus/utils/ray_marcher.py b/svrm/ldm/modules/rendering_neus/utils/ray_marcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..753300ff6ce30fe3df2fd198f8e6ba4394d7b309
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/utils/ray_marcher.py
@@ -0,0 +1,140 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+#
+# Modified by Zexin He
+# The modifications are subject to the same license as the original.
+
+
+"""
+The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths.
+Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!)
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class LearnedVariance(nn.Module):
+ def __init__(self, init_val):
+ super(LearnedVariance, self).__init__()
+ self.register_parameter("_inv_std", nn.Parameter(torch.tensor(init_val)))
+
+ @property
+ def inv_std(self):
+ val = torch.exp(self._inv_std * 10.0)
+ return val
+
+ def forward(self, x):
+ return torch.ones_like(x) * self.inv_std.clamp(1.0e-6, 1.0e6)
+
+
+class MipRayMarcher2(nn.Module):
+ def __init__(self, activation_factory):
+ super().__init__()
+ self.activation_factory = activation_factory
+ self.variance = LearnedVariance(0.3)
+ self.cos_anneal_ratio = 1.0
+ def get_alpha(self, sdf, normal, dirs, dists):
+ # sdf: [N 1] normal: [N 3] dirs: [N 3] dists: [N 1]
+ # import ipdb; ipdb.set_trace()
+
+ inv_std = self.variance(sdf)
+
+ true_cos = (dirs * normal).sum(-1, keepdim=True)
+ # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes
+ # the cos value "not dead" at the beginning training iterations, for better convergence.
+ iter_cos = -(
+ F.relu(-true_cos * 0.5 + 0.5) * (1.0 - self.cos_anneal_ratio)
+ + F.relu(-true_cos) * self.cos_anneal_ratio
+ ) # always non-positive
+
+ # Estimate signed distances at section points
+ estimated_next_sdf = sdf + iter_cos * dists * 0.5
+ estimated_prev_sdf = sdf - iter_cos * dists * 0.5
+
+ prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_std)
+ next_cdf = torch.sigmoid(estimated_next_sdf * inv_std)
+
+ p = prev_cdf - next_cdf
+ c = prev_cdf
+
+ alpha = ((p + 1e-5) / (c + 1e-5)).clip(0.0, 1.0)
+ return alpha
+
+ def run_forward(self, colors, sdfs, depths, normals, ray_directions, rendering_options, bgcolor=None, real_normals=None):
+ # depths: [B N_ray*N_sample 1]
+ # sdfs: [B, N_ray, N_sample 1]
+ # import ipdb; ipdb.set_trace()
+
+ deltas = depths[:, :, 1:] - depths[:, :, :-1]
+ colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2
+ sdfs_mid = (sdfs[:, :, :-1] + sdfs[:, :, 1:]) / 2
+ depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2
+ normals_mid = (normals[:, :, :-1] + normals[:, :, 1:]) / 2
+
+ # zhaohx add for normal :
+ real_normals_mid = (real_normals[:, :, :-1] + real_normals[:, :, 1:]) / 2
+
+ # # using factory mode for better usability
+ # densities_mid = self.activation_factory(rendering_options)(densities_mid)
+
+ # density_delta = densities_mid * deltas
+
+ # alpha = 1 - torch.exp(-density_delta)
+
+ # import ipdb; ipdb.set_trace()
+ dirs = ray_directions.unsqueeze(2).expand(-1, -1, sdfs_mid.shape[-2], -1)
+ B, N_ray, N_sample, _ = sdfs_mid.shape
+ alpha = self.get_alpha(sdfs_mid.reshape(-1, 1), normals_mid.reshape(-1, 3), dirs.reshape(-1, 3), deltas.reshape(-1, 1))
+ alpha = alpha.reshape(B, N_ray, N_sample, -1)
+
+ alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2)
+ weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1]
+
+ composite_rgb = torch.sum(weights * colors_mid, -2)
+ weight_total = weights.sum(2)
+ composite_depth = torch.sum(weights * depths_mid, -2) / weight_total
+
+ # clip the composite to min/max range of depths
+ composite_depth = torch.nan_to_num(composite_depth, float('inf'))
+ composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths))
+ # import pdb; pdb.set_trace()
+
+ # zhaohx add for normal :
+ composite_normal = torch.sum(weights * real_normals_mid, -2) / weight_total
+ composite_normal = torch.nan_to_num(composite_normal, float('inf'))
+ composite_normal = torch.clamp(composite_normal, torch.min(real_normals), torch.max(real_normals))
+
+ if rendering_options.get('white_back', False):
+ # composite_rgb = composite_rgb + 1 - weight_total
+ # weight_total[weight_total < 0.5] = 0
+ # composite_rgb = composite_rgb * weight_total + 1 - weight_total
+ # now is this
+ if bgcolor is None:
+ composite_rgb = composite_rgb + 1 - weight_total
+ # composite_rgb = composite_rgb * weight_total + 1 - weight_total
+ else:
+ # import pdb; pdb.set_trace()
+ bgcolor = bgcolor.permute(0, 2, 3, 1).contiguous().view(composite_rgb.shape[0], -1, composite_rgb.shape[-1])
+ composite_rgb = composite_rgb + (1 - weight_total) * bgcolor
+ # composite_rgb = composite_rgb * weight_total + (1 - weight_total) * bgcolor
+ # composite_rgb = composite_rgb
+ # print('new white_back')
+
+ # rendered value scale is 0-1, comment out original mipnerf scaling
+ # composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1)
+
+ return composite_rgb, composite_depth, weights, composite_normal
+
+
+ def forward(self, colors, sdfs, depths, normals, ray_directions, rendering_options, bgcolor=None, real_normals=None):
+ composite_rgb, composite_depth, weights, composite_normal = self.run_forward(colors, sdfs, depths, normals, ray_directions, rendering_options, bgcolor, real_normals)
+
+ return composite_rgb, composite_depth, weights, composite_normal
diff --git a/svrm/ldm/modules/rendering_neus/utils/ray_sampler.py b/svrm/ldm/modules/rendering_neus/utils/ray_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..7501478dbac004c3d0a9c0366f989af0b52ef5db
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/utils/ray_sampler.py
@@ -0,0 +1,81 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+#
+# Modified by Zexin He
+# The modifications are subject to the same license as the original.
+
+
+"""
+The ray sampler is a module that takes in camera matrices and resolution and batches of rays.
+Expects cam2world matrices that use the OpenCV camera coordinate system conventions.
+"""
+
+import torch
+
+class RaySampler(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None
+
+
+ def forward(self, cam2world_matrix, intrinsics, render_size):
+ """
+ Create batches of rays and return origins and directions.
+
+ cam2world_matrix: (N, 4, 4)
+ intrinsics: (N, 3, 3)
+ render_size: int
+
+ ray_origins: (N, M, 3)
+ ray_dirs: (N, M, 2)
+ """
+
+ N, M = cam2world_matrix.shape[0], render_size**2
+ cam_locs_world = cam2world_matrix[:, :3, 3]
+ fx = intrinsics[:, 0, 0]
+ fy = intrinsics[:, 1, 1]
+ cx = intrinsics[:, 0, 2]
+ cy = intrinsics[:, 1, 2]
+ sk = intrinsics[:, 0, 1]
+
+ uv = torch.stack(torch.meshgrid(
+ torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device),
+ torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device),
+ indexing='ij',
+ ))
+ uv = uv.flip(0).reshape(2, -1).transpose(1, 0)
+ uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
+
+ x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size)
+ y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size)
+ z_cam = torch.ones((N, M), device=cam2world_matrix.device)
+
+ x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam
+ y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam
+
+ cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1)
+
+ _opencv2blender = torch.tensor([
+ [1, 0, 0, 0],
+ [0, -1, 0, 0],
+ [0, 0, -1, 0],
+ [0, 0, 0, 1],
+ ], dtype=torch.float32, device=cam2world_matrix.device).unsqueeze(0).repeat(N, 1, 1)
+
+ cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender)
+
+ world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3]
+
+ ray_dirs = world_rel_points - cam_locs_world[:, None, :]
+ ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2)
+
+ ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1)
+
+ return ray_origins, ray_dirs
diff --git a/svrm/ldm/modules/rendering_neus/utils/renderer.py b/svrm/ldm/modules/rendering_neus/utils/renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e315c72e47a6a4437bb76e40bcc29d38cd7e2c2c
--- /dev/null
+++ b/svrm/ldm/modules/rendering_neus/utils/renderer.py
@@ -0,0 +1,331 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+#
+# Modified by Zexin He
+# The modifications are subject to the same license as the original.
+
+
+"""
+The renderer is a module that takes in rays, decides where to sample along each
+ray, and computes pixel colors using the volume rendering equation.
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .ray_marcher import MipRayMarcher2
+from . import math_utils
+# from ldm.modules.rendering_neus.third_party.ops import grid_sample
+
+def generate_planes():
+ """
+ Defines planes by the three vectors that form the "axes" of the
+ plane. Should work with arbitrary number of planes and planes of
+ arbitrary orientation.
+
+ Bugfix reference: https://github.com/NVlabs/eg3d/issues/67
+ """
+ return torch.tensor([[[1, 0, 0],
+ [0, 1, 0],
+ [0, 0, 1]],
+ [[1, 0, 0],
+ [0, 0, 1],
+ [0, 1, 0]],
+ [[0, 0, 1],
+ [0, 1, 0],
+ [1, 0, 0]]], dtype=torch.float32)
+
+def project_onto_planes(planes, coordinates):
+ """
+ Does a projection of a 3D point onto a batch of 2D planes,
+ returning 2D plane coordinates.
+
+ Takes plane axes of shape n_planes, 3, 3
+ # Takes coordinates of shape N, M, 3
+ # returns projections of shape N*n_planes, M, 2
+ """
+ N, M, C = coordinates.shape
+ n_planes, _, _ = planes.shape
+ coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
+ inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
+ projections = torch.bmm(coordinates, inv_planes)
+ return projections[..., :2]
+
+def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
+ assert padding_mode == 'zeros'
+ N, n_planes, C, H, W = plane_features.shape
+ _, M, _ = coordinates.shape
+ plane_features = plane_features.view(N*n_planes, C, H, W)
+
+ coordinates = (2/box_warp) * coordinates # add specific box bounds
+ # print(coordinates.max(), coordinates.min())
+ # import pdb; pdb.set_trace()
+
+ projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
+ # output_features = grid_sample.grid_sample_2d(plane_features, projected_coordinates.float().to(plane_features.device)).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
+
+ output_features = torch.nn.functional.grid_sample(plane_features.float(), projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
+ return output_features
+
+def sample_from_3dgrid(grid, coordinates):
+ """
+ Expects coordinates in shape (batch_size, num_points_per_batch, 3)
+ Expects grid in shape (1, channels, H, W, D)
+ (Also works if grid has batch size)
+ Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels)
+ """
+ batch_size, n_coords, n_dims = coordinates.shape
+ sampled_features = torch.nn.functional.grid_sample(grid.expand(batch_size, -1, -1, -1, -1),
+ coordinates.reshape(batch_size, 1, 1, -1, n_dims),
+ mode='bilinear', padding_mode='zeros', align_corners=False)
+ N, C, H, W, D = sampled_features.shape
+ sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C)
+ return sampled_features
+
+class ImportanceRenderer(torch.nn.Module):
+ """
+ Modified original version to filter out-of-box samples as TensoRF does.
+
+ Reference:
+ TensoRF: https://github.com/apchenstu/TensoRF/blob/main/models/tensorBase.py#L277
+ """
+ def __init__(self):
+ super().__init__()
+ self.activation_factory = self._build_activation_factory()
+ self.ray_marcher = MipRayMarcher2(self.activation_factory)
+ self.plane_axes = generate_planes()
+
+ def _build_activation_factory(self):
+ def activation_factory(options: dict):
+ if options['clamp_mode'] == 'softplus':
+ return lambda x: F.softplus(x - 1) # activation bias of -1 makes things initialize better
+ else:
+ assert False, "Renderer only supports `clamp_mode`=`softplus`!"
+ return activation_factory
+
+ def _forward_pass(self, depths: torch.Tensor, ray_directions: torch.Tensor, ray_origins: torch.Tensor,
+ planes: torch.Tensor, decoder: nn.Module, rendering_options: dict):
+ """
+ Additional filtering is applied to filter out-of-box samples.
+ Modifications made by Zexin He.
+ """
+
+ # context related variables
+ batch_size, num_rays, samples_per_ray, _ = depths.shape
+ device = depths.device
+
+ # define sample points with depths
+ sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3)
+ sample_coordinates = (ray_origins.unsqueeze(-2) + depths * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
+ # print(f'min bbox: {sample_coordinates.min()}, max bbox: {sample_coordinates.max()}')
+ # import pdb; pdb.set_trace()
+ # filter out-of-box samples
+ mask_inbox = \
+ (rendering_options['sampler_bbox_min'] <= sample_coordinates) & \
+ (sample_coordinates <= rendering_options['sampler_bbox_max'])
+ mask_inbox = mask_inbox.all(-1)
+
+ # forward model according to all samples
+ _out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options)
+
+ # set out-of-box samples to zeros(rgb) & -inf(sigma)
+ SAFE_GUARD = 3
+ DATA_TYPE = _out['sdf'].dtype
+ colors_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE)
+ normals_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE)
+ sdfs_pass = torch.nan_to_num(torch.full((batch_size, num_rays * samples_per_ray, 1), -float('inf'), device=device, dtype=DATA_TYPE)) / SAFE_GUARD
+
+ # print(DATA_TYPE)
+ # import pdb; pdb.set_trace()
+ # colors_pass[mask_inbox], sdfs_pass[mask_inbox] = _out['rgb'][mask_inbox], _out['sdf'][mask_inbox]
+ colors_pass[mask_inbox], sdfs_pass = _out['rgb'][mask_inbox], _out['sdf']
+ normals_pass = _out['normal']
+
+ # reshape back
+ colors_pass = colors_pass.reshape(batch_size, num_rays, samples_per_ray, colors_pass.shape[-1])
+ sdfs_pass = sdfs_pass.reshape(batch_size, num_rays, samples_per_ray, sdfs_pass.shape[-1])
+ normals_pass = normals_pass.reshape(batch_size, num_rays, samples_per_ray, normals_pass.shape[-1])
+
+ return colors_pass, sdfs_pass, normals_pass, _out['sdf_grad']
+
+ def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options, bgcolor=None):
+ # self.plane_axes = self.plane_axes.to(ray_origins.device)
+
+ if rendering_options['ray_start'] == 'auto' == rendering_options['ray_end']:
+ ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp']) # [1, N_ray, 1]
+ is_ray_valid = ray_end > ray_start
+ if torch.any(is_ray_valid).item():
+ ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
+ ray_end[~is_ray_valid] = ray_start[is_ray_valid].max()
+ depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) # [1, N_ray, N_sample, 1]】
+ else:
+ # Create stratified depth samples
+ depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
+
+
+ # Coarse Pass
+ colors_coarse, sdfs_coarse, normals_coarse, sdf_grad = self._forward_pass(
+ depths=depths_coarse, ray_directions=ray_directions, ray_origins=ray_origins,
+ planes=planes, decoder=decoder, rendering_options=rendering_options)
+
+
+ # Fine Pass
+ N_importance = rendering_options['depth_resolution_importance']
+ # TODO
+ if N_importance > 0:
+ _, _, weights = self.ray_marcher(colors_coarse, sdfs_coarse, depths_coarse, sdf_grad.reshape(*normals_coarse.shape), ray_directions, rendering_options, bgcolor)
+
+ depths_fine = self.sample_importance(depths_coarse, weights, N_importance)
+
+ colors_fine, densities_fine = self._forward_pass(
+ depths=depths_fine, ray_directions=ray_directions, ray_origins=ray_origins,
+ planes=planes, decoder=decoder, rendering_options=rendering_options)
+
+ all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse,
+ depths_fine, colors_fine, densities_fine)
+ ####
+ # dists = depths_coarse[:, :, 1:, :] - depths_coarse[:, :, :-1, :]
+ # inter = (ray_end - ray_start) / ( rendering_options['depth_resolution'] + rendering_options['depth_resolution_importance'] - 1) # [1, N_ray, 1]
+ # dists = torch.cat([dists, inter.unsqueeze(2), 2])
+ ####
+
+ # Aggregate
+ rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options, bgcolor)
+ else:
+ # # import pdb; pdb.set_trace()
+ # dists = depths_coarse[:, :, 1:, :] - depths_coarse[:, :, :-1, :]
+ # inter = (ray_end - ray_start) / ( rendering_options['depth_resolution'] - 1) # [1, N_ray, 1]
+ # dists = torch.cat([dists, inter.unsqueeze(2)], 2)
+ # # import ipdb; ipdb.set_trace()
+
+ # rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, sdfs_coarse, depths_coarse, normals_coarse, dists, ray_directions, rendering_options, bgcolor)
+ rgb_final, depth_final, weights, normal_final = self.ray_marcher(colors_coarse, sdfs_coarse, depths_coarse, sdf_grad.reshape(*normals_coarse.shape), ray_directions, rendering_options, bgcolor, normals_coarse)
+ # import ipdb; ipdb.set_trace()
+
+ return rgb_final, depth_final, weights.sum(2), sdf_grad, normal_final
+
+ def run_model(self, planes, decoder, sample_coordinates, sample_directions, options):
+ plane_axes = self.plane_axes.to(planes.device)
+
+ out = decoder(sample_directions, sample_coordinates, plane_axes, planes, options)
+ # if options.get('density_noise', 0) > 0:
+ # out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise']
+ return out
+
+ def run_model_activated(self, planes, decoder, sample_coordinates, sample_directions, options):
+ out = self.run_model(planes, decoder, sample_coordinates, sample_directions, options)
+ out['sigma'] = self.activation_factory(options)(out['sigma'])
+ return out
+
+ def sort_samples(self, all_depths, all_colors, all_densities):
+ _, indices = torch.sort(all_depths, dim=-2)
+ all_depths = torch.gather(all_depths, -2, indices)
+ all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
+ all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
+ return all_depths, all_colors, all_densities
+
+ def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2):
+ all_depths = torch.cat([depths1, depths2], dim = -2)
+ all_colors = torch.cat([colors1, colors2], dim = -2)
+ all_densities = torch.cat([densities1, densities2], dim = -2)
+
+ _, indices = torch.sort(all_depths, dim=-2)
+ all_depths = torch.gather(all_depths, -2, indices)
+ all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
+ all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
+
+ return all_depths, all_colors, all_densities
+
+ def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False):
+ """
+ Return depths of approximately uniformly spaced samples along rays.
+ """
+ N, M, _ = ray_origins.shape
+ if disparity_space_sampling:
+ depths_coarse = torch.linspace(0,
+ 1,
+ depth_resolution,
+ device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
+ depth_delta = 1/(depth_resolution - 1)
+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta
+ depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse)
+ else:
+ if type(ray_start) == torch.Tensor:
+ depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3)
+ depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
+ else:
+ depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
+ depth_delta = (ray_end - ray_start)/(depth_resolution - 1)
+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta
+
+ return depths_coarse
+
+ def sample_importance(self, z_vals, weights, N_importance):
+ """
+ Return depths of importance sampled points along rays. See NeRF importance sampling for more.
+ """
+ with torch.no_grad():
+ batch_size, num_rays, samples_per_ray, _ = z_vals.shape
+
+ z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray)
+ weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher
+
+ # smooth weights
+ weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1).float(), 2, 1, padding=1)
+ weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze()
+ weights = weights + 0.01
+
+ z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:])
+ importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1],
+ N_importance).detach().reshape(batch_size, num_rays, N_importance, 1)
+ return importance_z_vals
+
+ def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5):
+ """
+ Sample @N_importance samples from @bins with distribution defined by @weights.
+ Inputs:
+ bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2"
+ weights: (N_rays, N_samples_)
+ N_importance: the number of samples to draw from the distribution
+ det: deterministic or not
+ eps: a small number to prevent division by zero
+ Outputs:
+ samples: the sampled samples
+ """
+ N_rays, N_samples_ = weights.shape
+ weights = weights + eps # prevent division by zero (don't do inplace op!)
+ pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_)
+ cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function
+ cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1)
+ # padded to 0~1 inclusive
+
+ if det:
+ u = torch.linspace(0, 1, N_importance, device=bins.device)
+ u = u.expand(N_rays, N_importance)
+ else:
+ u = torch.rand(N_rays, N_importance, device=bins.device)
+ u = u.contiguous()
+
+ inds = torch.searchsorted(cdf, u, right=True)
+ below = torch.clamp_min(inds-1, 0)
+ above = torch.clamp_max(inds, N_samples_)
+
+ inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance)
+ cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2)
+ bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2)
+
+ denom = cdf_g[...,1]-cdf_g[...,0]
+ denom[denom b c h w')
+ h = self.upsampler(h)
+ h = rearrange(h, '(b d) c h w-> b d c h w', d=3)
+ h = h.type(x.dtype)
+ return h
+ else:
+ h = self.upsampler(h) #[b, h, w, triplane_dim*4]
+ b, height, width, _ = h.shape
+ h = h.view(b, height, width, self.triplane_dim, self.upsample_ratio, self.upsample_ratio) #[b, h, w, triplane_dim, 2, 2]
+ h = h.permute(0,3,1,4,2,5).contiguous() #[b, triplane_dim, h, 2, w, 2]
+ h = h.view(b, self.triplane_dim, height*self.upsample_ratio, width*self.upsample_ratio)
+ h = rearrange(h, '(b d) c h w-> b d c h w', d=3)
+ h = h.type(x.dtype)
+ return h
diff --git a/svrm/ldm/modules/x_transformer.py b/svrm/ldm/modules/x_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c7a4d4f2bd89ce015bc7896e73b55772e508a66
--- /dev/null
+++ b/svrm/ldm/modules/x_transformer.py
@@ -0,0 +1,641 @@
+"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+from functools import partial
+from inspect import isfunction
+from collections import namedtuple
+from einops import rearrange, repeat, reduce
+
+# constants
+
+DEFAULT_DIM_HEAD = 64
+
+Intermediates = namedtuple('Intermediates', [
+ 'pre_softmax_attn',
+ 'post_softmax_attn'
+])
+
+LayerIntermediates = namedtuple('Intermediates', [
+ 'hiddens',
+ 'attn_intermediates'
+])
+
+
+class AbsolutePositionalEmbedding(nn.Module):
+ def __init__(self, dim, max_seq_len):
+ super().__init__()
+ self.emb = nn.Embedding(max_seq_len, dim)
+ self.init_()
+
+ def init_(self):
+ nn.init.normal_(self.emb.weight, std=0.02)
+
+ def forward(self, x):
+ n = torch.arange(x.shape[1], device=x.device)
+ return self.emb(n)[None, :, :]
+
+
+class FixedPositionalEmbedding(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer('inv_freq', inv_freq)
+
+ def forward(self, x, seq_dim=1, offset=0):
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
+ sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
+ return emb[None, :, :]
+
+
+# helpers
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def always(val):
+ def inner(*args, **kwargs):
+ return val
+ return inner
+
+
+def not_equals(val):
+ def inner(x):
+ return x != val
+ return inner
+
+
+def equals(val):
+ def inner(x):
+ return x == val
+ return inner
+
+
+def max_neg_value(tensor):
+ return -torch.finfo(tensor.dtype).max
+
+
+# keyword argument helpers
+
+def pick_and_pop(keys, d):
+ values = list(map(lambda key: d.pop(key), keys))
+ return dict(zip(keys, values))
+
+
+def group_dict_by_key(cond, d):
+ return_val = [dict(), dict()]
+ for key in d.keys():
+ match = bool(cond(key))
+ ind = int(not match)
+ return_val[ind][key] = d[key]
+ return (*return_val,)
+
+
+def string_begins_with(prefix, str):
+ return str.startswith(prefix)
+
+
+def group_by_key_prefix(prefix, d):
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
+
+
+def groupby_prefix_and_trim(prefix, d):
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
+ return kwargs_without_prefix, kwargs
+
+
+# classes
+class Scale(nn.Module):
+ def __init__(self, value, fn):
+ super().__init__()
+ self.value = value
+ self.fn = fn
+
+ def forward(self, x, **kwargs):
+ x, *rest = self.fn(x, **kwargs)
+ return (x * self.value, *rest)
+
+
+class Rezero(nn.Module):
+ def __init__(self, fn):
+ super().__init__()
+ self.fn = fn
+ self.g = nn.Parameter(torch.zeros(1))
+
+ def forward(self, x, **kwargs):
+ x, *rest = self.fn(x, **kwargs)
+ return (x * self.g, *rest)
+
+
+class ScaleNorm(nn.Module):
+ def __init__(self, dim, eps=1e-5):
+ super().__init__()
+ self.scale = dim ** -0.5
+ self.eps = eps
+ self.g = nn.Parameter(torch.ones(1))
+
+ def forward(self, x):
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
+ return x / norm.clamp(min=self.eps) * self.g
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim, eps=1e-8):
+ super().__init__()
+ self.scale = dim ** -0.5
+ self.eps = eps
+ self.g = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
+ return x / norm.clamp(min=self.eps) * self.g
+
+
+class Residual(nn.Module):
+ def forward(self, x, residual):
+ return x + residual
+
+
+class GRUGating(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.gru = nn.GRUCell(dim, dim)
+
+ def forward(self, x, residual):
+ gated_output = self.gru(
+ rearrange(x, 'b n d -> (b n) d'),
+ rearrange(residual, 'b n d -> (b n) d')
+ )
+
+ return gated_output.reshape_as(x)
+
+
+# feedforward
+
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+# attention.
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_head=DEFAULT_DIM_HEAD,
+ heads=8,
+ causal=False,
+ mask=None,
+ talking_heads=False,
+ sparse_topk=None,
+ use_entmax15=False,
+ num_mem_kv=0,
+ dropout=0.,
+ on_attn=False
+ ):
+ super().__init__()
+ if use_entmax15:
+ raise NotImplementedError("Check out entmax activation instead of softmax activation!")
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+ self.causal = causal
+ self.mask = mask
+
+ inner_dim = dim_head * heads
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(dim, inner_dim, bias=False)
+ self.dropout = nn.Dropout(dropout)
+
+ # talking heads
+ self.talking_heads = talking_heads
+ if talking_heads:
+ self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
+ self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
+
+ # explicit topk sparse attention
+ self.sparse_topk = sparse_topk
+
+ # entmax
+ #self.attn_fn = entmax15 if use_entmax15 else F.softmax
+ self.attn_fn = F.softmax
+
+ # add memory key / values
+ self.num_mem_kv = num_mem_kv
+ if num_mem_kv > 0:
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
+
+ # attention on attention
+ self.attn_on_attn = on_attn
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ context_mask=None,
+ rel_pos=None,
+ sinusoidal_emb=None,
+ prev_attn=None,
+ mem=None
+ ):
+ b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
+ kv_input = default(context, x)
+
+ q_input = x
+ k_input = kv_input
+ v_input = kv_input
+
+ if exists(mem):
+ k_input = torch.cat((mem, k_input), dim=-2)
+ v_input = torch.cat((mem, v_input), dim=-2)
+
+ if exists(sinusoidal_emb):
+ # in shortformer, the query would start at a position offset depending on the past cached memory
+ offset = k_input.shape[-2] - q_input.shape[-2]
+ q_input = q_input + sinusoidal_emb(q_input, offset=offset)
+ k_input = k_input + sinusoidal_emb(k_input)
+
+ q = self.to_q(q_input)
+ k = self.to_k(k_input)
+ v = self.to_v(v_input)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
+
+ input_mask = None
+ if any(map(exists, (mask, context_mask))):
+ q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
+ k_mask = q_mask if not exists(context) else context_mask
+ k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
+ q_mask = rearrange(q_mask, 'b i -> b () i ()')
+ k_mask = rearrange(k_mask, 'b j -> b () () j')
+ input_mask = q_mask * k_mask
+
+ if self.num_mem_kv > 0:
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
+ k = torch.cat((mem_k, k), dim=-2)
+ v = torch.cat((mem_v, v), dim=-2)
+ if exists(input_mask):
+ input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
+
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
+ mask_value = max_neg_value(dots)
+
+ if exists(prev_attn):
+ dots = dots + prev_attn
+
+ pre_softmax_attn = dots
+
+ if talking_heads:
+ dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
+
+ if exists(rel_pos):
+ dots = rel_pos(dots)
+
+ if exists(input_mask):
+ dots.masked_fill_(~input_mask, mask_value)
+ del input_mask
+
+ if self.causal:
+ i, j = dots.shape[-2:]
+ r = torch.arange(i, device=device)
+ mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
+ mask = F.pad(mask, (j - i, 0), value=False)
+ dots.masked_fill_(mask, mask_value)
+ del mask
+
+ if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
+ top, _ = dots.topk(self.sparse_topk, dim=-1)
+ vk = top[..., -1].unsqueeze(-1).expand_as(dots)
+ mask = dots < vk
+ dots.masked_fill_(mask, mask_value)
+ del mask
+
+ attn = self.attn_fn(dots, dim=-1)
+ post_softmax_attn = attn
+
+ attn = self.dropout(attn)
+
+ if talking_heads:
+ attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
+
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
+ out = rearrange(out, 'b h n d -> b n (h d)')
+
+ intermediates = Intermediates(
+ pre_softmax_attn=pre_softmax_attn,
+ post_softmax_attn=post_softmax_attn
+ )
+
+ return self.to_out(out), intermediates
+
+
+class AttentionLayers(nn.Module):
+ def __init__(
+ self,
+ dim,
+ depth,
+ heads=8,
+ causal=False,
+ cross_attend=False,
+ only_cross=False,
+ use_scalenorm=False,
+ use_rmsnorm=False,
+ use_rezero=False,
+ rel_pos_num_buckets=32,
+ rel_pos_max_distance=128,
+ position_infused_attn=False,
+ custom_layers=None,
+ sandwich_coef=None,
+ par_ratio=None,
+ residual_attn=False,
+ cross_residual_attn=False,
+ macaron=False,
+ pre_norm=True,
+ gate_residual=False,
+ **kwargs
+ ):
+ super().__init__()
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
+ attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
+
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
+
+ self.dim = dim
+ self.depth = depth
+ self.layers = nn.ModuleList([])
+
+ self.has_pos_emb = position_infused_attn
+ self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
+ self.rotary_pos_emb = always(None)
+
+ assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
+ self.rel_pos = None
+
+ self.pre_norm = pre_norm
+
+ self.residual_attn = residual_attn
+ self.cross_residual_attn = cross_residual_attn
+
+ norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
+ norm_class = RMSNorm if use_rmsnorm else norm_class
+ norm_fn = partial(norm_class, dim)
+
+ norm_fn = nn.Identity if use_rezero else norm_fn
+ branch_fn = Rezero if use_rezero else None
+
+ if cross_attend and not only_cross:
+ default_block = ('a', 'c', 'f')
+ elif cross_attend and only_cross:
+ default_block = ('c', 'f')
+ else:
+ default_block = ('a', 'f')
+
+ if macaron:
+ default_block = ('f',) + default_block
+
+ if exists(custom_layers):
+ layer_types = custom_layers
+ elif exists(par_ratio):
+ par_depth = depth * len(default_block)
+ assert 1 < par_ratio <= par_depth, 'par ratio out of range'
+ default_block = tuple(filter(not_equals('f'), default_block))
+ par_attn = par_depth // par_ratio
+ depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
+ assert len(default_block) <= par_width, 'default block is too large for par_ratio'
+ par_block = default_block + ('f',) * (par_width - len(default_block))
+ par_head = par_block * par_attn
+ layer_types = par_head + ('f',) * (par_depth - len(par_head))
+ elif exists(sandwich_coef):
+ assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
+ layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
+ else:
+ layer_types = default_block * depth
+
+ self.layer_types = layer_types
+ self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
+
+ for layer_type in self.layer_types:
+ if layer_type == 'a':
+ layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
+ elif layer_type == 'c':
+ layer = Attention(dim, heads=heads, **attn_kwargs)
+ elif layer_type == 'f':
+ layer = FeedForward(dim, **ff_kwargs)
+ layer = layer if not macaron else Scale(0.5, layer)
+ else:
+ raise Exception(f'invalid layer type {layer_type}')
+
+ if isinstance(layer, Attention) and exists(branch_fn):
+ layer = branch_fn(layer)
+
+ if gate_residual:
+ residual_fn = GRUGating(dim)
+ else:
+ residual_fn = Residual()
+
+ self.layers.append(nn.ModuleList([
+ norm_fn(),
+ layer,
+ residual_fn
+ ]))
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ context_mask=None,
+ mems=None,
+ return_hiddens=False
+ ):
+ hiddens = []
+ intermediates = []
+ prev_attn = None
+ prev_cross_attn = None
+
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
+
+ for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
+ is_last = ind == (len(self.layers) - 1)
+
+ if layer_type == 'a':
+ hiddens.append(x)
+ layer_mem = mems.pop(0)
+
+ residual = x
+
+ if self.pre_norm:
+ x = norm(x)
+
+ if layer_type == 'a':
+ out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
+ prev_attn=prev_attn, mem=layer_mem)
+ elif layer_type == 'c':
+ out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
+ elif layer_type == 'f':
+ out = block(x)
+
+ x = residual_fn(out, residual)
+
+ if layer_type in ('a', 'c'):
+ intermediates.append(inter)
+
+ if layer_type == 'a' and self.residual_attn:
+ prev_attn = inter.pre_softmax_attn
+ elif layer_type == 'c' and self.cross_residual_attn:
+ prev_cross_attn = inter.pre_softmax_attn
+
+ if not self.pre_norm and not is_last:
+ x = norm(x)
+
+ if return_hiddens:
+ intermediates = LayerIntermediates(
+ hiddens=hiddens,
+ attn_intermediates=intermediates
+ )
+
+ return x, intermediates
+
+ return x
+
+
+class Encoder(AttentionLayers):
+ def __init__(self, **kwargs):
+ assert 'causal' not in kwargs, 'cannot set causality on encoder'
+ super().__init__(causal=False, **kwargs)
+
+
+
+class TransformerWrapper(nn.Module):
+ def __init__(
+ self,
+ *,
+ num_tokens,
+ max_seq_len,
+ attn_layers,
+ emb_dim=None,
+ max_mem_len=0.,
+ emb_dropout=0.,
+ num_memory_tokens=None,
+ tie_embedding=False,
+ use_pos_emb=True
+ ):
+ super().__init__()
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
+
+ dim = attn_layers.dim
+ emb_dim = default(emb_dim, dim)
+
+ self.max_seq_len = max_seq_len
+ self.max_mem_len = max_mem_len
+ self.num_tokens = num_tokens
+
+ self.token_emb = nn.Embedding(num_tokens, emb_dim)
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
+ self.emb_dropout = nn.Dropout(emb_dropout)
+
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
+ self.attn_layers = attn_layers
+ self.norm = nn.LayerNorm(dim)
+
+ self.init_()
+
+ self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
+
+ # memory tokens (like [cls]) from Memory Transformers paper
+ num_memory_tokens = default(num_memory_tokens, 0)
+ self.num_memory_tokens = num_memory_tokens
+ if num_memory_tokens > 0:
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
+
+ # let funnel encoder know number of memory tokens, if specified
+ if hasattr(attn_layers, 'num_memory_tokens'):
+ attn_layers.num_memory_tokens = num_memory_tokens
+
+ def init_(self):
+ nn.init.normal_(self.token_emb.weight, std=0.02)
+
+ def forward(
+ self,
+ x,
+ return_embeddings=False,
+ mask=None,
+ return_mems=False,
+ return_attn=False,
+ mems=None,
+ **kwargs
+ ):
+ b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
+ x = self.token_emb(x)
+ x += self.pos_emb(x)
+ x = self.emb_dropout(x)
+
+ x = self.project_emb(x)
+
+ if num_mem > 0:
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
+ x = torch.cat((mem, x), dim=1)
+
+ # auto-handle masking after appending memory tokens
+ if exists(mask):
+ mask = F.pad(mask, (num_mem, 0), value=True)
+
+ x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
+ x = self.norm(x)
+
+ mem, x = x[:, :num_mem], x[:, num_mem:]
+
+ out = self.to_logits(x) if not return_embeddings else x
+
+ if return_mems:
+ hiddens = intermediates.hiddens
+ new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
+ return out, new_mems
+
+ if return_attn:
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
+ return out, attn_maps
+
+ return out
+
diff --git a/svrm/ldm/util.py b/svrm/ldm/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b32ee92cc886d3d949abbfb9fea9c9633cb4a6c
--- /dev/null
+++ b/svrm/ldm/util.py
@@ -0,0 +1,252 @@
+import os
+import importlib
+from inspect import isfunction
+import cv2
+import time
+import numpy as np
+from PIL import Image, ImageDraw, ImageFont
+import matplotlib.pyplot as plt
+import torch
+from torch import optim
+import torchvision
+
+
+def pil_rectangle_crop(im):
+ width, height = im.size # Get dimensions
+
+ if width <= height:
+ left = 0
+ right = width
+ top = (height - width)/2
+ bottom = (height + width)/2
+ else:
+
+ top = 0
+ bottom = height
+ left = (width - height) / 2
+ bottom = (width + height) / 2
+
+ # Crop the center of the image
+ im = im.crop((left, top, right, bottom))
+ return im
+
+
+def add_margin(pil_img, color, size=256):
+ width, height = pil_img.size
+ result = Image.new(pil_img.mode, (size, size), color)
+ result.paste(pil_img, ((size - width) // 2, (size - height) // 2))
+ return result
+
+
+def load_and_preprocess(interface, input_im):
+ '''
+ :param input_im (PIL Image).
+ :return image (H, W, 3) array in [0, 1].
+ '''
+ # See https://github.com/Ir1d/image-background-remove-tool
+ image = input_im.convert('RGB')
+
+ image_without_background = interface([image])[0]
+ image_without_background = np.array(image_without_background)
+ est_seg = image_without_background > 127
+ image = np.array(image)
+ foreground = est_seg[:, : , -1].astype(np.bool_)
+ image[~foreground] = [255., 255., 255.]
+ x, y, w, h = cv2.boundingRect(foreground.astype(np.uint8))
+ image = image[y:y+h, x:x+w, :]
+ image = Image.fromarray(np.array(image))
+
+ # resize image such that long edge is 512
+ image.thumbnail([200, 200], Image.Resampling.LANCZOS)
+ image = add_margin(image, (255, 255, 255), size=256)
+ image = np.array(image)
+ return image
+
+
+def log_txt_as_img(wh, xc, size=10):
+ # wh a tuple of (width, height)
+ # xc a list of captions to plot
+ b = len(xc)
+ txts = list()
+ for bi in range(b):
+ txt = Image.new("RGB", wh, color="white")
+ draw = ImageDraw.Draw(txt)
+ font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
+ nc = int(40 * (wh[0] / 256))
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
+
+ try:
+ draw.text((0, 0), lines, fill="black", font=font)
+ except UnicodeEncodeError:
+ print("Cant encode string for logging. Skipping.")
+
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+ txts.append(txt)
+ txts = np.stack(txts)
+ txts = torch.tensor(txts)
+ return txts
+
+
+def ismap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+ if not isinstance(x,torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def exists(x):
+ return x is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+ """
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
+ return total_params
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == '__is_first_stage__':
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+class AdamWwithEMAandWings(optim.Optimizer):
+ # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
+ def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
+ weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
+ ema_power=1., param_names=()):
+ """AdamW that saves EMA versions of the parameters."""
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ if not 0.0 <= ema_decay <= 1.0:
+ raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
+ defaults = dict(lr=lr, betas=betas, eps=eps,
+ weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
+ ema_power=ema_power, param_names=param_names)
+ super().__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super().__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('amsgrad', False)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ params_with_grad = []
+ grads = []
+ exp_avgs = []
+ exp_avg_sqs = []
+ ema_params_with_grad = []
+ state_sums = []
+ max_exp_avg_sqs = []
+ state_steps = []
+ amsgrad = group['amsgrad']
+ beta1, beta2 = group['betas']
+ ema_decay = group['ema_decay']
+ ema_power = group['ema_power']
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ params_with_grad.append(p)
+ if p.grad.is_sparse:
+ raise RuntimeError('AdamW does not support sparse gradients')
+ grads.append(p.grad)
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+ if amsgrad:
+ # Maintains max of all exp. moving avg. of sq. grad. values
+ state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+ # Exponential moving average of parameter values
+ state['param_exp_avg'] = p.detach().float().clone()
+
+ exp_avgs.append(state['exp_avg'])
+ exp_avg_sqs.append(state['exp_avg_sq'])
+ ema_params_with_grad.append(state['param_exp_avg'])
+
+ if amsgrad:
+ max_exp_avg_sqs.append(state['max_exp_avg_sq'])
+
+ # update the steps for each param group update
+ state['step'] += 1
+ # record the step after step update
+ state_steps.append(state['step'])
+
+ optim._functional.adamw(params_with_grad,
+ grads,
+ exp_avgs,
+ exp_avg_sqs,
+ max_exp_avg_sqs,
+ state_steps,
+ amsgrad=amsgrad,
+ beta1=beta1,
+ beta2=beta2,
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ maximize=False)
+
+ cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
+ for param, ema_param in zip(params_with_grad, ema_params_with_grad):
+ ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
+
+ return loss
diff --git a/svrm/ldm/utils/ops.py b/svrm/ldm/utils/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..628f2d6208aa457b40b6292c73dbceda0fc7d4a4
--- /dev/null
+++ b/svrm/ldm/utils/ops.py
@@ -0,0 +1,538 @@
+import os
+import math
+import numpy as np
+from collections import defaultdict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.cuda.amp import custom_bwd, custom_fwd
+from igl import fast_winding_number_for_meshes, point_mesh_squared_distance, read_obj
+
+from .typing import *
+
+
+def get_rank():
+ # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
+ # therefore LOCAL_RANK needs to be checked first
+ rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
+ for key in rank_keys:
+ rank = os.environ.get(key)
+ if rank is not None:
+ return int(rank)
+ return 0
+
+def dot(x, y):
+ return torch.sum(x * y, -1, keepdim=True)
+
+
+def reflect(x, n):
+ return 2 * dot(x, n) * n - x
+
+
+ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]]
+
+
+def scale_tensor(
+ dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale
+):
+ if inp_scale is None:
+ inp_scale = (0, 1)
+ if tgt_scale is None:
+ tgt_scale = (0, 1)
+ if isinstance(tgt_scale, Tensor):
+ assert dat.shape[-1] == tgt_scale.shape[-1]
+ dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
+ dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
+ return dat
+
+
+class _TruncExp(Function): # pylint: disable=abstract-method
+ # Implementation from torch-ngp:
+ # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, x): # pylint: disable=arguments-differ
+ ctx.save_for_backward(x)
+ return torch.exp(x)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, g): # pylint: disable=arguments-differ
+ x = ctx.saved_tensors[0]
+ return g * torch.exp(torch.clamp(x, max=15))
+
+
+class SpecifyGradient(Function):
+ # Implementation from stable-dreamfusion
+ # https://github.com/ashawkey/stable-dreamfusion
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, input_tensor, gt_grad):
+ ctx.save_for_backward(gt_grad)
+ # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
+ return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_scale):
+ (gt_grad,) = ctx.saved_tensors
+ gt_grad = gt_grad * grad_scale
+ return gt_grad, None
+
+
+trunc_exp = _TruncExp.apply
+
+
+def get_activation(name) -> Callable:
+ if name is None:
+ return lambda x: x
+ name = name.lower()
+ if name == "none":
+ return lambda x: x
+ elif name == "lin2srgb":
+ return lambda x: torch.where(
+ x > 0.0031308,
+ torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055,
+ 12.92 * x,
+ ).clamp(0.0, 1.0)
+ elif name == "exp":
+ return lambda x: torch.exp(x)
+ elif name == "shifted_exp":
+ return lambda x: torch.exp(x - 1.0)
+ elif name == "trunc_exp":
+ return trunc_exp
+ elif name == "shifted_trunc_exp":
+ return lambda x: trunc_exp(x - 1.0)
+ elif name == "sigmoid":
+ return lambda x: torch.sigmoid(x)
+ elif name == "tanh":
+ return lambda x: torch.tanh(x)
+ elif name == "shifted_softplus":
+ return lambda x: F.softplus(x - 1.0)
+ elif name == "scale_-11_01":
+ return lambda x: x * 0.5 + 0.5
+ else:
+ try:
+ return getattr(F, name)
+ except AttributeError:
+ raise ValueError(f"Unknown activation function: {name}")
+
+
+def chunk_batch(func: Callable, chunk_size: int, triplane=None, *args, **kwargs) -> Any:
+ if chunk_size <= 0:
+ return func(*args, **kwargs)
+ B = None
+ for arg in list(args) + list(kwargs.values()):
+ if isinstance(arg, torch.Tensor):
+ B = arg.shape[0]
+ break
+ assert (
+ B is not None
+ ), "No tensor found in args or kwargs, cannot determine batch size."
+ out = defaultdict(list)
+ out_type = None
+ # max(1, B) to support B == 0
+ for i in range(0, max(1, B), chunk_size):
+ if triplane is not None:
+ out_chunk = func(triplane=triplane,
+ *[
+ arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
+ for arg in args
+ ],
+ **{
+ k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
+ for k, arg in kwargs.items()
+ },
+ )
+ else:
+ out_chunk = func(
+ *[
+ arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
+ for arg in args
+ ],
+ **{
+ k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
+ for k, arg in kwargs.items()
+ },
+ )
+ if out_chunk is None:
+ continue
+ out_type = type(out_chunk)
+ if isinstance(out_chunk, torch.Tensor):
+ out_chunk = {0: out_chunk}
+ elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list):
+ chunk_length = len(out_chunk)
+ out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)}
+ elif isinstance(out_chunk, dict):
+ pass
+ else:
+ print(
+ f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}."
+ )
+ exit(1)
+ for k, v in out_chunk.items():
+ v = v if torch.is_grad_enabled() else v.detach()
+ out[k].append(v)
+
+ if out_type is None:
+ return None
+
+ out_merged: Dict[Any, Optional[torch.Tensor]] = {}
+ for k, v in out.items():
+ if all([vv is None for vv in v]):
+ # allow None in return value
+ out_merged[k] = None
+ elif all([isinstance(vv, torch.Tensor) for vv in v]):
+ out_merged[k] = torch.cat(v, dim=0)
+ else:
+ raise TypeError(
+ f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}"
+ )
+
+ if out_type is torch.Tensor:
+ return out_merged[0]
+ elif out_type in [tuple, list]:
+ return out_type([out_merged[i] for i in range(chunk_length)])
+ elif out_type is dict:
+ return out_merged
+
+
+def get_ray_directions(
+ H: int,
+ W: int,
+ focal: Union[float, Tuple[float, float]],
+ principal: Optional[Tuple[float, float]] = None,
+ use_pixel_centers: bool = True,
+) -> Float[Tensor, "H W 3"]:
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+
+ Inputs:
+ H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ pixel_center = 0.5 if use_pixel_centers else 0
+
+ if isinstance(focal, float):
+ fx, fy = focal, focal
+ cx, cy = W / 2, H / 2
+ else:
+ fx, fy = focal
+ assert principal is not None
+ cx, cy = principal
+
+ i, j = torch.meshgrid(
+ torch.arange(W, dtype=torch.float32) + pixel_center,
+ torch.arange(H, dtype=torch.float32) + pixel_center,
+ indexing="xy",
+ )
+
+ directions: Float[Tensor, "H W 3"] = torch.stack(
+ [(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1
+ )
+
+ return directions
+
+
+def get_rays(
+ directions: Float[Tensor, "... 3"],
+ c2w: Float[Tensor, "... 4 4"],
+ keepdim=False,
+ noise_scale=0.0,
+) -> Tuple[Float[Tensor, "... 3"], Float[Tensor, "... 3"]]:
+ # Rotate ray directions from camera coordinate to the world coordinate
+ assert directions.shape[-1] == 3
+
+ if directions.ndim == 2: # (N_rays, 3)
+ if c2w.ndim == 2: # (4, 4)
+ c2w = c2w[None, :, :]
+ assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4)
+ rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3)
+ rays_o = c2w[:, :3, 3].expand(rays_d.shape)
+ elif directions.ndim == 3: # (H, W, 3)
+ assert c2w.ndim in [2, 3]
+ if c2w.ndim == 2: # (4, 4)
+ rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum(
+ -1
+ ) # (H, W, 3)
+ rays_o = c2w[None, None, :3, 3].expand(rays_d.shape)
+ elif c2w.ndim == 3: # (B, 4, 4)
+ rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
+ -1
+ ) # (B, H, W, 3)
+ rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
+ elif directions.ndim == 4: # (B, H, W, 3)
+ assert c2w.ndim == 3 # (B, 4, 4)
+ rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
+ -1
+ ) # (B, H, W, 3)
+ rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
+
+ # add camera noise to avoid grid-like artifect
+ # https://github.com/ashawkey/stable-dreamfusion/blob/49c3d4fa01d68a4f027755acf94e1ff6020458cc/nerf/utils.py#L373
+ if noise_scale > 0:
+ rays_o = rays_o + torch.randn(3, device=rays_o.device) * noise_scale
+ rays_d = rays_d + torch.randn(3, device=rays_d.device) * noise_scale
+
+ rays_d = F.normalize(rays_d, dim=-1)
+ if not keepdim:
+ rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
+
+ return rays_o, rays_d
+
+
+def get_projection_matrix(
+ fovy: Float[Tensor, "B"], aspect_wh: float, near: float, far: float
+) -> Float[Tensor, "B 4 4"]:
+ batch_size = fovy.shape[0]
+ proj_mtx = torch.zeros(batch_size, 4, 4, dtype=torch.float32)
+ proj_mtx[:, 0, 0] = 1.0 / (torch.tan(fovy / 2.0) * aspect_wh)
+ proj_mtx[:, 1, 1] = -1.0 / torch.tan(
+ fovy / 2.0
+ ) # add a negative sign here as the y axis is flipped in nvdiffrast output
+ proj_mtx[:, 2, 2] = -(far + near) / (far - near)
+ proj_mtx[:, 2, 3] = -2.0 * far * near / (far - near)
+ proj_mtx[:, 3, 2] = -1.0
+ return proj_mtx
+
+
+def get_mvp_matrix(
+ c2w: Float[Tensor, "B 4 4"], proj_mtx: Float[Tensor, "B 4 4"]
+) -> Float[Tensor, "B 4 4"]:
+ # calculate w2c from c2w: R' = Rt, t' = -Rt * t
+ # mathematically equivalent to (c2w)^-1
+ w2c: Float[Tensor, "B 4 4"] = torch.zeros(c2w.shape[0], 4, 4).to(c2w)
+ w2c[:, :3, :3] = c2w[:, :3, :3].permute(0, 2, 1)
+ w2c[:, :3, 3:] = -c2w[:, :3, :3].permute(0, 2, 1) @ c2w[:, :3, 3:]
+ w2c[:, 3, 3] = 1.0
+ # calculate mvp matrix by proj_mtx @ w2c (mv_mtx)
+ mvp_mtx = proj_mtx @ w2c
+ return mvp_mtx
+
+
+def get_full_projection_matrix(
+ c2w: Float[Tensor, "B 4 4"], proj_mtx: Float[Tensor, "B 4 4"]
+) -> Float[Tensor, "B 4 4"]:
+ return (c2w.unsqueeze(0).bmm(proj_mtx.unsqueeze(0))).squeeze(0)
+
+
+# gaussian splatting functions
+def convert_pose(C2W):
+ flip_yz = torch.eye(4, device=C2W.device)
+ flip_yz[1, 1] = -1
+ flip_yz[2, 2] = -1
+ C2W = torch.matmul(C2W, flip_yz)
+ return C2W
+
+
+def get_projection_matrix_gaussian(znear, zfar, fovX, fovY, device="cuda"):
+ tanHalfFovY = math.tan((fovY / 2))
+ tanHalfFovX = math.tan((fovX / 2))
+
+ top = tanHalfFovY * znear
+ bottom = -top
+ right = tanHalfFovX * znear
+ left = -right
+
+ P = torch.zeros(4, 4, device=device)
+
+ z_sign = 1.0
+
+ P[0, 0] = 2.0 * znear / (right - left)
+ P[1, 1] = 2.0 * znear / (top - bottom)
+ P[0, 2] = (right + left) / (right - left)
+ P[1, 2] = (top + bottom) / (top - bottom)
+ P[3, 2] = z_sign
+ P[2, 2] = z_sign * zfar / (zfar - znear)
+ P[2, 3] = -(zfar * znear) / (zfar - znear)
+ return P
+
+
+def get_fov_gaussian(P):
+ tanHalfFovX = 1 / P[0, 0]
+ tanHalfFovY = 1 / P[1, 1]
+ fovY = math.atan(tanHalfFovY) * 2
+ fovX = math.atan(tanHalfFovX) * 2
+ return fovX, fovY
+
+
+def get_cam_info_gaussian(c2w, fovx, fovy, znear, zfar):
+ c2w = convert_pose(c2w)
+ world_view_transform = torch.inverse(c2w)
+
+ world_view_transform = world_view_transform.transpose(0, 1).cuda().float()
+ projection_matrix = (
+ get_projection_matrix_gaussian(znear=znear, zfar=zfar, fovX=fovx, fovY=fovy)
+ .transpose(0, 1)
+ .cuda()
+ )
+ full_proj_transform = (
+ world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))
+ ).squeeze(0)
+ camera_center = world_view_transform.inverse()[3, :3]
+
+ return world_view_transform, full_proj_transform, camera_center
+
+
+def binary_cross_entropy(input, target):
+ """
+ F.binary_cross_entropy is not numerically stable in mixed-precision training.
+ """
+ return -(target * torch.log(input) + (1 - target) * torch.log(1 - input)).mean()
+
+
+def tet_sdf_diff(
+ vert_sdf: Float[Tensor, "Nv 1"], tet_edges: Integer[Tensor, "Ne 2"]
+) -> Float[Tensor, ""]:
+ sdf_f1x6x2 = vert_sdf[:, 0][tet_edges.reshape(-1)].reshape(-1, 2)
+ mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
+ sdf_f1x6x2 = sdf_f1x6x2[mask]
+ sdf_diff = F.binary_cross_entropy_with_logits(
+ sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()
+ ) + F.binary_cross_entropy_with_logits(
+ sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float()
+ )
+ return sdf_diff
+
+
+# Implementation from Latent-NeRF
+# https://github.com/eladrich/latent-nerf/blob/f49ecefcd48972e69a28e3116fe95edf0fac4dc8/src/latent_nerf/models/mesh_utils.py
+class MeshOBJ:
+ dx = torch.zeros(3).float()
+ dx[0] = 1
+ dy, dz = dx[[1, 0, 2]], dx[[2, 1, 0]]
+ dx, dy, dz = dx[None, :], dy[None, :], dz[None, :]
+
+ def __init__(self, v: np.ndarray, f: np.ndarray):
+ self.v = v
+ self.f = f
+ self.dx, self.dy, self.dz = MeshOBJ.dx, MeshOBJ.dy, MeshOBJ.dz
+ self.v_tensor = torch.from_numpy(self.v)
+
+ vf = self.v[self.f, :]
+ self.f_center = vf.mean(axis=1)
+ self.f_center_tensor = torch.from_numpy(self.f_center).float()
+
+ e1 = vf[:, 1, :] - vf[:, 0, :]
+ e2 = vf[:, 2, :] - vf[:, 0, :]
+ self.face_normals = np.cross(e1, e2)
+ self.face_normals = (
+ self.face_normals / np.linalg.norm(self.face_normals, axis=-1)[:, None]
+ )
+ self.face_normals_tensor = torch.from_numpy(self.face_normals)
+
+ def normalize_mesh(self, target_scale=0.5):
+ verts = self.v
+
+ # Compute center of bounding box
+ # center = torch.mean(torch.column_stack([torch.max(verts, dim=0)[0], torch.min(verts, dim=0)[0]]))
+ center = verts.mean(axis=0)
+ verts = verts - center
+ scale = np.max(np.linalg.norm(verts, axis=1))
+ verts = (verts / scale) * target_scale
+
+ return MeshOBJ(verts, self.f)
+
+ def winding_number(self, query: torch.Tensor):
+ device = query.device
+ shp = query.shape
+ query_np = query.detach().cpu().reshape(-1, 3).numpy()
+ target_alphas = fast_winding_number_for_meshes(
+ self.v.astype(np.float32), self.f, query_np
+ )
+ return torch.from_numpy(target_alphas).reshape(shp[:-1]).to(device)
+
+ def gaussian_weighted_distance(self, query: torch.Tensor, sigma):
+ device = query.device
+ shp = query.shape
+ query_np = query.detach().cpu().reshape(-1, 3).numpy()
+ distances, _, _ = point_mesh_squared_distance(
+ query_np, self.v.astype(np.float32), self.f
+ )
+ distances = torch.from_numpy(distances).reshape(shp[:-1]).to(device)
+ weight = torch.exp(-(distances / (2 * sigma**2)))
+ return weight
+
+
+def ce_pq_loss(p, q, weight=None):
+ def clamp(v, T=0.0001):
+ return v.clamp(T, 1 - T)
+
+ p = p.view(q.shape)
+ ce = -1 * (p * torch.log(clamp(q)) + (1 - p) * torch.log(clamp(1 - q)))
+ if weight is not None:
+ ce *= weight
+ return ce.sum()
+
+
+class ShapeLoss(nn.Module):
+ def __init__(self, guide_shape):
+ super().__init__()
+ self.mesh_scale = 0.7
+ self.proximal_surface = 0.3
+ self.delta = 0.2
+ self.shape_path = guide_shape
+ v, _, _, f, _, _ = read_obj(self.shape_path, float)
+ mesh = MeshOBJ(v, f)
+ matrix_rot = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) @ np.array(
+ [[0, 0, 1], [0, 1, 0], [-1, 0, 0]]
+ )
+ self.sketchshape = mesh.normalize_mesh(self.mesh_scale)
+ self.sketchshape = MeshOBJ(
+ np.ascontiguousarray(
+ (matrix_rot @ self.sketchshape.v.transpose(1, 0)).transpose(1, 0)
+ ),
+ f,
+ )
+
+ def forward(self, xyzs, sigmas):
+ mesh_occ = self.sketchshape.winding_number(xyzs)
+ if self.proximal_surface > 0:
+ weight = 1 - self.sketchshape.gaussian_weighted_distance(
+ xyzs, self.proximal_surface
+ )
+ else:
+ weight = None
+ indicator = (mesh_occ > 0.5).float()
+ nerf_occ = 1 - torch.exp(-self.delta * sigmas)
+ nerf_occ = nerf_occ.clamp(min=0, max=1.1)
+ loss = ce_pq_loss(
+ nerf_occ, indicator, weight=weight
+ ) # order is important for CE loss + second argument may not be optimized
+ return loss
+
+
+def shifted_expotional_decay(a, b, c, r):
+ return a * torch.exp(-b * r) + c
+
+
+def shifted_cosine_decay(a, b, c, r):
+ return a * torch.cos(b * r + c) + a
+
+
+def perpendicular_component(x: Float[Tensor, "B C H W"], y: Float[Tensor, "B C H W"]):
+ # get the component of x that is perpendicular to y
+ eps = torch.ones_like(x[:, 0, 0, 0]) * 1e-6
+ return (
+ x
+ - (
+ torch.mul(x, y).sum(dim=[1, 2, 3])
+ / torch.maximum(torch.mul(y, y).sum(dim=[1, 2, 3]), eps)
+ ).view(-1, 1, 1, 1)
+ * y
+ )
+
+
+def validate_empty_rays(ray_indices, t_start, t_end):
+ if ray_indices.nelement() == 0:
+ print("Warn Empty rays_indices!")
+ ray_indices = torch.LongTensor([0]).to(ray_indices)
+ t_start = torch.Tensor([0]).to(ray_indices)
+ t_end = torch.Tensor([0]).to(ray_indices)
+ return ray_indices, t_start, t_end
+
diff --git a/svrm/ldm/utils/typing.py b/svrm/ldm/utils/typing.py
new file mode 100644
index 0000000000000000000000000000000000000000..35320a646e5bbd16ee7bf91b1ad804a7a1e2c6ac
--- /dev/null
+++ b/svrm/ldm/utils/typing.py
@@ -0,0 +1,38 @@
+"""
+This module contains type annotations for the project, using
+1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects
+2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors
+
+Two types of typing checking can be used:
+1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode)
+2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking)
+"""
+
+# Basic types
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Literal,
+ NamedTuple,
+ NewType,
+ Optional,
+ Sized,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
+
+# Tensor dtype
+# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
+from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
+
+# PyTorch Tensor type
+from torch import Tensor
+
+# Runtime type checking decorator
+from typeguard import typechecked as typechecker
+
diff --git a/svrm/ldm/vis_util.py b/svrm/ldm/vis_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddddfc74266ebdf8d19bd5106ffb17580d057b4
--- /dev/null
+++ b/svrm/ldm/vis_util.py
@@ -0,0 +1,91 @@
+import os
+from typing import List, Optional
+from PIL import Image
+import imageio
+import time
+import torch
+from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj
+from pytorch3d.ops import interpolate_face_attributes
+from pytorch3d.common.datatypes import Device
+from pytorch3d.structures import Meshes
+from pytorch3d.vis.texture_vis import texturesuv_image_matplotlib
+from pytorch3d.renderer import (
+ look_at_view_transform,
+ FoVPerspectiveCameras,
+ PointLights,
+ DirectionalLights,
+ AmbientLights,
+ Materials,
+ RasterizationSettings,
+ MeshRenderer,
+ MeshRasterizer,
+ SoftPhongShader,
+ TexturesUV,
+ TexturesVertex,
+ camera_position_from_spherical_angles,
+ BlendParams,
+)
+
+
+def render(
+ obj_filename,
+ elev=0,
+ azim=0,
+ resolution=512,
+ gif_dst_path='',
+ n_views=120,
+ fps=30,
+ device="cuda:0",
+ rgb=False
+):
+ '''
+ obj_filename: path to obj file
+ gif_dst_path:
+ if set a path, will render n_views frames, then save it to a gif file
+ if not set, will render single frame, then return PIL.Image instance
+ rgb: if set true, will convert result to rgb image/frame
+ '''
+ # load mesh
+ mesh = load_objs_as_meshes([obj_filename], device=device)
+ meshes = mesh.extend(n_views)
+
+ if gif_dst_path != '':
+ elev = torch.linspace(elev, elev, n_views+1)[:-1]
+ azim = torch.linspace(0, 360, n_views+1)[:-1]
+
+ # prepare R,T then compute cameras
+ R, T = look_at_view_transform(dist=1.5, elev=elev, azim=azim)
+ cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=49.1)
+
+ # init pytorch3d renderer instance
+ renderer = MeshRenderer(
+ rasterizer=MeshRasterizer(
+ cameras=cameras,
+ raster_settings=RasterizationSettings(
+ image_size=resolution,
+ blur_radius=0.0,
+ faces_per_pixel=1,
+ ),
+ ),
+ shader=SoftPhongShader(
+ device=device,
+ cameras=cameras,
+ lights=AmbientLights(device=device),
+ blend_params=BlendParams(background_color=(1.0, 1.0, 1.0)),
+ )
+ )
+ images = renderer(meshes)
+
+ # single frame rendering
+ if gif_dst_path == '':
+ frame = images[0, ..., :3] if rgb else images[0, ...]
+ frame = Image.fromarray((frame.cpu().squeeze(0) * 255).numpy().astype("uint8"))
+ return frame
+
+ # orbit frames rendering
+ with imageio.get_writer(uri=gif_dst_path, mode='I', duration=1. / fps * 1000, loop=0) as writer:
+ for i in range(n_views):
+ frame = images[i, ..., :3] if rgb else images[i, ...]
+ frame = Image.fromarray((frame.cpu().squeeze(0) * 255).numpy().astype("uint8"))
+ writer.append_data(frame)
+ return gif_dst_path
diff --git a/svrm/predictor.py b/svrm/predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..19fc4ddd12cbc83caa639c41feba4a2d0e64005a
--- /dev/null
+++ b/svrm/predictor.py
@@ -0,0 +1,150 @@
+# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
+# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
+
+# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
+# The below software and/or models in this distribution may have been
+# modified by THL A29 Limited ("Tencent Modifications").
+# All Tencent Modifications are Copyright (C) THL A29 Limited.
+
+# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
+# except for the third-party components listed below.
+# Hunyuan 3D does not impose any additional limitations beyond what is outlined
+# in the repsective licenses of these third-party components.
+# Users must comply with all terms and conditions of original licenses of these third-party
+# components and must ensure that the usage of the third party components adheres to
+# all relevant laws and regulations.
+
+# For avoidance of doubts, Hunyuan 3D means the large language models and
+# their software and algorithms, including trained model weights, parameters (including
+# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
+# fine-tuning enabling code and other elements of the foregoing made publicly available
+# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
+
+import os
+import math
+import time
+import torch
+import numpy as np
+from tqdm import tqdm
+from PIL import Image, ImageSequence
+from omegaconf import OmegaConf
+from torchvision import transforms
+from safetensors.torch import save_file, load_file
+from .ldm.util import instantiate_from_config
+from .ldm.vis_util import render
+
+class MV23DPredictor(object):
+ def __init__(self, ckpt_path, cfg_path, elevation=15, number_view=60,
+ render_size=256, device="cuda:0") -> None:
+ self.device = device
+ self.elevation = elevation
+ self.number_view = number_view
+ self.render_size = render_size
+
+ self.elevation_list = [0, 0, 0, 0, 0, 0, 0]
+ self.azimuth_list = [0, 60, 120, 180, 240, 300, 0]
+
+ st = time.time()
+ self.model = self.init_model(ckpt_path, cfg_path)
+ print(f"=====> mv23d model init time: {time.time() - st}")
+
+ self.input_view_transform = transforms.Compose([
+ transforms.Resize(504, interpolation=Image.BICUBIC),
+ transforms.ToTensor(),
+ ])
+ self.final_input_view_transform = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
+
+ def init_model(self, ckpt_path, cfg_path):
+ config = OmegaConf.load(cfg_path)
+ model = instantiate_from_config(config.model)
+
+ weights = load_file("./weights/svrm/svrm.safetensors")
+ model.load_state_dict(weights)
+
+ model.to(self.device)
+ model = model.eval()
+ model.render.half()
+ print(f'Load model successfully')
+ return model
+
+ def create_camera_to_world_matrix(self, elevation, azimuth, cam_dis=1.5):
+ # elevation azimuth are radians
+ # Convert elevation and azimuth angles to Cartesian coordinates on a unit sphere
+ x = np.cos(elevation) * np.cos(azimuth)
+ y = np.cos(elevation) * np.sin(azimuth)
+ z = np.sin(elevation)
+
+ # Calculate camera position, target, and up vectors
+ camera_pos = np.array([x, y, z]) * cam_dis
+ target = np.array([0, 0, 0])
+ up = np.array([0, 0, 1])
+
+ # Construct view matrix
+ forward = target - camera_pos
+ forward /= np.linalg.norm(forward)
+ right = np.cross(forward, up)
+ right /= np.linalg.norm(right)
+ new_up = np.cross(right, forward)
+ new_up /= np.linalg.norm(new_up)
+ cam2world = np.eye(4)
+ cam2world[:3, :3] = np.array([right, new_up, -forward]).T
+ cam2world[:3, 3] = camera_pos
+ return cam2world
+
+ def refine_mask(self, mask, k=16):
+ mask /= 255.0
+ boder_mask = (mask >= -math.pi / 2.0 / k + 0.5) & (mask <= math.pi / 2.0 / k + 0.5)
+ mask[boder_mask] = 0.5 * np.sin(k * (mask[boder_mask] - 0.5)) + 0.5
+ mask[mask < -math.pi / 2.0 / k + 0.5] = 0.0
+ mask[mask > math.pi / 2.0 / k + 0.5] = 1.0
+ return (mask * 255.0).astype(np.uint8)
+
+ def load_images_and_cameras(self, input_imgs, elevation_list, azimuth_list):
+ input_image_list = []
+ input_cam_list = []
+ for input_view_image, elevation, azimuth in zip(input_imgs, elevation_list, azimuth_list):
+ input_view_image = self.input_view_transform(input_view_image)
+ input_image_list.append(input_view_image)
+
+ input_view_cam_pos = self.create_camera_to_world_matrix(np.radians(elevation), np.radians(azimuth))
+ input_view_cam_intrinsic = np.array([35. / 32, 35. /32, 0.5, 0.5])
+ input_view_cam = torch.from_numpy(
+ np.concatenate([input_view_cam_pos.reshape(-1), input_view_cam_intrinsic], 0)
+ ).float()
+ input_cam_list.append(input_view_cam)
+
+ pixels_input = torch.stack(input_image_list, dim=0)
+ input_images = self.final_input_view_transform(pixels_input)
+ input_cams = torch.stack(input_cam_list, dim=0)
+ return input_images, input_cams
+
+ def load_data(self, intput_imgs):
+ assert (6+1) == len(intput_imgs)
+
+ input_images, input_cams = self.load_images_and_cameras(intput_imgs, self.elevation_list, self.azimuth_list)
+ input_cams[-1, :] = 0 # for user input view
+
+ data = {}
+ data["input_view"] = input_images.unsqueeze(0).to(self.device) # 1 4 3 512 512
+ data["input_view_cam"] = input_cams.unsqueeze(0).to(self.device) # 1 4 20
+ return data
+
+ @torch.no_grad()
+ def predict(
+ self,
+ intput_imgs,
+ save_dir = "outputs/",
+ image_input = None,
+ target_face_count = 10000,
+ do_texture_mapping = True,
+ ):
+ os.makedirs(save_dir, exist_ok=True)
+ print(save_dir)
+
+ with torch.cuda.amp.autocast():
+ self.model.export_mesh_with_uv(
+ data = self.load_data(intput_imgs),
+ out_dir = save_dir,
+ target_face_count = target_face_count,
+ do_texture_mapping = do_texture_mapping
+ )
diff --git a/svrm/utils/camera_utils.py b/svrm/utils/camera_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fb15e74ab632a30b8c78b21511b3fbfc6ca46d4
--- /dev/null
+++ b/svrm/utils/camera_utils.py
@@ -0,0 +1,90 @@
+import math
+import numpy as np
+
+def compute_extrinsic_matrix(elevation, azimuth, camera_distance):
+ # 将角度转换为弧度
+ elevation_rad = np.radians(elevation)
+ azimuth_rad = np.radians(azimuth)
+
+ R = np.array([
+ [np.cos(azimuth_rad), 0, -np.sin(azimuth_rad)],
+ [0, 1, 0],
+ [np.sin(azimuth_rad), 0, np.cos(azimuth_rad)],
+ ], dtype=np.float32)
+
+ R = R @ np.array([
+ [1, 0, 0],
+ [0, np.cos(elevation_rad), -np.sin(elevation_rad)],
+ [0, np.sin(elevation_rad), np.cos(elevation_rad)]
+ ], dtype=np.float32)
+
+ # 构建平移矩阵 T (3x1)
+ T = np.array([[camera_distance], [0], [0]], dtype=np.float32)
+ T = R @ T
+
+ # 组合成 4x4 的变换矩阵
+ extrinsic_matrix = np.vstack((np.hstack((R, T)), np.array([[0, 0, 0, 1]], dtype=np.float32)))
+
+ return extrinsic_matrix
+
+
+def transform_camera_pose(im_pose, ori_pose, new_pose):
+ T = new_pose @ ori_pose.T
+ transformed_poses = []
+
+ for pose in im_pose:
+ transformed_pose = T @ pose
+ transformed_poses.append(transformed_pose)
+
+ return transformed_poses
+
+def compute_fov(intrinsic_matrix):
+ # 获取内参矩阵中的焦距值
+ fx = intrinsic_matrix[0, 0]
+ fy = intrinsic_matrix[1, 1]
+
+ h, w = intrinsic_matrix[0,2]*2, intrinsic_matrix[1,2]*2
+
+ # 计算水平和垂直方向的FOV值
+ fov_x = 2 * math.atan(w / (2 * fx)) * 180 / math.pi
+ fov_y = 2 * math.atan(h / (2 * fy)) * 180 / math.pi
+
+ return fov_x, fov_y
+
+
+
+def rotation_matrix_to_quaternion(rotation_matrix):
+ rot = Rotation.from_matrix(rotation_matrix)
+ quaternion = rot.as_quat()
+ return quaternion
+
+def quaternion_to_rotation_matrix(quaternion):
+ rot = Rotation.from_quat(quaternion)
+ rotation_matrix = rot.as_matrix()
+ return rotation_matrix
+
+def remap_points(img_size, match, size=512):
+ H, W, _ = img_size
+
+ S = max(W, H)
+ new_W = int(round(W * size / S))
+ new_H = int(round(H * size / S))
+ cx, cy = new_W // 2, new_H // 2
+
+ # 计算变换后的图像中心点坐标
+ halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8
+
+ dw, dh = cx - halfw, cy - halfh
+
+ # 初始化一个新的数组来存储映射回原始图像的点坐标
+ new_match = np.zeros_like(match)
+
+ # 将变换后的点坐标映射回原始图像
+ new_match[:, 0] = (match[:, 0] + dw) / new_W * W
+ new_match[:, 1] = (match[:, 1] + dh) / new_H * H
+
+ #print(dw,new_W,W,dh,new_H,H)
+
+ return new_match
+
+
\ No newline at end of file
diff --git a/svrm/utils/img_utils.py b/svrm/utils/img_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b6db1c08ed09d5fe1d774f6057e2a7d2162b08d
--- /dev/null
+++ b/svrm/utils/img_utils.py
@@ -0,0 +1,217 @@
+import os
+import cv2
+import numpy as np
+from skimage.metrics import hausdorff_distance
+from matplotlib import pyplot as plt
+
+
+def get_input_imgs_path(input_data_dir):
+ path = {}
+ names = ['000', 'ori_000']
+ for name in names:
+ jpg_path = os.path.join(input_data_dir, f"{name}.jpg")
+ png_path = os.path.join(input_data_dir, f"{name}.png")
+ if os.path.exists(jpg_path):
+ path[name] = jpg_path
+ elif os.path.exists(png_path):
+ path[name] = png_path
+ return path
+
+
+def rgba_to_rgb(image, bg_color=[255, 255, 255]):
+ if image.shape[-1] == 3: return image
+
+ rgba = image.astype(float)
+ rgb = rgba[:, :, :3].copy()
+ alpha = rgba[:, :, 3] / 255.0
+
+ bg = np.ones((image.shape[0], image.shape[1], 3), dtype=np.float32)
+ bg = bg * np.array(bg_color, dtype=np.float32)
+
+ rgb = rgb * alpha[:, :, np.newaxis] + bg * (1 - alpha[:, :, np.newaxis])
+ rgb = rgb.astype(np.uint8)
+
+ return rgb
+
+
+def resize_with_aspect_ratio(image1, image2, pad_value=[255, 255, 255]):
+ aspect_ratio1 = float(image1.shape[1]) / float(image1.shape[0])
+ aspect_ratio2 = float(image2.shape[1]) / float(image2.shape[0])
+
+ top_pad, bottom_pad, left_pad, right_pad = 0, 0, 0, 0
+
+ if aspect_ratio1 < aspect_ratio2:
+ new_width = (aspect_ratio2 * image1.shape[0])
+ right_pad = left_pad = int((new_width - image1.shape[1]) / 2)
+ else:
+ new_height = (image1.shape[1] / aspect_ratio2)
+ bottom_pad = top_pad = int((new_height - image1.shape[0]) / 2)
+
+ image1_padded = cv2.copyMakeBorder(
+ image1, top_pad, bottom_pad, left_pad, right_pad, cv2.BORDER_CONSTANT, value=pad_value
+ )
+ return image1_padded
+
+
+def estimate_img_mask(image):
+ # 转换为灰度图像
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
+
+ # 使用大津法进行阈值分割
+ # _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
+ # mask_otsu = thresh.astype(bool)
+ # thresh_gray = 240
+
+ # 使用 Canny 边缘检测算法找到边缘
+ edges = cv2.Canny(gray, 20, 50)
+
+ # 使用形态学操作扩展边缘
+ kernel = np.ones((3, 3), np.uint8)
+ edges_dilated = cv2.dilate(edges, kernel, iterations=1)
+
+ contours, _ = cv2.findContours(edges_dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+
+ # 创建一个空的 mask
+ mask = np.zeros_like(gray, dtype=np.uint8)
+
+ # 根据轮廓信息填充 mask(使用 thickness=cv2.FILLED 参数)
+ cv2.drawContours(mask, contours, -1, 255, thickness=cv2.FILLED)
+ mask = mask.astype(bool)
+
+ return mask
+
+
+def compute_img_diff(img1, img2, matches1, matches1_from_2, vis=False):
+ scale = 0.125
+ gray_trunc_thres = 25 / 255.0
+
+ # Match
+ if matches1.shape[0] > 0:
+ match_scale = np.max(np.ptp(matches1, axis=-1))
+ match_dists = np.sqrt(np.sum((matches1 - matches1_from_2) ** 2, axis=-1))
+ dist_threshold = match_scale * 0.01
+ match_num = np.sum(match_dists <= dist_threshold)
+ match_rate = np.mean(match_dists <= dist_threshold)
+ else:
+ match_num = 0
+ match_rate = 0
+
+ # IOU
+ img1_mask = estimate_img_mask(img1)
+ img2_mask = estimate_img_mask(img2)
+ img_intersection = (img1_mask == 1) & (img2_mask == 1)
+ img_union = (img1_mask == 1) | (img2_mask == 1)
+ intersection = np.sum(img_intersection == 1)
+ union = np.sum(img_union == 1)
+ mask_iou = intersection / union if union != 0 else 0
+
+ # Gray
+ height, width = img1.shape[:2]
+ img1_gray = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
+ img2_gray = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
+ img1_gray = cv2.GaussianBlur(img1_gray, (7, 7), 0)
+ img2_gray = cv2.GaussianBlur(img2_gray, (7, 7), 0)
+
+ # Gray Diff
+ img1_gray_small = cv2.resize(img1_gray, (int(width * scale), int(height * scale)),
+ interpolation=cv2.INTER_LINEAR) / 255.0
+ img2_gray_small = cv2.resize(img2_gray, (int(width * scale), int(height * scale)),
+ interpolation=cv2.INTER_LINEAR) / 255.0
+ img_gray_small_diff = np.abs(img1_gray_small - img2_gray_small)
+ gray_diff = img_gray_small_diff.sum() / (union * scale) if union != 0 else 1
+
+ img_gray_small_diff_trunc = img_gray_small_diff.copy()
+ img_gray_small_diff_trunc[img_gray_small_diff < gray_trunc_thres] = 0
+ gray_diff_trunc = img_gray_small_diff_trunc.sum() / (union * scale) if union != 0 else 1
+
+ # Edge
+ img1_edge = cv2.Canny(img1_gray, 100, 200)
+ img2_edge = cv2.Canny(img2_gray, 100, 200)
+ bw_edges1 = (img1_edge > 0).astype(bool)
+ bw_edges2 = (img2_edge > 0).astype(bool)
+ hausdorff_dist = hausdorff_distance(bw_edges1, bw_edges2)
+ if vis == True:
+ fig, axs = plt.subplots(1, 4, figsize=(15, 5))
+ axs[0].imshow(img1_gray, cmap='gray')
+ axs[0].set_title('Img1')
+ axs[1].imshow(img2_gray, cmap='gray')
+ axs[1].set_title('Img2')
+ axs[2].imshow(img1_mask)
+ axs[2].set_title('Mask1')
+ axs[3].imshow(img2_mask)
+ axs[3].set_title('Mask2')
+ plt.show()
+ plt.figure()
+ mask_cmp = np.zeros((height, width, 3))
+ mask_cmp[img_intersection, 1] = 1
+ mask_cmp[img_union, 0] = 1
+ plt.imshow(mask_cmp)
+ plt.show()
+ fig, axs = plt.subplots(1, 4, figsize=(15, 5))
+ axs[0].imshow(img1_gray_small, cmap='gray')
+ axs[0].set_title('Img1 Gray')
+ axs[1].imshow(img2_gray_small, cmap='gray')
+ axs[1].set_title('Img2 Gary')
+ axs[2].imshow(img_gray_small_diff, cmap='gray')
+ axs[2].set_title('diff')
+ axs[3].imshow(img_gray_small_diff_trunc, cmap='gray')
+ axs[3].set_title('diff_trunct')
+ plt.show()
+ fig, axs = plt.subplots(1, 2, figsize=(15, 5))
+ axs[0].imshow(img1_edge, cmap='gray')
+ axs[0].set_title('img1_edge')
+ axs[1].imshow(img2_edge, cmap='gray')
+ axs[1].set_title('img2_edge')
+ plt.show()
+
+ info = {}
+ info['match_num'] = match_num
+ info['match_rate'] = match_rate
+ info['mask_iou'] = mask_iou
+ info['gray_diff'] = gray_diff
+ info['gray_diff_trunc'] = gray_diff_trunc
+ info['hausdorff_dist'] = hausdorff_dist
+ return info
+
+
+def predict_match_success_human(info):
+ match_num = info['match_num']
+ match_rate = info['match_rate']
+ mask_iou = info['mask_iou']
+ gray_diff = info['gray_diff']
+ gray_diff_trunc = info['gray_diff_trunc']
+ hausdorff_dist = info['hausdorff_dist']
+
+ if mask_iou > 0.95:
+ return True
+
+ if match_num < 20 or match_rate < 0.7:
+ return False
+
+ if mask_iou > 0.80 and gray_diff < 0.040 and gray_diff_trunc < 0.010:
+ return True
+
+ if mask_iou > 0.70 and gray_diff < 0.050 and gray_diff_trunc < 0.008:
+ return True
+
+ '''
+ if match_rate<0.70 or match_num<3000:
+ return False
+
+ if (mask_iou>0.85 and hausdorff_dist<20)or (gray_diff<0.015 and gray_diff_trunc<0.01) or match_rate>=0.90:
+ return True
+ '''
+
+ return False
+
+
+def predict_match_success(info, model=None):
+ if model == None:
+ return predict_match_success_human(info)
+ else:
+ feat_name = ['match_num', 'match_rate', 'mask_iou', 'gray_diff', 'gray_diff_trunc', 'hausdorff_dist']
+ # 提取特征
+ features = [info[f] for f in feat_name]
+ # 预测
+ pred = model.predict([features])[0]
+ return pred >= 0.5
\ No newline at end of file
diff --git a/svrm/utils/log_utils.py b/svrm/utils/log_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..12d0a5abcff36904ac7d748f10f8c94e37007868
--- /dev/null
+++ b/svrm/utils/log_utils.py
@@ -0,0 +1,21 @@
+import cv2
+import numpy as np
+
+def txt_to_img(text, font=cv2.FONT_HERSHEY_SIMPLEX, font_scale=1, font_thickness=2, img_width=1000, img_height=100, text_color=(0, 0, 0), bg_color=(255, 255, 255)):
+ lines = text.split('\n')
+ img_lines = []
+ for line in lines:
+ # 计算每行文本的尺寸
+ line_size, _ = cv2.getTextSize(line, font, font_scale, font_thickness)
+ line_width, line_height = line_size
+ # 创建包含当前行的图像画布
+ img_line = np.full((int(line_height*1.5) , img_width, 3), bg_color, dtype=np.uint8)
+ text_x, text_y = 0, line_height
+ cv2.putText(img_line, line, (text_x, text_y), font, font_scale, text_color, font_thickness, cv2.LINE_AA)
+
+ img_lines.append(img_line)
+
+
+ # 垂直堆叠所有行图像
+ img = np.vstack(img_lines)
+ return img
\ No newline at end of file