diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..27a5d99b6134654d00d9e2b5c4e9a5bd31e118e9 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,36 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3245013519348162e619fa6eb1cc9f6584522887 --- /dev/null +++ b/README.md @@ -0,0 +1,16 @@ +--- +title: TRELLIS +emoji: 🏢 +colorFrom: indigo +colorTo: blue +sdk: gradio +sdk_version: 4.44.1 +app_file: app.py +pinned: false +license: mit +short_description: Scalable and Versatile 3D Generation from images +--- + +Check out the configuration reference at https://huggingface.co./docs/hub/spaces-config-reference + +Paper: https://huggingface.co./papers/2412.01506 \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..d1943aebc7519ee3f2a2c5db170c8d9e9ee60c16 --- /dev/null +++ b/app.py @@ -0,0 +1,414 @@ +import gradio as gr +import spaces +from gradio_litmodel3d import LitModel3D + +import os +import shutil +os.environ['SPCONV_ALGO'] = 'native' +from typing import * +import torch +import numpy as np +import imageio +from easydict import EasyDict as edict +from PIL import Image +from trellis.pipelines import TrellisImageTo3DPipeline +from trellis.representations import Gaussian, MeshExtractResult +from trellis.utils import render_utils, postprocessing_utils + + +MAX_SEED = np.iinfo(np.int32).max +TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') +os.makedirs(TMP_DIR, exist_ok=True) + + +def start_session(req: gr.Request): + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + os.makedirs(user_dir, exist_ok=True) + + +def end_session(req: gr.Request): + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + shutil.rmtree(user_dir) + + +def preprocess_image(image: Image.Image) -> Image.Image: + """ + Preprocess the input image. + + Args: + image (Image.Image): The input image. + + Returns: + Image.Image: The preprocessed image. + """ + processed_image = pipeline.preprocess_image(image) + return processed_image + + +def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]: + """ + Preprocess a list of input images. + + Args: + images (List[Tuple[Image.Image, str]]): The input images. + + Returns: + List[Image.Image]: The preprocessed images. + """ + images = [image[0] for image in images] + processed_images = [pipeline.preprocess_image(image) for image in images] + return processed_images + + +def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict: + return { + 'gaussian': { + **gs.init_params, + '_xyz': gs._xyz.cpu().numpy(), + '_features_dc': gs._features_dc.cpu().numpy(), + '_scaling': gs._scaling.cpu().numpy(), + '_rotation': gs._rotation.cpu().numpy(), + '_opacity': gs._opacity.cpu().numpy(), + }, + 'mesh': { + 'vertices': mesh.vertices.cpu().numpy(), + 'faces': mesh.faces.cpu().numpy(), + }, + } + + +def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]: + gs = Gaussian( + aabb=state['gaussian']['aabb'], + sh_degree=state['gaussian']['sh_degree'], + mininum_kernel_size=state['gaussian']['mininum_kernel_size'], + scaling_bias=state['gaussian']['scaling_bias'], + opacity_bias=state['gaussian']['opacity_bias'], + scaling_activation=state['gaussian']['scaling_activation'], + ) + gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda') + gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda') + gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda') + gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda') + gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda') + + mesh = edict( + vertices=torch.tensor(state['mesh']['vertices'], device='cuda'), + faces=torch.tensor(state['mesh']['faces'], device='cuda'), + ) + + return gs, mesh + + +def get_seed(randomize_seed: bool, seed: int) -> int: + """ + Get the random seed. + """ + return np.random.randint(0, MAX_SEED) if randomize_seed else seed + + +@spaces.GPU +def image_to_3d( + image: Image.Image, + multiimages: List[Tuple[Image.Image, str]], + is_multiimage: bool, + seed: int, + ss_guidance_strength: float, + ss_sampling_steps: int, + slat_guidance_strength: float, + slat_sampling_steps: int, + multiimage_algo: Literal["multidiffusion", "stochastic"], + req: gr.Request, +) -> Tuple[dict, str]: + """ + Convert an image to a 3D model. + + Args: + image (Image.Image): The input image. + multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode. + is_multiimage (bool): Whether is in multi-image mode. + seed (int): The random seed. + ss_guidance_strength (float): The guidance strength for sparse structure generation. + ss_sampling_steps (int): The number of sampling steps for sparse structure generation. + slat_guidance_strength (float): The guidance strength for structured latent generation. + slat_sampling_steps (int): The number of sampling steps for structured latent generation. + multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation. + + Returns: + dict: The information of the generated 3D model. + str: The path to the video of the 3D model. + """ + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + if not is_multiimage: + outputs = pipeline.run( + image, + seed=seed, + formats=["gaussian", "mesh"], + preprocess_image=False, + sparse_structure_sampler_params={ + "steps": ss_sampling_steps, + "cfg_strength": ss_guidance_strength, + }, + slat_sampler_params={ + "steps": slat_sampling_steps, + "cfg_strength": slat_guidance_strength, + }, + ) + else: + outputs = pipeline.run_multi_image( + [image[0] for image in multiimages], + seed=seed, + formats=["gaussian", "mesh"], + preprocess_image=False, + sparse_structure_sampler_params={ + "steps": ss_sampling_steps, + "cfg_strength": ss_guidance_strength, + }, + slat_sampler_params={ + "steps": slat_sampling_steps, + "cfg_strength": slat_guidance_strength, + }, + mode=multiimage_algo, + ) + video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color'] + video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal'] + video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))] + video_path = os.path.join(user_dir, 'sample.mp4') + imageio.mimsave(video_path, video, fps=15) + state = pack_state(outputs['gaussian'][0], outputs['mesh'][0]) + torch.cuda.empty_cache() + return state, video_path + + +@spaces.GPU(duration=90) +def extract_glb( + state: dict, + mesh_simplify: float, + texture_size: int, + req: gr.Request, +) -> Tuple[str, str]: + """ + Extract a GLB file from the 3D model. + + Args: + state (dict): The state of the generated 3D model. + mesh_simplify (float): The mesh simplification factor. + texture_size (int): The texture resolution. + + Returns: + str: The path to the extracted GLB file. + """ + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + gs, mesh = unpack_state(state) + glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False) + glb_path = os.path.join(user_dir, 'sample.glb') + glb.export(glb_path) + torch.cuda.empty_cache() + return glb_path, glb_path + + +@spaces.GPU +def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]: + """ + Extract a Gaussian file from the 3D model. + + Args: + state (dict): The state of the generated 3D model. + + Returns: + str: The path to the extracted Gaussian file. + """ + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + gs, _ = unpack_state(state) + gaussian_path = os.path.join(user_dir, 'sample.ply') + gs.save_ply(gaussian_path) + torch.cuda.empty_cache() + return gaussian_path, gaussian_path + + +def prepare_multi_example() -> List[Image.Image]: + multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")])) + images = [] + for case in multi_case: + _images = [] + for i in range(1, 4): + img = Image.open(f'assets/example_multi_image/{case}_{i}.png') + W, H = img.size + img = img.resize((int(W / H * 512), 512)) + _images.append(np.array(img)) + images.append(Image.fromarray(np.concatenate(_images, axis=1))) + return images + + +def split_image(image: Image.Image) -> List[Image.Image]: + """ + Split an image into multiple views. + """ + image = np.array(image) + alpha = image[..., 3] + alpha = np.any(alpha>0, axis=0) + start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist() + end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist() + images = [] + for s, e in zip(start_pos, end_pos): + images.append(Image.fromarray(image[:, s:e+1])) + return [preprocess_image(image) for image in images] + + +with gr.Blocks(delete_cache=(600, 600)) as demo: + gr.Markdown(""" + ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/) + * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background. + * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it. + + ✨New: 1) Experimental multi-image support. 2) Gaussian file extraction. + """) + + with gr.Row(): + with gr.Column(): + with gr.Tabs() as input_tabs: + with gr.Tab(label="Single Image", id=0) as single_image_input_tab: + image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300) + with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab: + multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3) + gr.Markdown(""" + Input different views of the object in separate images. + + *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.* + """) + + with gr.Accordion(label="Generation Settings", open=False): + seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) + randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) + gr.Markdown("Stage 1: Sparse Structure Generation") + with gr.Row(): + ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) + ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) + gr.Markdown("Stage 2: Structured Latent Generation") + with gr.Row(): + slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1) + slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) + multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic") + + generate_btn = gr.Button("Generate") + + with gr.Accordion(label="GLB Extraction Settings", open=False): + mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01) + texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512) + + with gr.Row(): + extract_glb_btn = gr.Button("Extract GLB", interactive=False) + extract_gs_btn = gr.Button("Extract Gaussian", interactive=False) + gr.Markdown(""" + *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.* + """) + + with gr.Column(): + video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300) + model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300) + + with gr.Row(): + download_glb = gr.DownloadButton(label="Download GLB", interactive=False) + download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False) + + is_multiimage = gr.State(False) + output_buf = gr.State() + + # Example images at the bottom of the page + with gr.Row() as single_image_example: + examples = gr.Examples( + examples=[ + f'assets/example_image/{image}' + for image in os.listdir("assets/example_image") + ], + inputs=[image_prompt], + fn=preprocess_image, + outputs=[image_prompt], + run_on_click=True, + examples_per_page=64, + ) + with gr.Row(visible=False) as multiimage_example: + examples_multi = gr.Examples( + examples=prepare_multi_example(), + inputs=[image_prompt], + fn=split_image, + outputs=[multiimage_prompt], + run_on_click=True, + examples_per_page=8, + ) + + # Handlers + demo.load(start_session) + demo.unload(end_session) + + single_image_input_tab.select( + lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]), + outputs=[is_multiimage, single_image_example, multiimage_example] + ) + multiimage_input_tab.select( + lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]), + outputs=[is_multiimage, single_image_example, multiimage_example] + ) + + image_prompt.upload( + preprocess_image, + inputs=[image_prompt], + outputs=[image_prompt], + ) + multiimage_prompt.upload( + preprocess_images, + inputs=[multiimage_prompt], + outputs=[multiimage_prompt], + ) + + generate_btn.click( + get_seed, + inputs=[randomize_seed, seed], + outputs=[seed], + ).then( + image_to_3d, + inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo], + outputs=[output_buf, video_output], + ).then( + lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]), + outputs=[extract_glb_btn, extract_gs_btn], + ) + + video_output.clear( + lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]), + outputs=[extract_glb_btn, extract_gs_btn], + ) + + extract_glb_btn.click( + extract_glb, + inputs=[output_buf, mesh_simplify, texture_size], + outputs=[model_output, download_glb], + ).then( + lambda: gr.Button(interactive=True), + outputs=[download_glb], + ) + + extract_gs_btn.click( + extract_gaussian, + inputs=[output_buf], + outputs=[model_output, download_gs], + ).then( + lambda: gr.Button(interactive=True), + outputs=[download_gs], + ) + + model_output.clear( + lambda: gr.Button(interactive=False), + outputs=[download_glb], + ) + + +# Launch the Gradio app +if __name__ == "__main__": + pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large") + pipeline.cuda() + try: + pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg + except: + pass + demo.launch() diff --git a/assets/example_image/T.png b/assets/example_image/T.png new file mode 100644 index 0000000000000000000000000000000000000000..79af51bc9d711951fbc63be16b7d07c84294355b Binary files /dev/null and b/assets/example_image/T.png differ diff --git a/assets/example_image/typical_building_building.png b/assets/example_image/typical_building_building.png new file mode 100644 index 0000000000000000000000000000000000000000..4f9adcf79d4297bb9b906608c23c311f9b8f23d2 Binary files /dev/null and b/assets/example_image/typical_building_building.png differ diff --git a/assets/example_image/typical_building_castle.png b/assets/example_image/typical_building_castle.png new file mode 100644 index 0000000000000000000000000000000000000000..5f5f50733f3b8ed168026340ed679df357ccb9ec Binary files /dev/null and b/assets/example_image/typical_building_castle.png differ diff --git a/assets/example_image/typical_building_colorful_cottage.png b/assets/example_image/typical_building_colorful_cottage.png new file mode 100644 index 0000000000000000000000000000000000000000..94616b19be6a896413c287b3168b3b7886d64a56 Binary files /dev/null and b/assets/example_image/typical_building_colorful_cottage.png differ diff --git a/assets/example_image/typical_building_maya_pyramid.png b/assets/example_image/typical_building_maya_pyramid.png new file mode 100644 index 0000000000000000000000000000000000000000..1d87f3c3980f80eee878b4f8ab69e32279a5ea50 Binary files /dev/null and b/assets/example_image/typical_building_maya_pyramid.png differ diff --git a/assets/example_image/typical_building_mushroom.png b/assets/example_image/typical_building_mushroom.png new file mode 100644 index 0000000000000000000000000000000000000000..c4db49d4284e6fec83b2ea548e7e85d489711b84 Binary files /dev/null and b/assets/example_image/typical_building_mushroom.png differ diff --git a/assets/example_image/typical_building_space_station.png b/assets/example_image/typical_building_space_station.png new file mode 100644 index 0000000000000000000000000000000000000000..e37a5806d403dccc098d82492d687d71afa36850 Binary files /dev/null and b/assets/example_image/typical_building_space_station.png differ diff --git a/assets/example_image/typical_creature_dragon.png b/assets/example_image/typical_creature_dragon.png new file mode 100644 index 0000000000000000000000000000000000000000..c3fb92ff0400451c69f44fb75156f50db93da6ec Binary files /dev/null and b/assets/example_image/typical_creature_dragon.png differ diff --git a/assets/example_image/typical_creature_elephant.png b/assets/example_image/typical_creature_elephant.png new file mode 100644 index 0000000000000000000000000000000000000000..6fc3cf1776c66b91e739cad8e839b52074a57e4c Binary files /dev/null and b/assets/example_image/typical_creature_elephant.png differ diff --git a/assets/example_image/typical_creature_furry.png b/assets/example_image/typical_creature_furry.png new file mode 100644 index 0000000000000000000000000000000000000000..eb4e8d6c6cac1e03a206429eaf7de261c6f14072 Binary files /dev/null and b/assets/example_image/typical_creature_furry.png differ diff --git a/assets/example_image/typical_creature_quadruped.png b/assets/example_image/typical_creature_quadruped.png new file mode 100644 index 0000000000000000000000000000000000000000..b246e08e05702051fb22cced1366ab765cd6fbb0 Binary files /dev/null and b/assets/example_image/typical_creature_quadruped.png differ diff --git a/assets/example_image/typical_creature_robot_crab.png b/assets/example_image/typical_creature_robot_crab.png new file mode 100644 index 0000000000000000000000000000000000000000..8b4e10b353e0e9b60634ea272ff8fd9135fdd640 Binary files /dev/null and b/assets/example_image/typical_creature_robot_crab.png differ diff --git a/assets/example_image/typical_creature_robot_dinosour.png b/assets/example_image/typical_creature_robot_dinosour.png new file mode 100644 index 0000000000000000000000000000000000000000..7f8f51728fe1fecb0532673756b1601ef46edc2c Binary files /dev/null and b/assets/example_image/typical_creature_robot_dinosour.png differ diff --git a/assets/example_image/typical_creature_rock_monster.png b/assets/example_image/typical_creature_rock_monster.png new file mode 100644 index 0000000000000000000000000000000000000000..29dc243b197d9b3ee4df9355a5f08752ef0b9b9e Binary files /dev/null and b/assets/example_image/typical_creature_rock_monster.png differ diff --git a/assets/example_image/typical_humanoid_block_robot.png b/assets/example_image/typical_humanoid_block_robot.png new file mode 100644 index 0000000000000000000000000000000000000000..195212e38e6a8e331b02c2d58728ba41dba429a1 Binary files /dev/null and b/assets/example_image/typical_humanoid_block_robot.png differ diff --git a/assets/example_image/typical_humanoid_dragonborn.png b/assets/example_image/typical_humanoid_dragonborn.png new file mode 100644 index 0000000000000000000000000000000000000000..61ca2d9e69634c12ee9ae6f7e77f84839df83fdb Binary files /dev/null and b/assets/example_image/typical_humanoid_dragonborn.png differ diff --git a/assets/example_image/typical_humanoid_dwarf.png b/assets/example_image/typical_humanoid_dwarf.png new file mode 100644 index 0000000000000000000000000000000000000000..16de1631fff3cc42a3a5d6a8b0f638da75ad7b2f Binary files /dev/null and b/assets/example_image/typical_humanoid_dwarf.png differ diff --git a/assets/example_image/typical_humanoid_goblin.png b/assets/example_image/typical_humanoid_goblin.png new file mode 100644 index 0000000000000000000000000000000000000000..4e4fe04517801d5722817e8dfaed2af83b31d67e Binary files /dev/null and b/assets/example_image/typical_humanoid_goblin.png differ diff --git a/assets/example_image/typical_humanoid_mech.png b/assets/example_image/typical_humanoid_mech.png new file mode 100644 index 0000000000000000000000000000000000000000..f0fbbdf6cda5636f517b6e2fa3f20e15e56e3777 Binary files /dev/null and b/assets/example_image/typical_humanoid_mech.png differ diff --git a/assets/example_image/typical_misc_crate.png b/assets/example_image/typical_misc_crate.png new file mode 100644 index 0000000000000000000000000000000000000000..c3086f885bf9fc27c398b5bacfb04a65bd7dfbd9 Binary files /dev/null and b/assets/example_image/typical_misc_crate.png differ diff --git a/assets/example_image/typical_misc_fireplace.png b/assets/example_image/typical_misc_fireplace.png new file mode 100644 index 0000000000000000000000000000000000000000..82d79bc10346604a8b8b9cc8e2c317e8dc6d8c47 Binary files /dev/null and b/assets/example_image/typical_misc_fireplace.png differ diff --git a/assets/example_image/typical_misc_gate.png b/assets/example_image/typical_misc_gate.png new file mode 100644 index 0000000000000000000000000000000000000000..fa77919f9d9faabc26b9287b35c4dd3b4006163e Binary files /dev/null and b/assets/example_image/typical_misc_gate.png differ diff --git a/assets/example_image/typical_misc_lantern.png b/assets/example_image/typical_misc_lantern.png new file mode 100644 index 0000000000000000000000000000000000000000..4c93f5dea2638a5a169dd1557a36f6d34b57144d Binary files /dev/null and b/assets/example_image/typical_misc_lantern.png differ diff --git a/assets/example_image/typical_misc_magicbook.png b/assets/example_image/typical_misc_magicbook.png new file mode 100644 index 0000000000000000000000000000000000000000..7dc521a10fda176694c30170811809050f478a66 Binary files /dev/null and b/assets/example_image/typical_misc_magicbook.png differ diff --git a/assets/example_image/typical_misc_mailbox.png b/assets/example_image/typical_misc_mailbox.png new file mode 100644 index 0000000000000000000000000000000000000000..b6e8bc50cd270bb7462eee2af7a6d5649ef54cf2 Binary files /dev/null and b/assets/example_image/typical_misc_mailbox.png differ diff --git a/assets/example_image/typical_misc_monster_chest.png b/assets/example_image/typical_misc_monster_chest.png new file mode 100644 index 0000000000000000000000000000000000000000..6d544370fa306138e7dbab3e548d6e05b8ef2317 Binary files /dev/null and b/assets/example_image/typical_misc_monster_chest.png differ diff --git a/assets/example_image/typical_misc_paper_machine.png b/assets/example_image/typical_misc_paper_machine.png new file mode 100644 index 0000000000000000000000000000000000000000..a630074dbfe32c53f52f2f27e5b6b3eff8469a9e Binary files /dev/null and b/assets/example_image/typical_misc_paper_machine.png differ diff --git a/assets/example_image/typical_misc_phonograph.png b/assets/example_image/typical_misc_phonograph.png new file mode 100644 index 0000000000000000000000000000000000000000..668662d741344ac16427259fc966186ef8ca97a9 Binary files /dev/null and b/assets/example_image/typical_misc_phonograph.png differ diff --git a/assets/example_image/typical_misc_portal2.png b/assets/example_image/typical_misc_portal2.png new file mode 100644 index 0000000000000000000000000000000000000000..666daa75fbaf7df55585f7143906d158175be6be Binary files /dev/null and b/assets/example_image/typical_misc_portal2.png differ diff --git a/assets/example_image/typical_misc_storage_chest.png b/assets/example_image/typical_misc_storage_chest.png new file mode 100644 index 0000000000000000000000000000000000000000..38f4bd31f8eb62badcc5e1a51d4612e528b4069e Binary files /dev/null and b/assets/example_image/typical_misc_storage_chest.png differ diff --git a/assets/example_image/typical_misc_telephone.png b/assets/example_image/typical_misc_telephone.png new file mode 100644 index 0000000000000000000000000000000000000000..a0a7d65a300d9f1adc55b2fc36951731f5abb355 Binary files /dev/null and b/assets/example_image/typical_misc_telephone.png differ diff --git a/assets/example_image/typical_misc_television.png b/assets/example_image/typical_misc_television.png new file mode 100644 index 0000000000000000000000000000000000000000..1d6b5882b42ce532f6a60080ad55bda7053530c0 Binary files /dev/null and b/assets/example_image/typical_misc_television.png differ diff --git a/assets/example_image/typical_misc_workbench.png b/assets/example_image/typical_misc_workbench.png new file mode 100644 index 0000000000000000000000000000000000000000..88024f960ff56aa619b0c496f85de390076bbf5a Binary files /dev/null and b/assets/example_image/typical_misc_workbench.png differ diff --git a/assets/example_image/typical_vehicle_biplane.png b/assets/example_image/typical_vehicle_biplane.png new file mode 100644 index 0000000000000000000000000000000000000000..7427cad3270d8ed33dad05c7a2ae1b0092b4beb2 Binary files /dev/null and b/assets/example_image/typical_vehicle_biplane.png differ diff --git a/assets/example_image/typical_vehicle_bulldozer.png b/assets/example_image/typical_vehicle_bulldozer.png new file mode 100644 index 0000000000000000000000000000000000000000..17ffe389498d9561ef92766de654ef17b5755f60 Binary files /dev/null and b/assets/example_image/typical_vehicle_bulldozer.png differ diff --git a/assets/example_image/typical_vehicle_cart.png b/assets/example_image/typical_vehicle_cart.png new file mode 100644 index 0000000000000000000000000000000000000000..137bb4887f3879691ffff21227951790eb1840b4 Binary files /dev/null and b/assets/example_image/typical_vehicle_cart.png differ diff --git a/assets/example_image/typical_vehicle_excavator.png b/assets/example_image/typical_vehicle_excavator.png new file mode 100644 index 0000000000000000000000000000000000000000..c434e8b0ab142ecc35caf91d42df5f4541825b8f Binary files /dev/null and b/assets/example_image/typical_vehicle_excavator.png differ diff --git a/assets/example_image/typical_vehicle_helicopter.png b/assets/example_image/typical_vehicle_helicopter.png new file mode 100644 index 0000000000000000000000000000000000000000..39c2497d22ea519cddf576c6504954c338f943e4 Binary files /dev/null and b/assets/example_image/typical_vehicle_helicopter.png differ diff --git a/assets/example_image/typical_vehicle_locomotive.png b/assets/example_image/typical_vehicle_locomotive.png new file mode 100644 index 0000000000000000000000000000000000000000..dac6a2a2de9e8830bac53d3893aa1d3741916b1a Binary files /dev/null and b/assets/example_image/typical_vehicle_locomotive.png differ diff --git a/assets/example_image/typical_vehicle_pirate_ship.png b/assets/example_image/typical_vehicle_pirate_ship.png new file mode 100644 index 0000000000000000000000000000000000000000..9eed1529f309c64fc6237caba97631cc1f2bab53 Binary files /dev/null and b/assets/example_image/typical_vehicle_pirate_ship.png differ diff --git a/assets/example_image/weatherworn_misc_paper_machine3.png b/assets/example_image/weatherworn_misc_paper_machine3.png new file mode 100644 index 0000000000000000000000000000000000000000..46e8a9dc123aaf71a41e994a8e50eabb4e53e721 Binary files /dev/null and b/assets/example_image/weatherworn_misc_paper_machine3.png differ diff --git a/assets/example_multi_image/character_1.png b/assets/example_multi_image/character_1.png new file mode 100644 index 0000000000000000000000000000000000000000..743117c4458af4a9db717c7c7c6b05b6b08037dc Binary files /dev/null and b/assets/example_multi_image/character_1.png differ diff --git a/assets/example_multi_image/character_2.png b/assets/example_multi_image/character_2.png new file mode 100644 index 0000000000000000000000000000000000000000..5fc37f61179b2dc4293a40bec2959ac82fd7503c Binary files /dev/null and b/assets/example_multi_image/character_2.png differ diff --git a/assets/example_multi_image/character_3.png b/assets/example_multi_image/character_3.png new file mode 100644 index 0000000000000000000000000000000000000000..c6e8cb9fb2ab3e86e749405f96ee026273e1e99b Binary files /dev/null and b/assets/example_multi_image/character_3.png differ diff --git a/assets/example_multi_image/mushroom_1.png b/assets/example_multi_image/mushroom_1.png new file mode 100644 index 0000000000000000000000000000000000000000..982645f790b3c374ccab05f33371b2979b5a4031 Binary files /dev/null and b/assets/example_multi_image/mushroom_1.png differ diff --git a/assets/example_multi_image/mushroom_2.png b/assets/example_multi_image/mushroom_2.png new file mode 100644 index 0000000000000000000000000000000000000000..48e961e28e6996bce3259b114da8c6aced35e777 Binary files /dev/null and b/assets/example_multi_image/mushroom_2.png differ diff --git a/assets/example_multi_image/mushroom_3.png b/assets/example_multi_image/mushroom_3.png new file mode 100644 index 0000000000000000000000000000000000000000..16f2022458343da9a084013d8dba92faad0c2103 Binary files /dev/null and b/assets/example_multi_image/mushroom_3.png differ diff --git a/assets/example_multi_image/orangeguy_1.png b/assets/example_multi_image/orangeguy_1.png new file mode 100644 index 0000000000000000000000000000000000000000..5221021f6bde7d3fcb25fe0002ca3de528e4443d Binary files /dev/null and b/assets/example_multi_image/orangeguy_1.png differ diff --git a/assets/example_multi_image/orangeguy_2.png b/assets/example_multi_image/orangeguy_2.png new file mode 100644 index 0000000000000000000000000000000000000000..156aea1f404d7748b73694c34037b1b13d87db66 Binary files /dev/null and b/assets/example_multi_image/orangeguy_2.png differ diff --git a/assets/example_multi_image/orangeguy_3.png b/assets/example_multi_image/orangeguy_3.png new file mode 100644 index 0000000000000000000000000000000000000000..0c15598b97acb46301876c4cdde34aacc8580a68 Binary files /dev/null and b/assets/example_multi_image/orangeguy_3.png differ diff --git a/assets/example_multi_image/popmart_1.png b/assets/example_multi_image/popmart_1.png new file mode 100644 index 0000000000000000000000000000000000000000..81f1838f3a3698441d3d33bf04f0fce5dcdf05f3 Binary files /dev/null and b/assets/example_multi_image/popmart_1.png differ diff --git a/assets/example_multi_image/popmart_2.png b/assets/example_multi_image/popmart_2.png new file mode 100644 index 0000000000000000000000000000000000000000..ac6fdf3c53aa95dd0763a4e865ba7583768f2ac8 Binary files /dev/null and b/assets/example_multi_image/popmart_2.png differ diff --git a/assets/example_multi_image/popmart_3.png b/assets/example_multi_image/popmart_3.png new file mode 100644 index 0000000000000000000000000000000000000000..c83ea960e3aa151151260427d10fc5671619cbee Binary files /dev/null and b/assets/example_multi_image/popmart_3.png differ diff --git a/assets/example_multi_image/rabbit_1.png b/assets/example_multi_image/rabbit_1.png new file mode 100644 index 0000000000000000000000000000000000000000..0cd5708a752cb3951d5edb41165d23f1246955e1 Binary files /dev/null and b/assets/example_multi_image/rabbit_1.png differ diff --git a/assets/example_multi_image/rabbit_2.png b/assets/example_multi_image/rabbit_2.png new file mode 100644 index 0000000000000000000000000000000000000000..95492498199904f299a88ba06a80fc0742f874f9 Binary files /dev/null and b/assets/example_multi_image/rabbit_2.png differ diff --git a/assets/example_multi_image/rabbit_3.png b/assets/example_multi_image/rabbit_3.png new file mode 100644 index 0000000000000000000000000000000000000000..a83285e29702f9680cd5530fa2ab9526e7812352 Binary files /dev/null and b/assets/example_multi_image/rabbit_3.png differ diff --git a/assets/example_multi_image/tiger_1.png b/assets/example_multi_image/tiger_1.png new file mode 100644 index 0000000000000000000000000000000000000000..c4f87f93b63873a81c1e2bda18937a165d49a773 Binary files /dev/null and b/assets/example_multi_image/tiger_1.png differ diff --git a/assets/example_multi_image/tiger_2.png b/assets/example_multi_image/tiger_2.png new file mode 100644 index 0000000000000000000000000000000000000000..8fb9818ab2c6920be720811c04babb4024372c40 Binary files /dev/null and b/assets/example_multi_image/tiger_2.png differ diff --git a/assets/example_multi_image/tiger_3.png b/assets/example_multi_image/tiger_3.png new file mode 100644 index 0000000000000000000000000000000000000000..53689b9b2e3deeeb968628f6fcb636cf6d1223a4 Binary files /dev/null and b/assets/example_multi_image/tiger_3.png differ diff --git a/assets/example_multi_image/yoimiya_1.png b/assets/example_multi_image/yoimiya_1.png new file mode 100644 index 0000000000000000000000000000000000000000..da323f970a288542665e316a8447b1cccf54998d Binary files /dev/null and b/assets/example_multi_image/yoimiya_1.png differ diff --git a/assets/example_multi_image/yoimiya_2.png b/assets/example_multi_image/yoimiya_2.png new file mode 100644 index 0000000000000000000000000000000000000000..d38d854fc264025034ded35f3363e95cf509a0b4 Binary files /dev/null and b/assets/example_multi_image/yoimiya_2.png differ diff --git a/assets/example_multi_image/yoimiya_3.png b/assets/example_multi_image/yoimiya_3.png new file mode 100644 index 0000000000000000000000000000000000000000..f2c8a7ca4085badda0fede3e4ca92dac1565ed24 Binary files /dev/null and b/assets/example_multi_image/yoimiya_3.png differ diff --git a/extensions/nvdiffrast/LICENSE.txt b/extensions/nvdiffrast/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..9ef3a30911b6676b2fefa11e4e6044beb5b0bc8e --- /dev/null +++ b/extensions/nvdiffrast/LICENSE.txt @@ -0,0 +1,97 @@ +Copyright (c) 2020, NVIDIA Corporation. All rights reserved. + + +Nvidia Source Code License (1-Way Commercial) + +======================================================================= + +1. Definitions + +"Licensor" means any person or entity that distributes its Work. + +"Software" means the original work of authorship made available under +this License. + +"Work" means the Software and any additions to or derivative works of +the Software that are made available under this License. + +The terms "reproduce," "reproduction," "derivative works," and +"distribution" have the meaning as provided under U.S. copyright law; +provided, however, that for the purposes of this License, derivative +works shall not include works that remain separable from, or merely +link (or bind by name) to the interfaces of, the Work. + +Works, including the Software, are "made available" under this License +by including in or with the Work either (a) a copyright notice +referencing the applicability of this License to the Work, or (b) a +copy of this License. + +2. License Grants + + 2.1 Copyright Grant. Subject to the terms and conditions of this + License, each Licensor grants to you a perpetual, worldwide, + non-exclusive, royalty-free, copyright license to reproduce, + prepare derivative works of, publicly display, publicly perform, + sublicense and distribute its Work and any resulting derivative + works in any form. + +3. Limitations + + 3.1 Redistribution. You may reproduce or distribute the Work only + if (a) you do so under this License, (b) you include a complete + copy of this License with your distribution, and (c) you retain + without modification any copyright, patent, trademark, or + attribution notices that are present in the Work. + + 3.2 Derivative Works. You may specify that additional or different + terms apply to the use, reproduction, and distribution of your + derivative works of the Work ("Your Terms") only if (a) Your Terms + provide that the use limitation in Section 3.3 applies to your + derivative works, and (b) you identify the specific derivative + works that are subject to Your Terms. Notwithstanding Your Terms, + this License (including the redistribution requirements in Section + 3.1) will continue to apply to the Work itself. + + 3.3 Use Limitation. The Work and any derivative works thereof only + may be used or intended for use non-commercially. The Work or + derivative works thereof may be used or intended for use by Nvidia + or its affiliates commercially or non-commercially. As used herein, + "non-commercially" means for research or evaluation purposes only + and not for any direct or indirect monetary gain. + + 3.4 Patent Claims. If you bring or threaten to bring a patent claim + against any Licensor (including any claim, cross-claim or + counterclaim in a lawsuit) to enforce any patents that you allege + are infringed by any Work, then your rights under this License from + such Licensor (including the grant in Section 2.1) will terminate + immediately. + + 3.5 Trademarks. This License does not grant any rights to use any + Licensor's or its affiliates' names, logos, or trademarks, except + as necessary to reproduce the notices described in this License. + + 3.6 Termination. If you violate any term of this License, then your + rights under this License (including the grant in Section 2.1) will + terminate immediately. + +4. Disclaimer of Warranty. + +THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR +NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER +THIS LICENSE. + +5. Limitation of Liability. + +EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL +THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE +SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, +INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF +OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK +(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, +LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER +COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF +THE POSSIBILITY OF SUCH DAMAGES. + +======================================================================= diff --git a/extensions/nvdiffrast/README.md b/extensions/nvdiffrast/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c17baacc4209f34d3d74efdbd7cbb080d0451ca9 --- /dev/null +++ b/extensions/nvdiffrast/README.md @@ -0,0 +1,42 @@ +## Nvdiffrast – Modular Primitives for High-Performance Differentiable Rendering + +![Teaser image](./docs/img/teaser.png) + +**Modular Primitives for High-Performance Differentiable Rendering**
+Samuli Laine, Janne Hellsten, Tero Karras, Yeongho Seol, Jaakko Lehtinen, Timo Aila
+[http://arxiv.org/abs/2011.03277](http://arxiv.org/abs/2011.03277) + +Nvdiffrast is a PyTorch/TensorFlow library that provides high-performance primitive operations for rasterization-based differentiable rendering. +Please refer to ☞☞ [nvdiffrast documentation](https://nvlabs.github.io/nvdiffrast) ☜☜ for more information. + +## Licenses + +Copyright © 2020–2024, NVIDIA Corporation. All rights reserved. + +This work is made available under the [Nvidia Source Code License](https://github.com/NVlabs/nvdiffrast/blob/main/LICENSE.txt). + +For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/) + +We do not currently accept outside code contributions in the form of pull requests. + +Environment map stored as part of `samples/data/envphong.npz` is derived from a Wave Engine +[sample material](https://github.com/WaveEngine/Samples-2.5/tree/master/Materials/EnvironmentMap/Content/Assets/CubeMap.cubemap) +originally shared under +[MIT License](https://github.com/WaveEngine/Samples-2.5/blob/master/LICENSE.md). +Mesh and texture stored as part of `samples/data/earth.npz` are derived from +[3D Earth Photorealistic 2K](https://www.turbosquid.com/3d-models/3d-realistic-earth-photorealistic-2k-1279125) +model originally made available under +[TurboSquid 3D Model License](https://blog.turbosquid.com/turbosquid-3d-model-license/#3d-model-license). + +## Citation + +``` +@article{Laine2020diffrast, + title = {Modular Primitives for High-Performance Differentiable Rendering}, + author = {Samuli Laine and Janne Hellsten and Tero Karras and Yeongho Seol and Jaakko Lehtinen and Timo Aila}, + journal = {ACM Transactions on Graphics}, + year = {2020}, + volume = {39}, + number = {6} +} +``` diff --git a/extensions/nvdiffrast/nvdiffrast/__init__.py b/extensions/nvdiffrast/nvdiffrast/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f3d077dd95e58d52169d5699b6d3ede78a309c21 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +__version__ = '0.3.3' diff --git a/extensions/nvdiffrast/nvdiffrast/common/antialias.cu b/extensions/nvdiffrast/nvdiffrast/common/antialias.cu new file mode 100644 index 0000000000000000000000000000000000000000..a62c1b34bc63788bf2e8f960f8bb2d7d2c9d1d45 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/antialias.cu @@ -0,0 +1,558 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "antialias.h" + +//------------------------------------------------------------------------ +// Helpers. + +#define F32_MAX (3.402823466e+38f) +static __forceinline__ __device__ bool same_sign(float a, float b) { return (__float_as_int(a) ^ __float_as_int(b)) >= 0; } +static __forceinline__ __device__ bool rational_gt(float n0, float n1, float d0, float d1) { return (n0*d1 > n1*d0) == same_sign(d0, d1); } +static __forceinline__ __device__ int max_idx3(float n0, float n1, float n2, float d0, float d1, float d2) +{ + bool g10 = rational_gt(n1, n0, d1, d0); + bool g20 = rational_gt(n2, n0, d2, d0); + bool g21 = rational_gt(n2, n1, d2, d1); + if (g20 && g21) return 2; + if (g10) return 1; + return 0; +} + +//------------------------------------------------------------------------ +// Format of antialiasing work items stored in work buffer. Usually accessed directly as int4. + +struct AAWorkItem +{ + enum + { + EDGE_MASK = 3, // Edge index in lowest bits. + FLAG_DOWN_BIT = 2, // Down instead of right. + FLAG_TRI1_BIT = 3, // Edge is from other pixel's triangle. + }; + + int px, py; // Pixel x, y. + unsigned int pz_flags; // High 16 bits = pixel z, low 16 bits = edge index and flags. + float alpha; // Antialiasing alpha value. Zero if no AA. +}; + +//------------------------------------------------------------------------ +// Hash functions. Adapted from public-domain code at http://www.burtleburtle.net/bob/hash/doobs.html + +#define JENKINS_MAGIC (0x9e3779b9u) +static __device__ __forceinline__ void jenkins_mix(unsigned int& a, unsigned int& b, unsigned int& c) +{ + a -= b; a -= c; a ^= (c>>13); + b -= c; b -= a; b ^= (a<<8); + c -= a; c -= b; c ^= (b>>13); + a -= b; a -= c; a ^= (c>>12); + b -= c; b -= a; b ^= (a<<16); + c -= a; c -= b; c ^= (b>>5); + a -= b; a -= c; a ^= (c>>3); + b -= c; b -= a; b ^= (a<<10); + c -= a; c -= b; c ^= (b>>15); +} + +// Helper class for hash index iteration. Implements simple odd-skip linear probing with a key-dependent skip. +class HashIndex +{ +public: + __device__ __forceinline__ HashIndex(const AntialiasKernelParams& p, uint64_t key) + { + m_mask = (p.allocTriangles << AA_LOG_HASH_ELEMENTS_PER_TRIANGLE(p.allocTriangles)) - 1; // This should work until triangle count exceeds 1073741824. + m_idx = (uint32_t)(key & 0xffffffffu); + m_skip = (uint32_t)(key >> 32); + uint32_t dummy = JENKINS_MAGIC; + jenkins_mix(m_idx, m_skip, dummy); + m_idx &= m_mask; + m_skip &= m_mask; + m_skip |= 1; + } + __device__ __forceinline__ int get(void) const { return m_idx; } + __device__ __forceinline__ void next(void) { m_idx = (m_idx + m_skip) & m_mask; } +private: + uint32_t m_idx, m_skip, m_mask; +}; + +static __device__ __forceinline__ void hash_insert(const AntialiasKernelParams& p, uint64_t key, int v) +{ + HashIndex idx(p, key); + while(1) + { + uint64_t prev = atomicCAS((unsigned long long*)&p.evHash[idx.get()], 0, (unsigned long long)key); + if (prev == 0 || prev == key) + break; + idx.next(); + } + int* q = (int*)&p.evHash[idx.get()]; + int a = atomicCAS(q+2, 0, v); + if (a != 0 && a != v) + atomicCAS(q+3, 0, v); +} + +static __device__ __forceinline__ int2 hash_find(const AntialiasKernelParams& p, uint64_t key) +{ + HashIndex idx(p, key); + while(1) + { + uint4 entry = p.evHash[idx.get()]; + uint64_t k = ((uint64_t)entry.x) | (((uint64_t)entry.y) << 32); + if (k == key || k == 0) + return make_int2((int)entry.z, (int)entry.w); + idx.next(); + } +} + +static __device__ __forceinline__ void evhash_insert_vertex(const AntialiasKernelParams& p, int va, int vb, int vn) +{ + if (va == vb) + return; + + uint64_t v0 = (uint32_t)min(va, vb) + 1; // canonical vertex order + uint64_t v1 = (uint32_t)max(va, vb) + 1; + uint64_t vk = v0 | (v1 << 32); // hash key + hash_insert(p, vk, vn + 1); +} + +static __forceinline__ __device__ int evhash_find_vertex(const AntialiasKernelParams& p, int va, int vb, int vr) +{ + if (va == vb) + return -1; + + uint64_t v0 = (uint32_t)min(va, vb) + 1; // canonical vertex order + uint64_t v1 = (uint32_t)max(va, vb) + 1; + uint64_t vk = v0 | (v1 << 32); // hash key + int2 vn = hash_find(p, vk) - 1; + if (vn.x == vr) return vn.y; + if (vn.y == vr) return vn.x; + return -1; +} + +//------------------------------------------------------------------------ +// Mesh analysis kernel. + +__global__ void AntialiasFwdMeshKernel(const AntialiasKernelParams p) +{ + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= p.numTriangles) + return; + + int v0 = p.tri[idx * 3 + 0]; + int v1 = p.tri[idx * 3 + 1]; + int v2 = p.tri[idx * 3 + 2]; + + if (v0 < 0 || v0 >= p.numVertices || + v1 < 0 || v1 >= p.numVertices || + v2 < 0 || v2 >= p.numVertices) + return; + + if (v0 == v1 || v1 == v2 || v2 == v0) + return; + + evhash_insert_vertex(p, v1, v2, v0); + evhash_insert_vertex(p, v2, v0, v1); + evhash_insert_vertex(p, v0, v1, v2); +} + +//------------------------------------------------------------------------ +// Discontinuity finder kernel. + +__global__ void AntialiasFwdDiscontinuityKernel(const AntialiasKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * AA_DISCONTINUITY_KERNEL_BLOCK_WIDTH + threadIdx.x; + int py = blockIdx.y * AA_DISCONTINUITY_KERNEL_BLOCK_HEIGHT + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.width || py >= p.height || pz >= p.n) + return; + + // Pointer to our TriIdx and fetch. + int pidx0 = ((px + p.width * (py + p.height * pz)) << 2) + 3; + float tri0 = p.rasterOut[pidx0]; // These can stay as float, as we only compare them against each other. + + // Look right, clamp at edge. + int pidx1 = pidx0; + if (px < p.width - 1) + pidx1 += 4; + float tri1 = p.rasterOut[pidx1]; + + // Look down, clamp at edge. + int pidx2 = pidx0; + if (py < p.height - 1) + pidx2 += p.width << 2; + float tri2 = p.rasterOut[pidx2]; + + // Determine amount of work. + int count = 0; + if (tri1 != tri0) count = 1; + if (tri2 != tri0) count += 1; + if (!count) + return; // Exit warp. + + // Coalesce work counter update to once per CTA. + __shared__ int s_temp; + s_temp = 0; + __syncthreads(); + int idx = atomicAdd(&s_temp, count); + __syncthreads(); + if (idx == 0) + { + int base = atomicAdd(&p.workBuffer[0].x, s_temp); + s_temp = base + 1; // don't clobber the counters in first slot. + } + __syncthreads(); + idx += s_temp; + + // Write to memory. + if (tri1 != tri0) p.workBuffer[idx++] = make_int4(px, py, (pz << 16), 0); + if (tri2 != tri0) p.workBuffer[idx] = make_int4(px, py, (pz << 16) + (1 << AAWorkItem::FLAG_DOWN_BIT), 0); +} + +//------------------------------------------------------------------------ +// Forward analysis kernel. + +__global__ void AntialiasFwdAnalysisKernel(const AntialiasKernelParams p) +{ + __shared__ int s_base; + int workCount = p.workBuffer[0].x; + for(;;) + { + // Persistent threads work fetcher. + __syncthreads(); + if (threadIdx.x == 0) + s_base = atomicAdd(&p.workBuffer[0].y, AA_ANALYSIS_KERNEL_THREADS_PER_BLOCK); + __syncthreads(); + int thread_idx = s_base + threadIdx.x; + if (thread_idx >= workCount) + return; + + int4* pItem = p.workBuffer + thread_idx + 1; + int4 item = *pItem; + int px = item.x; + int py = item.y; + int pz = (int)(((unsigned int)item.z) >> 16); + int d = (item.z >> AAWorkItem::FLAG_DOWN_BIT) & 1; + + int pixel0 = px + p.width * (py + p.height * pz); + int pixel1 = pixel0 + (d ? p.width : 1); + float2 zt0 = ((float2*)p.rasterOut)[(pixel0 << 1) + 1]; + float2 zt1 = ((float2*)p.rasterOut)[(pixel1 << 1) + 1]; + int tri0 = float_to_triidx(zt0.y) - 1; + int tri1 = float_to_triidx(zt1.y) - 1; + + // Select triangle based on background / depth. + int tri = (tri0 >= 0) ? tri0 : tri1; + if (tri0 >= 0 && tri1 >= 0) + tri = (zt0.x < zt1.x) ? tri0 : tri1; + if (tri == tri1) + { + // Calculate with respect to neighbor pixel if chose that triangle. + px += 1 - d; + py += d; + } + + // Bail out if triangle index is corrupt. + if (tri < 0 || tri >= p.numTriangles) + continue; + + // Fetch vertex indices. + int vi0 = p.tri[tri * 3 + 0]; + int vi1 = p.tri[tri * 3 + 1]; + int vi2 = p.tri[tri * 3 + 2]; + + // Bail out if vertex indices are corrupt. + if (vi0 < 0 || vi0 >= p.numVertices || + vi1 < 0 || vi1 >= p.numVertices || + vi2 < 0 || vi2 >= p.numVertices) + continue; + + // Fetch opposite vertex indices. Use vertex itself (always silhouette) if no opposite vertex exists. + int op0 = evhash_find_vertex(p, vi2, vi1, vi0); + int op1 = evhash_find_vertex(p, vi0, vi2, vi1); + int op2 = evhash_find_vertex(p, vi1, vi0, vi2); + + // Instance mode: Adjust vertex indices based on minibatch index. + if (p.instance_mode) + { + int vbase = pz * p.numVertices; + vi0 += vbase; + vi1 += vbase; + vi2 += vbase; + if (op0 >= 0) op0 += vbase; + if (op1 >= 0) op1 += vbase; + if (op2 >= 0) op2 += vbase; + } + + // Fetch vertex positions. + float4 p0 = ((float4*)p.pos)[vi0]; + float4 p1 = ((float4*)p.pos)[vi1]; + float4 p2 = ((float4*)p.pos)[vi2]; + float4 o0 = (op0 < 0) ? p0 : ((float4*)p.pos)[op0]; + float4 o1 = (op1 < 0) ? p1 : ((float4*)p.pos)[op1]; + float4 o2 = (op2 < 0) ? p2 : ((float4*)p.pos)[op2]; + + // Project vertices to pixel space. + float w0 = 1.f / p0.w; + float w1 = 1.f / p1.w; + float w2 = 1.f / p2.w; + float ow0 = 1.f / o0.w; + float ow1 = 1.f / o1.w; + float ow2 = 1.f / o2.w; + float fx = (float)px + .5f - p.xh; + float fy = (float)py + .5f - p.yh; + float x0 = p0.x * w0 * p.xh - fx; + float y0 = p0.y * w0 * p.yh - fy; + float x1 = p1.x * w1 * p.xh - fx; + float y1 = p1.y * w1 * p.yh - fy; + float x2 = p2.x * w2 * p.xh - fx; + float y2 = p2.y * w2 * p.yh - fy; + float ox0 = o0.x * ow0 * p.xh - fx; + float oy0 = o0.y * ow0 * p.yh - fy; + float ox1 = o1.x * ow1 * p.xh - fx; + float oy1 = o1.y * ow1 * p.yh - fy; + float ox2 = o2.x * ow2 * p.xh - fx; + float oy2 = o2.y * ow2 * p.yh - fy; + + // Signs to kill non-silhouette edges. + float bb = (x1-x0)*(y2-y0) - (x2-x0)*(y1-y0); // Triangle itself. + float a0 = (x1-ox0)*(y2-oy0) - (x2-ox0)*(y1-oy0); // Wings. + float a1 = (x2-ox1)*(y0-oy1) - (x0-ox1)*(y2-oy1); + float a2 = (x0-ox2)*(y1-oy2) - (x1-ox2)*(y0-oy2); + + // If no matching signs anywhere, skip the rest. + if (same_sign(a0, bb) || same_sign(a1, bb) || same_sign(a2, bb)) + { + // XY flip for horizontal edges. + if (d) + { + swap(x0, y0); + swap(x1, y1); + swap(x2, y2); + } + + float dx0 = x2 - x1; + float dx1 = x0 - x2; + float dx2 = x1 - x0; + float dy0 = y2 - y1; + float dy1 = y0 - y2; + float dy2 = y1 - y0; + + // Check if an edge crosses between us and the neighbor pixel. + float dc = -F32_MAX; + float ds = (tri == tri0) ? 1.f : -1.f; + float d0 = ds * (x1*dy0 - y1*dx0); + float d1 = ds * (x2*dy1 - y2*dx1); + float d2 = ds * (x0*dy2 - y0*dx2); + + if (same_sign(y1, y2)) d0 = -F32_MAX, dy0 = 1.f; + if (same_sign(y2, y0)) d1 = -F32_MAX, dy1 = 1.f; + if (same_sign(y0, y1)) d2 = -F32_MAX, dy2 = 1.f; + + int di = max_idx3(d0, d1, d2, dy0, dy1, dy2); + if (di == 0 && same_sign(a0, bb) && fabsf(dy0) >= fabsf(dx0)) dc = d0 / dy0; + if (di == 1 && same_sign(a1, bb) && fabsf(dy1) >= fabsf(dx1)) dc = d1 / dy1; + if (di == 2 && same_sign(a2, bb) && fabsf(dy2) >= fabsf(dx2)) dc = d2 / dy2; + float eps = .0625f; // Expect no more than 1/16 pixel inaccuracy. + + // Adjust output image if a suitable edge was found. + if (dc > -eps && dc < 1.f + eps) + { + dc = fminf(fmaxf(dc, 0.f), 1.f); + float alpha = ds * (.5f - dc); + const float* pColor0 = p.color + pixel0 * p.channels; + const float* pColor1 = p.color + pixel1 * p.channels; + float* pOutput = p.output + (alpha > 0.f ? pixel0 : pixel1) * p.channels; + for (int i=0; i < p.channels; i++) + atomicAdd(&pOutput[i], alpha * (pColor1[i] - pColor0[i])); + + // Rewrite the work item's flags and alpha. Keep original px, py. + unsigned int flags = pz << 16; + flags |= di; + flags |= d << AAWorkItem::FLAG_DOWN_BIT; + flags |= (__float_as_uint(ds) >> 31) << AAWorkItem::FLAG_TRI1_BIT; + ((int2*)pItem)[1] = make_int2(flags, __float_as_int(alpha)); + } + } + } +} + +//------------------------------------------------------------------------ +// Gradient kernel. + +__global__ void AntialiasGradKernel(const AntialiasKernelParams p) +{ + // Temporary space for coalesced atomics. + CA_DECLARE_TEMP(AA_GRAD_KERNEL_THREADS_PER_BLOCK); + __shared__ int s_base; // Work counter communication across entire CTA. + + int workCount = p.workBuffer[0].x; + + for(;;) + { + // Persistent threads work fetcher. + __syncthreads(); + if (threadIdx.x == 0) + s_base = atomicAdd(&p.workBuffer[0].y, AA_GRAD_KERNEL_THREADS_PER_BLOCK); + __syncthreads(); + int thread_idx = s_base + threadIdx.x; + if (thread_idx >= workCount) + return; + + // Read work item filled out by forward kernel. + int4 item = p.workBuffer[thread_idx + 1]; + unsigned int amask = __ballot_sync(0xffffffffu, item.w); + if (item.w == 0) + continue; // No effect. + + // Unpack work item and replicate setup from forward analysis kernel. + int px = item.x; + int py = item.y; + int pz = (int)(((unsigned int)item.z) >> 16); + int d = (item.z >> AAWorkItem::FLAG_DOWN_BIT) & 1; + float alpha = __int_as_float(item.w); + int tri1 = (item.z >> AAWorkItem::FLAG_TRI1_BIT) & 1; + int di = item.z & AAWorkItem::EDGE_MASK; + float ds = __int_as_float(__float_as_int(1.0) | (tri1 << 31)); + int pixel0 = px + p.width * (py + p.height * pz); + int pixel1 = pixel0 + (d ? p.width : 1); + int tri = float_to_triidx(p.rasterOut[((tri1 ? pixel1 : pixel0) << 2) + 3]) - 1; + if (tri1) + { + px += 1 - d; + py += d; + } + + // Bail out if triangle index is corrupt. + bool triFail = (tri < 0 || tri >= p.numTriangles); + amask = __ballot_sync(amask, !triFail); + if (triFail) + continue; + + // Outgoing color gradients. + float* pGrad0 = p.gradColor + pixel0 * p.channels; + float* pGrad1 = p.gradColor + pixel1 * p.channels; + + // Incoming color gradients. + const float* pDy = p.dy + (alpha > 0.f ? pixel0 : pixel1) * p.channels; + + // Position gradient weight based on colors and incoming gradients. + float dd = 0.f; + const float* pColor0 = p.color + pixel0 * p.channels; + const float* pColor1 = p.color + pixel1 * p.channels; + + // Loop over channels and accumulate. + for (int i=0; i < p.channels; i++) + { + float dy = pDy[i]; + if (dy != 0.f) + { + // Update position gradient weight. + dd += dy * (pColor1[i] - pColor0[i]); + + // Update color gradients. No coalescing because all have different targets. + float v = alpha * dy; + atomicAdd(&pGrad0[i], -v); + atomicAdd(&pGrad1[i], v); + } + } + + // If position weight is zero, skip the rest. + bool noGrad = (dd == 0.f); + amask = __ballot_sync(amask, !noGrad); + if (noGrad) + continue; + + // Fetch vertex indices of the active edge and their positions. + int i1 = (di < 2) ? (di + 1) : 0; + int i2 = (i1 < 2) ? (i1 + 1) : 0; + int vi1 = p.tri[3 * tri + i1]; + int vi2 = p.tri[3 * tri + i2]; + + // Bail out if vertex indices are corrupt. + bool vtxFail = (vi1 < 0 || vi1 >= p.numVertices || vi2 < 0 || vi2 >= p.numVertices); + amask = __ballot_sync(amask, !vtxFail); + if (vtxFail) + continue; + + // Instance mode: Adjust vertex indices based on minibatch index. + if (p.instance_mode) + { + vi1 += pz * p.numVertices; + vi2 += pz * p.numVertices; + } + + // Fetch vertex positions. + float4 p1 = ((float4*)p.pos)[vi1]; + float4 p2 = ((float4*)p.pos)[vi2]; + + // Project vertices to pixel space. + float pxh = p.xh; + float pyh = p.yh; + float fx = (float)px + .5f - pxh; + float fy = (float)py + .5f - pyh; + + // XY flip for horizontal edges. + if (d) + { + swap(p1.x, p1.y); + swap(p2.x, p2.y); + swap(pxh, pyh); + swap(fx, fy); + } + + // Gradient calculation setup. + float w1 = 1.f / p1.w; + float w2 = 1.f / p2.w; + float x1 = p1.x * w1 * pxh - fx; + float y1 = p1.y * w1 * pyh - fy; + float x2 = p2.x * w2 * pxh - fx; + float y2 = p2.y * w2 * pyh - fy; + float dx = x2 - x1; + float dy = y2 - y1; + float db = x1*dy - y1*dx; + + // Compute inverse delta-y with epsilon. + float ep = copysignf(1e-3f, dy); // ~1/1000 pixel. + float iy = 1.f / (dy + ep); + + // Compute position gradients. + float dby = db * iy; + float iw1 = -w1 * iy * dd; + float iw2 = w2 * iy * dd; + float gp1x = iw1 * pxh * y2; + float gp2x = iw2 * pxh * y1; + float gp1y = iw1 * pyh * (dby - x2); + float gp2y = iw2 * pyh * (dby - x1); + float gp1w = -(p1.x * gp1x + p1.y * gp1y) * w1; + float gp2w = -(p2.x * gp2x + p2.y * gp2y) * w2; + + // XY flip the gradients. + if (d) + { + swap(gp1x, gp1y); + swap(gp2x, gp2y); + } + + // Kill position gradients if alpha was saturated. + if (fabsf(alpha) >= 0.5f) + { + gp1x = gp1y = gp1w = 0.f; + gp2x = gp2y = gp2w = 0.f; + } + + // Initialize coalesced atomics. Match both triangle ID and edge index. + // Also note that some threads may be inactive. + CA_SET_GROUP_MASK(tri ^ (di << 30), amask); + + // Accumulate gradients. + caAtomicAdd3_xyw(p.gradPos + 4 * vi1, gp1x, gp1y, gp1w); + caAtomicAdd3_xyw(p.gradPos + 4 * vi2, gp2x, gp2y, gp2w); + } +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/antialias.h b/extensions/nvdiffrast/nvdiffrast/common/antialias.h new file mode 100644 index 0000000000000000000000000000000000000000..bc2fd480e63ca6ad12fc5f4e15f1857880d145c3 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/antialias.h @@ -0,0 +1,50 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once +#include "common.h" + +//------------------------------------------------------------------------ +// Constants and helpers. + +#define AA_DISCONTINUITY_KERNEL_BLOCK_WIDTH 32 +#define AA_DISCONTINUITY_KERNEL_BLOCK_HEIGHT 8 +#define AA_ANALYSIS_KERNEL_THREADS_PER_BLOCK 256 +#define AA_MESH_KERNEL_THREADS_PER_BLOCK 256 +#define AA_HASH_ELEMENTS_PER_TRIANGLE(alloc) ((alloc) >= (2 << 25) ? 4 : 8) // With more than 16777216 triangles (alloc >= 33554432) use smallest possible value of 4 to conserve memory, otherwise use 8 for fewer collisions. +#define AA_LOG_HASH_ELEMENTS_PER_TRIANGLE(alloc) ((alloc) >= (2 << 25) ? 2 : 3) +#define AA_GRAD_KERNEL_THREADS_PER_BLOCK 256 + +//------------------------------------------------------------------------ +// CUDA kernel params. + +struct AntialiasKernelParams +{ + const float* color; // Incoming color buffer. + const float* rasterOut; // Incoming rasterizer output buffer. + const int* tri; // Incoming triangle buffer. + const float* pos; // Incoming position buffer. + float* output; // Output buffer of forward kernel. + const float* dy; // Incoming gradients. + float* gradColor; // Output buffer, color gradient. + float* gradPos; // Output buffer, position gradient. + int4* workBuffer; // Buffer for storing intermediate work items. First item reserved for counters. + uint4* evHash; // Edge-vertex hash. + int allocTriangles; // Number of triangles accommodated by evHash. Always power of two. + int numTriangles; // Number of triangles. + int numVertices; // Number of vertices. + int width; // Input width. + int height; // Input height. + int n; // Minibatch size. + int channels; // Channel count in color input. + float xh, yh; // Transfer to pixel space. + int instance_mode; // 0=normal, 1=instance mode. + int tri_const; // 1 if triangle array is known to be constant. +}; + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/common.cpp b/extensions/nvdiffrast/nvdiffrast/common/common.cpp new file mode 100644 index 0000000000000000000000000000000000000000..24e59ba6242ea362eac349e5127833908a4f251c --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/common.cpp @@ -0,0 +1,60 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include + +//------------------------------------------------------------------------ +// Block and grid size calculators for kernel launches. + +dim3 getLaunchBlockSize(int maxWidth, int maxHeight, int width, int height) +{ + int maxThreads = maxWidth * maxHeight; + if (maxThreads <= 1 || (width * height) <= 1) + return dim3(1, 1, 1); // Degenerate. + + // Start from max size. + int bw = maxWidth; + int bh = maxHeight; + + // Optimizations for weirdly sized buffers. + if (width < bw) + { + // Decrease block width to smallest power of two that covers the buffer width. + while ((bw >> 1) >= width) + bw >>= 1; + + // Maximize height. + bh = maxThreads / bw; + if (bh > height) + bh = height; + } + else if (height < bh) + { + // Halve height and double width until fits completely inside buffer vertically. + while (bh > height) + { + bh >>= 1; + if (bw < width) + bw <<= 1; + } + } + + // Done. + return dim3(bw, bh, 1); +} + +dim3 getLaunchGridSize(dim3 blockSize, int width, int height, int depth) +{ + dim3 gridSize; + gridSize.x = (width - 1) / blockSize.x + 1; + gridSize.y = (height - 1) / blockSize.y + 1; + gridSize.z = (depth - 1) / blockSize.z + 1; + return gridSize; +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/common.h b/extensions/nvdiffrast/nvdiffrast/common/common.h new file mode 100644 index 0000000000000000000000000000000000000000..df1e30d40ceb32c247516b68882c678b6dba215a --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/common.h @@ -0,0 +1,263 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once +#include +#include + +//------------------------------------------------------------------------ +// C++ helper function prototypes. + +dim3 getLaunchBlockSize(int maxWidth, int maxHeight, int width, int height); +dim3 getLaunchGridSize(dim3 blockSize, int width, int height, int depth); + +//------------------------------------------------------------------------ +// The rest is CUDA device code specific stuff. + +#ifdef __CUDACC__ + +//------------------------------------------------------------------------ +// Helpers for CUDA vector types. + +static __device__ __forceinline__ float2& operator*= (float2& a, const float2& b) { a.x *= b.x; a.y *= b.y; return a; } +static __device__ __forceinline__ float2& operator+= (float2& a, const float2& b) { a.x += b.x; a.y += b.y; return a; } +static __device__ __forceinline__ float2& operator-= (float2& a, const float2& b) { a.x -= b.x; a.y -= b.y; return a; } +static __device__ __forceinline__ float2& operator*= (float2& a, float b) { a.x *= b; a.y *= b; return a; } +static __device__ __forceinline__ float2& operator+= (float2& a, float b) { a.x += b; a.y += b; return a; } +static __device__ __forceinline__ float2& operator-= (float2& a, float b) { a.x -= b; a.y -= b; return a; } +static __device__ __forceinline__ float2 operator* (const float2& a, const float2& b) { return make_float2(a.x * b.x, a.y * b.y); } +static __device__ __forceinline__ float2 operator+ (const float2& a, const float2& b) { return make_float2(a.x + b.x, a.y + b.y); } +static __device__ __forceinline__ float2 operator- (const float2& a, const float2& b) { return make_float2(a.x - b.x, a.y - b.y); } +static __device__ __forceinline__ float2 operator* (const float2& a, float b) { return make_float2(a.x * b, a.y * b); } +static __device__ __forceinline__ float2 operator+ (const float2& a, float b) { return make_float2(a.x + b, a.y + b); } +static __device__ __forceinline__ float2 operator- (const float2& a, float b) { return make_float2(a.x - b, a.y - b); } +static __device__ __forceinline__ float2 operator* (float a, const float2& b) { return make_float2(a * b.x, a * b.y); } +static __device__ __forceinline__ float2 operator+ (float a, const float2& b) { return make_float2(a + b.x, a + b.y); } +static __device__ __forceinline__ float2 operator- (float a, const float2& b) { return make_float2(a - b.x, a - b.y); } +static __device__ __forceinline__ float2 operator- (const float2& a) { return make_float2(-a.x, -a.y); } +static __device__ __forceinline__ float3& operator*= (float3& a, const float3& b) { a.x *= b.x; a.y *= b.y; a.z *= b.z; return a; } +static __device__ __forceinline__ float3& operator+= (float3& a, const float3& b) { a.x += b.x; a.y += b.y; a.z += b.z; return a; } +static __device__ __forceinline__ float3& operator-= (float3& a, const float3& b) { a.x -= b.x; a.y -= b.y; a.z -= b.z; return a; } +static __device__ __forceinline__ float3& operator*= (float3& a, float b) { a.x *= b; a.y *= b; a.z *= b; return a; } +static __device__ __forceinline__ float3& operator+= (float3& a, float b) { a.x += b; a.y += b; a.z += b; return a; } +static __device__ __forceinline__ float3& operator-= (float3& a, float b) { a.x -= b; a.y -= b; a.z -= b; return a; } +static __device__ __forceinline__ float3 operator* (const float3& a, const float3& b) { return make_float3(a.x * b.x, a.y * b.y, a.z * b.z); } +static __device__ __forceinline__ float3 operator+ (const float3& a, const float3& b) { return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); } +static __device__ __forceinline__ float3 operator- (const float3& a, const float3& b) { return make_float3(a.x - b.x, a.y - b.y, a.z - b.z); } +static __device__ __forceinline__ float3 operator* (const float3& a, float b) { return make_float3(a.x * b, a.y * b, a.z * b); } +static __device__ __forceinline__ float3 operator+ (const float3& a, float b) { return make_float3(a.x + b, a.y + b, a.z + b); } +static __device__ __forceinline__ float3 operator- (const float3& a, float b) { return make_float3(a.x - b, a.y - b, a.z - b); } +static __device__ __forceinline__ float3 operator* (float a, const float3& b) { return make_float3(a * b.x, a * b.y, a * b.z); } +static __device__ __forceinline__ float3 operator+ (float a, const float3& b) { return make_float3(a + b.x, a + b.y, a + b.z); } +static __device__ __forceinline__ float3 operator- (float a, const float3& b) { return make_float3(a - b.x, a - b.y, a - b.z); } +static __device__ __forceinline__ float3 operator- (const float3& a) { return make_float3(-a.x, -a.y, -a.z); } +static __device__ __forceinline__ float4& operator*= (float4& a, const float4& b) { a.x *= b.x; a.y *= b.y; a.z *= b.z; a.w *= b.w; return a; } +static __device__ __forceinline__ float4& operator+= (float4& a, const float4& b) { a.x += b.x; a.y += b.y; a.z += b.z; a.w += b.w; return a; } +static __device__ __forceinline__ float4& operator-= (float4& a, const float4& b) { a.x -= b.x; a.y -= b.y; a.z -= b.z; a.w -= b.w; return a; } +static __device__ __forceinline__ float4& operator*= (float4& a, float b) { a.x *= b; a.y *= b; a.z *= b; a.w *= b; return a; } +static __device__ __forceinline__ float4& operator+= (float4& a, float b) { a.x += b; a.y += b; a.z += b; a.w += b; return a; } +static __device__ __forceinline__ float4& operator-= (float4& a, float b) { a.x -= b; a.y -= b; a.z -= b; a.w -= b; return a; } +static __device__ __forceinline__ float4 operator* (const float4& a, const float4& b) { return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); } +static __device__ __forceinline__ float4 operator+ (const float4& a, const float4& b) { return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); } +static __device__ __forceinline__ float4 operator- (const float4& a, const float4& b) { return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); } +static __device__ __forceinline__ float4 operator* (const float4& a, float b) { return make_float4(a.x * b, a.y * b, a.z * b, a.w * b); } +static __device__ __forceinline__ float4 operator+ (const float4& a, float b) { return make_float4(a.x + b, a.y + b, a.z + b, a.w + b); } +static __device__ __forceinline__ float4 operator- (const float4& a, float b) { return make_float4(a.x - b, a.y - b, a.z - b, a.w - b); } +static __device__ __forceinline__ float4 operator* (float a, const float4& b) { return make_float4(a * b.x, a * b.y, a * b.z, a * b.w); } +static __device__ __forceinline__ float4 operator+ (float a, const float4& b) { return make_float4(a + b.x, a + b.y, a + b.z, a + b.w); } +static __device__ __forceinline__ float4 operator- (float a, const float4& b) { return make_float4(a - b.x, a - b.y, a - b.z, a - b.w); } +static __device__ __forceinline__ float4 operator- (const float4& a) { return make_float4(-a.x, -a.y, -a.z, -a.w); } +static __device__ __forceinline__ int2& operator*= (int2& a, const int2& b) { a.x *= b.x; a.y *= b.y; return a; } +static __device__ __forceinline__ int2& operator+= (int2& a, const int2& b) { a.x += b.x; a.y += b.y; return a; } +static __device__ __forceinline__ int2& operator-= (int2& a, const int2& b) { a.x -= b.x; a.y -= b.y; return a; } +static __device__ __forceinline__ int2& operator*= (int2& a, int b) { a.x *= b; a.y *= b; return a; } +static __device__ __forceinline__ int2& operator+= (int2& a, int b) { a.x += b; a.y += b; return a; } +static __device__ __forceinline__ int2& operator-= (int2& a, int b) { a.x -= b; a.y -= b; return a; } +static __device__ __forceinline__ int2 operator* (const int2& a, const int2& b) { return make_int2(a.x * b.x, a.y * b.y); } +static __device__ __forceinline__ int2 operator+ (const int2& a, const int2& b) { return make_int2(a.x + b.x, a.y + b.y); } +static __device__ __forceinline__ int2 operator- (const int2& a, const int2& b) { return make_int2(a.x - b.x, a.y - b.y); } +static __device__ __forceinline__ int2 operator* (const int2& a, int b) { return make_int2(a.x * b, a.y * b); } +static __device__ __forceinline__ int2 operator+ (const int2& a, int b) { return make_int2(a.x + b, a.y + b); } +static __device__ __forceinline__ int2 operator- (const int2& a, int b) { return make_int2(a.x - b, a.y - b); } +static __device__ __forceinline__ int2 operator* (int a, const int2& b) { return make_int2(a * b.x, a * b.y); } +static __device__ __forceinline__ int2 operator+ (int a, const int2& b) { return make_int2(a + b.x, a + b.y); } +static __device__ __forceinline__ int2 operator- (int a, const int2& b) { return make_int2(a - b.x, a - b.y); } +static __device__ __forceinline__ int2 operator- (const int2& a) { return make_int2(-a.x, -a.y); } +static __device__ __forceinline__ int3& operator*= (int3& a, const int3& b) { a.x *= b.x; a.y *= b.y; a.z *= b.z; return a; } +static __device__ __forceinline__ int3& operator+= (int3& a, const int3& b) { a.x += b.x; a.y += b.y; a.z += b.z; return a; } +static __device__ __forceinline__ int3& operator-= (int3& a, const int3& b) { a.x -= b.x; a.y -= b.y; a.z -= b.z; return a; } +static __device__ __forceinline__ int3& operator*= (int3& a, int b) { a.x *= b; a.y *= b; a.z *= b; return a; } +static __device__ __forceinline__ int3& operator+= (int3& a, int b) { a.x += b; a.y += b; a.z += b; return a; } +static __device__ __forceinline__ int3& operator-= (int3& a, int b) { a.x -= b; a.y -= b; a.z -= b; return a; } +static __device__ __forceinline__ int3 operator* (const int3& a, const int3& b) { return make_int3(a.x * b.x, a.y * b.y, a.z * b.z); } +static __device__ __forceinline__ int3 operator+ (const int3& a, const int3& b) { return make_int3(a.x + b.x, a.y + b.y, a.z + b.z); } +static __device__ __forceinline__ int3 operator- (const int3& a, const int3& b) { return make_int3(a.x - b.x, a.y - b.y, a.z - b.z); } +static __device__ __forceinline__ int3 operator* (const int3& a, int b) { return make_int3(a.x * b, a.y * b, a.z * b); } +static __device__ __forceinline__ int3 operator+ (const int3& a, int b) { return make_int3(a.x + b, a.y + b, a.z + b); } +static __device__ __forceinline__ int3 operator- (const int3& a, int b) { return make_int3(a.x - b, a.y - b, a.z - b); } +static __device__ __forceinline__ int3 operator* (int a, const int3& b) { return make_int3(a * b.x, a * b.y, a * b.z); } +static __device__ __forceinline__ int3 operator+ (int a, const int3& b) { return make_int3(a + b.x, a + b.y, a + b.z); } +static __device__ __forceinline__ int3 operator- (int a, const int3& b) { return make_int3(a - b.x, a - b.y, a - b.z); } +static __device__ __forceinline__ int3 operator- (const int3& a) { return make_int3(-a.x, -a.y, -a.z); } +static __device__ __forceinline__ int4& operator*= (int4& a, const int4& b) { a.x *= b.x; a.y *= b.y; a.z *= b.z; a.w *= b.w; return a; } +static __device__ __forceinline__ int4& operator+= (int4& a, const int4& b) { a.x += b.x; a.y += b.y; a.z += b.z; a.w += b.w; return a; } +static __device__ __forceinline__ int4& operator-= (int4& a, const int4& b) { a.x -= b.x; a.y -= b.y; a.z -= b.z; a.w -= b.w; return a; } +static __device__ __forceinline__ int4& operator*= (int4& a, int b) { a.x *= b; a.y *= b; a.z *= b; a.w *= b; return a; } +static __device__ __forceinline__ int4& operator+= (int4& a, int b) { a.x += b; a.y += b; a.z += b; a.w += b; return a; } +static __device__ __forceinline__ int4& operator-= (int4& a, int b) { a.x -= b; a.y -= b; a.z -= b; a.w -= b; return a; } +static __device__ __forceinline__ int4 operator* (const int4& a, const int4& b) { return make_int4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); } +static __device__ __forceinline__ int4 operator+ (const int4& a, const int4& b) { return make_int4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); } +static __device__ __forceinline__ int4 operator- (const int4& a, const int4& b) { return make_int4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); } +static __device__ __forceinline__ int4 operator* (const int4& a, int b) { return make_int4(a.x * b, a.y * b, a.z * b, a.w * b); } +static __device__ __forceinline__ int4 operator+ (const int4& a, int b) { return make_int4(a.x + b, a.y + b, a.z + b, a.w + b); } +static __device__ __forceinline__ int4 operator- (const int4& a, int b) { return make_int4(a.x - b, a.y - b, a.z - b, a.w - b); } +static __device__ __forceinline__ int4 operator* (int a, const int4& b) { return make_int4(a * b.x, a * b.y, a * b.z, a * b.w); } +static __device__ __forceinline__ int4 operator+ (int a, const int4& b) { return make_int4(a + b.x, a + b.y, a + b.z, a + b.w); } +static __device__ __forceinline__ int4 operator- (int a, const int4& b) { return make_int4(a - b.x, a - b.y, a - b.z, a - b.w); } +static __device__ __forceinline__ int4 operator- (const int4& a) { return make_int4(-a.x, -a.y, -a.z, -a.w); } +static __device__ __forceinline__ uint2& operator*= (uint2& a, const uint2& b) { a.x *= b.x; a.y *= b.y; return a; } +static __device__ __forceinline__ uint2& operator+= (uint2& a, const uint2& b) { a.x += b.x; a.y += b.y; return a; } +static __device__ __forceinline__ uint2& operator-= (uint2& a, const uint2& b) { a.x -= b.x; a.y -= b.y; return a; } +static __device__ __forceinline__ uint2& operator*= (uint2& a, unsigned int b) { a.x *= b; a.y *= b; return a; } +static __device__ __forceinline__ uint2& operator+= (uint2& a, unsigned int b) { a.x += b; a.y += b; return a; } +static __device__ __forceinline__ uint2& operator-= (uint2& a, unsigned int b) { a.x -= b; a.y -= b; return a; } +static __device__ __forceinline__ uint2 operator* (const uint2& a, const uint2& b) { return make_uint2(a.x * b.x, a.y * b.y); } +static __device__ __forceinline__ uint2 operator+ (const uint2& a, const uint2& b) { return make_uint2(a.x + b.x, a.y + b.y); } +static __device__ __forceinline__ uint2 operator- (const uint2& a, const uint2& b) { return make_uint2(a.x - b.x, a.y - b.y); } +static __device__ __forceinline__ uint2 operator* (const uint2& a, unsigned int b) { return make_uint2(a.x * b, a.y * b); } +static __device__ __forceinline__ uint2 operator+ (const uint2& a, unsigned int b) { return make_uint2(a.x + b, a.y + b); } +static __device__ __forceinline__ uint2 operator- (const uint2& a, unsigned int b) { return make_uint2(a.x - b, a.y - b); } +static __device__ __forceinline__ uint2 operator* (unsigned int a, const uint2& b) { return make_uint2(a * b.x, a * b.y); } +static __device__ __forceinline__ uint2 operator+ (unsigned int a, const uint2& b) { return make_uint2(a + b.x, a + b.y); } +static __device__ __forceinline__ uint2 operator- (unsigned int a, const uint2& b) { return make_uint2(a - b.x, a - b.y); } +static __device__ __forceinline__ uint3& operator*= (uint3& a, const uint3& b) { a.x *= b.x; a.y *= b.y; a.z *= b.z; return a; } +static __device__ __forceinline__ uint3& operator+= (uint3& a, const uint3& b) { a.x += b.x; a.y += b.y; a.z += b.z; return a; } +static __device__ __forceinline__ uint3& operator-= (uint3& a, const uint3& b) { a.x -= b.x; a.y -= b.y; a.z -= b.z; return a; } +static __device__ __forceinline__ uint3& operator*= (uint3& a, unsigned int b) { a.x *= b; a.y *= b; a.z *= b; return a; } +static __device__ __forceinline__ uint3& operator+= (uint3& a, unsigned int b) { a.x += b; a.y += b; a.z += b; return a; } +static __device__ __forceinline__ uint3& operator-= (uint3& a, unsigned int b) { a.x -= b; a.y -= b; a.z -= b; return a; } +static __device__ __forceinline__ uint3 operator* (const uint3& a, const uint3& b) { return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z); } +static __device__ __forceinline__ uint3 operator+ (const uint3& a, const uint3& b) { return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z); } +static __device__ __forceinline__ uint3 operator- (const uint3& a, const uint3& b) { return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z); } +static __device__ __forceinline__ uint3 operator* (const uint3& a, unsigned int b) { return make_uint3(a.x * b, a.y * b, a.z * b); } +static __device__ __forceinline__ uint3 operator+ (const uint3& a, unsigned int b) { return make_uint3(a.x + b, a.y + b, a.z + b); } +static __device__ __forceinline__ uint3 operator- (const uint3& a, unsigned int b) { return make_uint3(a.x - b, a.y - b, a.z - b); } +static __device__ __forceinline__ uint3 operator* (unsigned int a, const uint3& b) { return make_uint3(a * b.x, a * b.y, a * b.z); } +static __device__ __forceinline__ uint3 operator+ (unsigned int a, const uint3& b) { return make_uint3(a + b.x, a + b.y, a + b.z); } +static __device__ __forceinline__ uint3 operator- (unsigned int a, const uint3& b) { return make_uint3(a - b.x, a - b.y, a - b.z); } +static __device__ __forceinline__ uint4& operator*= (uint4& a, const uint4& b) { a.x *= b.x; a.y *= b.y; a.z *= b.z; a.w *= b.w; return a; } +static __device__ __forceinline__ uint4& operator+= (uint4& a, const uint4& b) { a.x += b.x; a.y += b.y; a.z += b.z; a.w += b.w; return a; } +static __device__ __forceinline__ uint4& operator-= (uint4& a, const uint4& b) { a.x -= b.x; a.y -= b.y; a.z -= b.z; a.w -= b.w; return a; } +static __device__ __forceinline__ uint4& operator*= (uint4& a, unsigned int b) { a.x *= b; a.y *= b; a.z *= b; a.w *= b; return a; } +static __device__ __forceinline__ uint4& operator+= (uint4& a, unsigned int b) { a.x += b; a.y += b; a.z += b; a.w += b; return a; } +static __device__ __forceinline__ uint4& operator-= (uint4& a, unsigned int b) { a.x -= b; a.y -= b; a.z -= b; a.w -= b; return a; } +static __device__ __forceinline__ uint4 operator* (const uint4& a, const uint4& b) { return make_uint4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); } +static __device__ __forceinline__ uint4 operator+ (const uint4& a, const uint4& b) { return make_uint4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); } +static __device__ __forceinline__ uint4 operator- (const uint4& a, const uint4& b) { return make_uint4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); } +static __device__ __forceinline__ uint4 operator* (const uint4& a, unsigned int b) { return make_uint4(a.x * b, a.y * b, a.z * b, a.w * b); } +static __device__ __forceinline__ uint4 operator+ (const uint4& a, unsigned int b) { return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b); } +static __device__ __forceinline__ uint4 operator- (const uint4& a, unsigned int b) { return make_uint4(a.x - b, a.y - b, a.z - b, a.w - b); } +static __device__ __forceinline__ uint4 operator* (unsigned int a, const uint4& b) { return make_uint4(a * b.x, a * b.y, a * b.z, a * b.w); } +static __device__ __forceinline__ uint4 operator+ (unsigned int a, const uint4& b) { return make_uint4(a + b.x, a + b.y, a + b.z, a + b.w); } +static __device__ __forceinline__ uint4 operator- (unsigned int a, const uint4& b) { return make_uint4(a - b.x, a - b.y, a - b.z, a - b.w); } + +template static __device__ __forceinline__ T zero_value(void); +template<> __device__ __forceinline__ float zero_value (void) { return 0.f; } +template<> __device__ __forceinline__ float2 zero_value(void) { return make_float2(0.f, 0.f); } +template<> __device__ __forceinline__ float4 zero_value(void) { return make_float4(0.f, 0.f, 0.f, 0.f); } +static __device__ __forceinline__ float3 make_float3(const float2& a, float b) { return make_float3(a.x, a.y, b); } +static __device__ __forceinline__ float4 make_float4(const float3& a, float b) { return make_float4(a.x, a.y, a.z, b); } +static __device__ __forceinline__ float4 make_float4(const float2& a, const float2& b) { return make_float4(a.x, a.y, b.x, b.y); } +static __device__ __forceinline__ int3 make_int3(const int2& a, int b) { return make_int3(a.x, a.y, b); } +static __device__ __forceinline__ int4 make_int4(const int3& a, int b) { return make_int4(a.x, a.y, a.z, b); } +static __device__ __forceinline__ int4 make_int4(const int2& a, const int2& b) { return make_int4(a.x, a.y, b.x, b.y); } +static __device__ __forceinline__ uint3 make_uint3(const uint2& a, unsigned int b) { return make_uint3(a.x, a.y, b); } +static __device__ __forceinline__ uint4 make_uint4(const uint3& a, unsigned int b) { return make_uint4(a.x, a.y, a.z, b); } +static __device__ __forceinline__ uint4 make_uint4(const uint2& a, const uint2& b) { return make_uint4(a.x, a.y, b.x, b.y); } + +template static __device__ __forceinline__ void swap(T& a, T& b) { T temp = a; a = b; b = temp; } + +//------------------------------------------------------------------------ +// Triangle ID <-> float32 conversion functions to support very large triangle IDs. +// +// Values up to and including 16777216 (also, negative values) are converted trivially and retain +// compatibility with previous versions. Larger values are mapped to unique float32 that are not equal to +// the ID. The largest value that converts to float32 and back without generating inf or nan is 889192447. + +static __device__ __forceinline__ int float_to_triidx(float x) { if (x <= 16777216.f) return (int)x; return __float_as_int(x) - 0x4a800000; } +static __device__ __forceinline__ float triidx_to_float(int x) { if (x <= 0x01000000) return (float)x; return __int_as_float(0x4a800000 + x); } + +//------------------------------------------------------------------------ +// Coalesced atomics. These are all done via macros. + +#if __CUDA_ARCH__ >= 700 // Warp match instruction __match_any_sync() is only available on compute capability 7.x and higher + +#define CA_TEMP _ca_temp +#define CA_TEMP_PARAM float* CA_TEMP +#define CA_DECLARE_TEMP(threads_per_block) \ + __shared__ float CA_TEMP[(threads_per_block)] + +#define CA_SET_GROUP_MASK(group, thread_mask) \ + bool _ca_leader; \ + float* _ca_ptr; \ + do { \ + int tidx = threadIdx.x + blockDim.x * threadIdx.y; \ + int lane = tidx & 31; \ + int warp = tidx >> 5; \ + int tmask = __match_any_sync((thread_mask), (group)); \ + int leader = __ffs(tmask) - 1; \ + _ca_leader = (leader == lane); \ + _ca_ptr = &_ca_temp[((warp << 5) + leader)]; \ + } while(0) + +#define CA_SET_GROUP(group) \ + CA_SET_GROUP_MASK((group), 0xffffffffu) + +#define caAtomicAdd(ptr, value) \ + do { \ + if (_ca_leader) \ + *_ca_ptr = 0.f; \ + atomicAdd(_ca_ptr, (value)); \ + if (_ca_leader) \ + atomicAdd((ptr), *_ca_ptr); \ + } while(0) + +#define caAtomicAdd3_xyw(ptr, x, y, w) \ + do { \ + caAtomicAdd((ptr), (x)); \ + caAtomicAdd((ptr)+1, (y)); \ + caAtomicAdd((ptr)+3, (w)); \ + } while(0) + +#define caAtomicAddTexture(ptr, level, idx, value) \ + do { \ + CA_SET_GROUP((idx) ^ ((level) << 27)); \ + caAtomicAdd((ptr)+(idx), (value)); \ + } while(0) + +//------------------------------------------------------------------------ +// Disable atomic coalescing for compute capability lower than 7.x + +#else // __CUDA_ARCH__ >= 700 +#define CA_TEMP _ca_temp +#define CA_TEMP_PARAM float CA_TEMP +#define CA_DECLARE_TEMP(threads_per_block) CA_TEMP_PARAM +#define CA_SET_GROUP_MASK(group, thread_mask) +#define CA_SET_GROUP(group) +#define caAtomicAdd(ptr, value) atomicAdd((ptr), (value)) +#define caAtomicAdd3_xyw(ptr, x, y, w) \ + do { \ + atomicAdd((ptr), (x)); \ + atomicAdd((ptr)+1, (y)); \ + atomicAdd((ptr)+3, (w)); \ + } while(0) +#define caAtomicAddTexture(ptr, level, idx, value) atomicAdd((ptr)+(idx), (value)) +#endif // __CUDA_ARCH__ >= 700 + +//------------------------------------------------------------------------ +#endif // __CUDACC__ diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/CudaRaster.hpp b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/CudaRaster.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9261f995af645f50d25f9de5267e692576093cf1 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/CudaRaster.hpp @@ -0,0 +1,63 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once + +//------------------------------------------------------------------------ +// This is a slimmed-down and modernized version of the original +// CudaRaster codebase that accompanied the HPG 2011 paper +// "High-Performance Software Rasterization on GPUs" by Laine and Karras. +// Modifications have been made to accommodate post-Volta execution model +// with warp divergence. Support for shading, blending, quad rendering, +// and supersampling have been removed as unnecessary for nvdiffrast. +//------------------------------------------------------------------------ + +namespace CR +{ + +class RasterImpl; + +//------------------------------------------------------------------------ +// Interface class to isolate user from implementation details. +//------------------------------------------------------------------------ + +class CudaRaster +{ +public: + enum + { + RenderModeFlag_EnableBackfaceCulling = 1 << 0, // Enable backface culling. + RenderModeFlag_EnableDepthPeeling = 1 << 1, // Enable depth peeling. Must have a peel buffer set. + }; + +public: + CudaRaster (void); + ~CudaRaster (void); + + void setBufferSize (int width, int height, int numImages); // Width and height are internally rounded up to multiples of tile size (8x8) for buffer sizes. + void setViewport (int width, int height, int offsetX, int offsetY); // Tiled rendering viewport setup. + void setRenderModeFlags (unsigned int renderModeFlags); // Affects all subsequent calls to drawTriangles(). Defaults to zero. + void deferredClear (unsigned int clearColor); // Clears color and depth buffers during next call to drawTriangles(). + void setVertexBuffer (void* vertices, int numVertices); // GPU pointer managed by caller. Vertex positions in clip space as float4 (x, y, z, w). + void setIndexBuffer (void* indices, int numTriangles); // GPU pointer managed by caller. Triangle index+color quadruplets as uint4 (idx0, idx1, idx2, color). + bool drawTriangles (const int* ranges, bool peel, cudaStream_t stream); // Ranges (offsets and counts) as #triangles entries, not as bytes. If NULL, draw all triangles. Returns false in case of internal overflow. + void* getColorBuffer (void); // GPU pointer managed by CudaRaster. + void* getDepthBuffer (void); // GPU pointer managed by CudaRaster. + void swapDepthAndPeel (void); // Swap depth and peeling buffers. + +private: + CudaRaster (const CudaRaster&); // forbidden + CudaRaster& operator= (const CudaRaster&); // forbidden + +private: + RasterImpl* m_impl; // Opaque pointer to implementation. +}; + +//------------------------------------------------------------------------ +} // namespace CR + diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/BinRaster.inl b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/BinRaster.inl new file mode 100644 index 0000000000000000000000000000000000000000..0cfb2c356727a6ee9d3b1063c28fd4b6817e093c --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/BinRaster.inl @@ -0,0 +1,423 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ + +__device__ __inline__ void binRasterImpl(const CRParams p) +{ + __shared__ volatile U32 s_broadcast [CR_BIN_WARPS + 16]; + __shared__ volatile S32 s_outOfs [CR_MAXBINS_SQR]; + __shared__ volatile S32 s_outTotal [CR_MAXBINS_SQR]; + __shared__ volatile S32 s_overIndex [CR_MAXBINS_SQR]; + __shared__ volatile S32 s_outMask [CR_BIN_WARPS][CR_MAXBINS_SQR + 1]; // +1 to avoid bank collisions + __shared__ volatile S32 s_outCount [CR_BIN_WARPS][CR_MAXBINS_SQR + 1]; // +1 to avoid bank collisions + __shared__ volatile S32 s_triBuf [CR_BIN_WARPS*32*4]; // triangle ring buffer + __shared__ volatile U32 s_batchPos; + __shared__ volatile U32 s_bufCount; + __shared__ volatile U32 s_overTotal; + __shared__ volatile U32 s_allocBase; + + const CRImageParams& ip = getImageParams(p, blockIdx.z); + CRAtomics& atomics = p.atomics[blockIdx.z]; + const U8* triSubtris = (const U8*)p.triSubtris + p.maxSubtris * blockIdx.z; + const CRTriangleHeader* triHeader = (const CRTriangleHeader*)p.triHeader + p.maxSubtris * blockIdx.z; + + S32* binFirstSeg = (S32*)p.binFirstSeg + CR_MAXBINS_SQR * CR_BIN_STREAMS_SIZE * blockIdx.z; + S32* binTotal = (S32*)p.binTotal + CR_MAXBINS_SQR * CR_BIN_STREAMS_SIZE * blockIdx.z; + S32* binSegData = (S32*)p.binSegData + p.maxBinSegs * CR_BIN_SEG_SIZE * blockIdx.z; + S32* binSegNext = (S32*)p.binSegNext + p.maxBinSegs * blockIdx.z; + S32* binSegCount = (S32*)p.binSegCount + p.maxBinSegs * blockIdx.z; + + if (atomics.numSubtris > p.maxSubtris) + return; + + // per-thread state + int thrInBlock = threadIdx.x + threadIdx.y * 32; + int batchPos = 0; + + // first 16 elements of s_broadcast are always zero + if (thrInBlock < 16) + s_broadcast[thrInBlock] = 0; + + // initialize output linked lists and offsets + if (thrInBlock < p.numBins) + { + binFirstSeg[(thrInBlock << CR_BIN_STREAMS_LOG2) + blockIdx.x] = -1; + s_outOfs[thrInBlock] = -CR_BIN_SEG_SIZE; + s_outTotal[thrInBlock] = 0; + } + + // repeat until done + for(;;) + { + // get batch + if (thrInBlock == 0) + s_batchPos = atomicAdd(&atomics.binCounter, ip.binBatchSize); + __syncthreads(); + batchPos = s_batchPos; + + // all batches done? + if (batchPos >= ip.triCount) + break; + + // per-thread state + int bufIndex = 0; + int bufCount = 0; + int batchEnd = min(batchPos + ip.binBatchSize, ip.triCount); + + // loop over batch as long as we have triangles in it + do + { + // read more triangles + while (bufCount < CR_BIN_WARPS*32 && batchPos < batchEnd) + { + // get subtriangle count + + int triIdx = batchPos + thrInBlock; + int num = 0; + if (triIdx < batchEnd) + num = triSubtris[triIdx]; + + // cumulative sum of subtriangles within each warp + U32 myIdx = __popc(__ballot_sync(~0u, num & 1) & getLaneMaskLt()); + if (__any_sync(~0u, num > 1)) + { + myIdx += __popc(__ballot_sync(~0u, num & 2) & getLaneMaskLt()) * 2; + myIdx += __popc(__ballot_sync(~0u, num & 4) & getLaneMaskLt()) * 4; + } + if (threadIdx.x == 31) // Do not assume that last thread in warp wins the write. + s_broadcast[threadIdx.y + 16] = myIdx + num; + __syncthreads(); + + // cumulative sum of per-warp subtriangle counts + // Note: cannot have more than 32 warps or this needs to sync between each step. + bool act = (thrInBlock < CR_BIN_WARPS); + U32 actMask = __ballot_sync(~0u, act); + if (threadIdx.y == 0 && act) + { + volatile U32* ptr = &s_broadcast[thrInBlock + 16]; + U32 val = *ptr; + #if (CR_BIN_WARPS > 1) + val += ptr[-1]; __syncwarp(actMask); + *ptr = val; __syncwarp(actMask); + #endif + #if (CR_BIN_WARPS > 2) + val += ptr[-2]; __syncwarp(actMask); + *ptr = val; __syncwarp(actMask); + #endif + #if (CR_BIN_WARPS > 4) + val += ptr[-4]; __syncwarp(actMask); + *ptr = val; __syncwarp(actMask); + #endif + #if (CR_BIN_WARPS > 8) + val += ptr[-8]; __syncwarp(actMask); + *ptr = val; __syncwarp(actMask); + #endif + #if (CR_BIN_WARPS > 16) + val += ptr[-16]; __syncwarp(actMask); + *ptr = val; __syncwarp(actMask); + #endif + + // initially assume that we consume everything + // only last active thread does the writes + if (threadIdx.x == CR_BIN_WARPS - 1) + { + s_batchPos = batchPos + CR_BIN_WARPS * 32; + s_bufCount = bufCount + val; + } + } + __syncthreads(); + + // skip if no subtriangles + if (num) + { + // calculate write position for first subtriangle + U32 pos = bufCount + myIdx + s_broadcast[threadIdx.y + 16 - 1]; + + // only write if entire triangle fits + if (pos + num <= CR_ARRAY_SIZE(s_triBuf)) + { + pos += bufIndex; // adjust for current start position + pos &= CR_ARRAY_SIZE(s_triBuf)-1; + if (num == 1) + s_triBuf[pos] = triIdx * 8 + 7; // single triangle + else + { + for (int i=0; i < num; i++) + { + s_triBuf[pos] = triIdx * 8 + i; + pos++; + pos &= CR_ARRAY_SIZE(s_triBuf)-1; + } + } + } else if (pos <= CR_ARRAY_SIZE(s_triBuf)) + { + // this triangle is the first that failed, overwrite total count and triangle count + s_batchPos = batchPos + thrInBlock; + s_bufCount = pos; + } + } + + // update triangle counts + __syncthreads(); + batchPos = s_batchPos; + bufCount = s_bufCount; + } + + // make every warp clear its output buffers + for (int i=threadIdx.x; i < p.numBins; i += 32) + s_outMask[threadIdx.y][i] = 0; + __syncwarp(); + + // choose our triangle + uint4 triData = make_uint4(0, 0, 0, 0); + if (thrInBlock < bufCount) + { + U32 triPos = bufIndex + thrInBlock; + triPos &= CR_ARRAY_SIZE(s_triBuf)-1; + + // find triangle + int triIdx = s_triBuf[triPos]; + int dataIdx = triIdx >> 3; + int subtriIdx = triIdx & 7; + if (subtriIdx != 7) + dataIdx = triHeader[dataIdx].misc + subtriIdx; + + // read triangle + + triData = *(((const uint4*)triHeader) + dataIdx); + } + + // setup bounding box and edge functions, and rasterize + S32 lox, loy, hix, hiy; + bool hasTri = (thrInBlock < bufCount); + U32 hasTriMask = __ballot_sync(~0u, hasTri); + if (hasTri) + { + S32 v0x = add_s16lo_s16lo(triData.x, p.widthPixelsVp * (CR_SUBPIXEL_SIZE >> 1)); + S32 v0y = add_s16hi_s16lo(triData.x, p.heightPixelsVp * (CR_SUBPIXEL_SIZE >> 1)); + S32 d01x = sub_s16lo_s16lo(triData.y, triData.x); + S32 d01y = sub_s16hi_s16hi(triData.y, triData.x); + S32 d02x = sub_s16lo_s16lo(triData.z, triData.x); + S32 d02y = sub_s16hi_s16hi(triData.z, triData.x); + int binLog = CR_BIN_LOG2 + CR_TILE_LOG2 + CR_SUBPIXEL_LOG2; + lox = add_clamp_0_x((v0x + min_min(d01x, 0, d02x)) >> binLog, 0, p.widthBins - 1); + loy = add_clamp_0_x((v0y + min_min(d01y, 0, d02y)) >> binLog, 0, p.heightBins - 1); + hix = add_clamp_0_x((v0x + max_max(d01x, 0, d02x)) >> binLog, 0, p.widthBins - 1); + hiy = add_clamp_0_x((v0y + max_max(d01y, 0, d02y)) >> binLog, 0, p.heightBins - 1); + + U32 bit = 1 << threadIdx.x; +#if __CUDA_ARCH__ >= 700 + bool multi = (hix != lox || hiy != loy); + if (!__any_sync(hasTriMask, multi)) + { + int binIdx = lox + p.widthBins * loy; + U32 mask = __match_any_sync(hasTriMask, binIdx); + s_outMask[threadIdx.y][binIdx] = mask; + __syncwarp(hasTriMask); + } else +#endif + { + bool complex = (hix > lox+1 || hiy > loy+1); + if (!__any_sync(hasTriMask, complex)) + { + int binIdx = lox + p.widthBins * loy; + atomicOr((U32*)&s_outMask[threadIdx.y][binIdx], bit); + if (hix > lox) atomicOr((U32*)&s_outMask[threadIdx.y][binIdx + 1], bit); + if (hiy > loy) atomicOr((U32*)&s_outMask[threadIdx.y][binIdx + p.widthBins], bit); + if (hix > lox && hiy > loy) atomicOr((U32*)&s_outMask[threadIdx.y][binIdx + p.widthBins + 1], bit); + } else + { + S32 d12x = d02x - d01x, d12y = d02y - d01y; + v0x -= lox << binLog, v0y -= loy << binLog; + + S32 t01 = v0x * d01y - v0y * d01x; + S32 t02 = v0y * d02x - v0x * d02y; + S32 t12 = d01x * d12y - d01y * d12x - t01 - t02; + S32 b01 = add_sub(t01 >> binLog, max(d01x, 0), min(d01y, 0)); + S32 b02 = add_sub(t02 >> binLog, max(d02y, 0), min(d02x, 0)); + S32 b12 = add_sub(t12 >> binLog, max(d12x, 0), min(d12y, 0)); + + int width = hix - lox + 1; + d01x += width * d01y; + d02x += width * d02y; + d12x += width * d12y; + + U8* currPtr = (U8*)&s_outMask[threadIdx.y][lox + loy * p.widthBins]; + U8* skipPtr = (U8*)&s_outMask[threadIdx.y][(hix + 1) + loy * p.widthBins]; + U8* endPtr = (U8*)&s_outMask[threadIdx.y][lox + (hiy + 1) * p.widthBins]; + int stride = p.widthBins * 4; + int ptrYInc = stride - width * 4; + + do + { + if (b01 >= 0 && b02 >= 0 && b12 >= 0) + atomicOr((U32*)currPtr, bit); + currPtr += 4, b01 -= d01y, b02 += d02y, b12 -= d12y; + if (currPtr == skipPtr) + currPtr += ptrYInc, b01 += d01x, b02 -= d02x, b12 += d12x, skipPtr += stride; + } + while (currPtr != endPtr); + } + } + } + + // count per-bin contributions + if (thrInBlock == 0) + s_overTotal = 0; // overflow counter + + // ensure that out masks are done + __syncthreads(); + + int overIndex = -1; + bool act = (thrInBlock < p.numBins); + U32 actMask = __ballot_sync(~0u, act); + if (act) + { + U8* srcPtr = (U8*)&s_outMask[0][thrInBlock]; + U8* dstPtr = (U8*)&s_outCount[0][thrInBlock]; + int total = 0; + for (int i = 0; i < CR_BIN_WARPS; i++) + { + total += __popc(*(U32*)srcPtr); + *(U32*)dstPtr = total; + srcPtr += (CR_MAXBINS_SQR + 1) * 4; + dstPtr += (CR_MAXBINS_SQR + 1) * 4; + } + + // overflow => request a new segment + int ofs = s_outOfs[thrInBlock]; + bool ovr = (((ofs - 1) >> CR_BIN_SEG_LOG2) != (((ofs - 1) + total) >> CR_BIN_SEG_LOG2)); + U32 ovrMask = __ballot_sync(actMask, ovr); + if (ovr) + { + overIndex = __popc(ovrMask & getLaneMaskLt()); + if (overIndex == 0) + s_broadcast[threadIdx.y + 16] = atomicAdd((U32*)&s_overTotal, __popc(ovrMask)); + __syncwarp(ovrMask); + overIndex += s_broadcast[threadIdx.y + 16]; + s_overIndex[thrInBlock] = overIndex; + } + } + + // sync after overTotal is ready + __syncthreads(); + + // at least one segment overflowed => allocate segments + U32 overTotal = s_overTotal; + U32 allocBase = 0; + if (overTotal > 0) + { + // allocate memory + if (thrInBlock == 0) + { + U32 allocBase = atomicAdd(&atomics.numBinSegs, overTotal); + s_allocBase = (allocBase + overTotal <= p.maxBinSegs) ? allocBase : 0; + } + __syncthreads(); + allocBase = s_allocBase; + + // did my bin overflow? + if (overIndex != -1) + { + // calculate new segment index + int segIdx = allocBase + overIndex; + + // add to linked list + if (s_outOfs[thrInBlock] < 0) + binFirstSeg[(thrInBlock << CR_BIN_STREAMS_LOG2) + blockIdx.x] = segIdx; + else + binSegNext[(s_outOfs[thrInBlock] - 1) >> CR_BIN_SEG_LOG2] = segIdx; + + // defaults + binSegNext [segIdx] = -1; + binSegCount[segIdx] = CR_BIN_SEG_SIZE; + } + } + + // concurrent emission -- each warp handles its own triangle + if (thrInBlock < bufCount) + { + int triPos = (bufIndex + thrInBlock) & (CR_ARRAY_SIZE(s_triBuf) - 1); + int currBin = lox + loy * p.widthBins; + int skipBin = (hix + 1) + loy * p.widthBins; + int endBin = lox + (hiy + 1) * p.widthBins; + int binYInc = p.widthBins - (hix - lox + 1); + + // loop over triangle's bins + do + { + U32 outMask = s_outMask[threadIdx.y][currBin]; + if (outMask & (1< 0) + idx += s_outCount[threadIdx.y-1][currBin]; + + int base = s_outOfs[currBin]; + int free = (-base) & (CR_BIN_SEG_SIZE - 1); + if (idx >= free) + idx += ((allocBase + s_overIndex[currBin]) << CR_BIN_SEG_LOG2) - free; + else + idx += base; + + binSegData[idx] = s_triBuf[triPos]; + } + + currBin++; + if (currBin == skipBin) + currBin += binYInc, skipBin += p.widthBins; + } + while (currBin != endBin); + } + + // wait all triangles to finish, then replace overflown segment offsets + __syncthreads(); + if (thrInBlock < p.numBins) + { + U32 total = s_outCount[CR_BIN_WARPS - 1][thrInBlock]; + U32 oldOfs = s_outOfs[thrInBlock]; + if (overIndex == -1) + s_outOfs[thrInBlock] = oldOfs + total; + else + { + int addr = oldOfs + total; + addr = ((addr - 1) & (CR_BIN_SEG_SIZE - 1)) + 1; + addr += (allocBase + overIndex) << CR_BIN_SEG_LOG2; + s_outOfs[thrInBlock] = addr; + } + s_outTotal[thrInBlock] += total; + } + + // these triangles are now done + int count = ::min(bufCount, CR_BIN_WARPS * 32); + bufCount -= count; + bufIndex += count; + bufIndex &= CR_ARRAY_SIZE(s_triBuf)-1; + } + while (bufCount > 0 || batchPos < batchEnd); + + // flush all bins + if (thrInBlock < p.numBins) + { + int ofs = s_outOfs[thrInBlock]; + if (ofs & (CR_BIN_SEG_SIZE-1)) + { + int seg = ofs >> CR_BIN_SEG_LOG2; + binSegCount[seg] = ofs & (CR_BIN_SEG_SIZE-1); + s_outOfs[thrInBlock] = (ofs + CR_BIN_SEG_SIZE - 1) & -CR_BIN_SEG_SIZE; + } + } + } + + // output totals + if (thrInBlock < p.numBins) + binTotal[(thrInBlock << CR_BIN_STREAMS_LOG2) + blockIdx.x] = s_outTotal[thrInBlock]; +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Buffer.cpp b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Buffer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..acd42262ec68e324657bb7cbd7449150c136c6fc --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Buffer.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "../../framework.h" +#include "Buffer.hpp" + +using namespace CR; + +//------------------------------------------------------------------------ +// GPU buffer. +//------------------------------------------------------------------------ + +Buffer::Buffer(void) +: m_gpuPtr(NULL), + m_bytes (0) +{ + // empty +} + +Buffer::~Buffer(void) +{ + if (m_gpuPtr) + cudaFree(m_gpuPtr); // Don't throw an exception. +} + +void Buffer::reset(size_t bytes) +{ + if (bytes == m_bytes) + return; + + if (m_gpuPtr) + { + NVDR_CHECK_CUDA_ERROR(cudaFree(m_gpuPtr)); + m_gpuPtr = NULL; + } + + if (bytes > 0) + NVDR_CHECK_CUDA_ERROR(cudaMalloc(&m_gpuPtr, bytes)); + + m_bytes = bytes; +} + +void Buffer::grow(size_t bytes) +{ + if (bytes > m_bytes) + reset(bytes); +} + +//------------------------------------------------------------------------ +// Host buffer with page-locked memory. +//------------------------------------------------------------------------ + +HostBuffer::HostBuffer(void) +: m_hostPtr(NULL), + m_bytes (0) +{ + // empty +} + +HostBuffer::~HostBuffer(void) +{ + if (m_hostPtr) + cudaFreeHost(m_hostPtr); // Don't throw an exception. +} + +void HostBuffer::reset(size_t bytes) +{ + if (bytes == m_bytes) + return; + + if (m_hostPtr) + { + NVDR_CHECK_CUDA_ERROR(cudaFreeHost(m_hostPtr)); + m_hostPtr = NULL; + } + + if (bytes > 0) + NVDR_CHECK_CUDA_ERROR(cudaMallocHost(&m_hostPtr, bytes)); + + m_bytes = bytes; +} + +void HostBuffer::grow(size_t bytes) +{ + if (bytes > m_bytes) + reset(bytes); +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Buffer.hpp b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Buffer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1063a68d0c024f8580a99efefeeb0ed56ec3f239 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Buffer.hpp @@ -0,0 +1,55 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once +#include "Defs.hpp" + +namespace CR +{ +//------------------------------------------------------------------------ + +class Buffer +{ +public: + Buffer (void); + ~Buffer (void); + + void reset (size_t bytes); + void grow (size_t bytes); + void* getPtr (size_t offset = 0) { return (void*)(((uintptr_t)m_gpuPtr) + offset); } + size_t getSize (void) const { return m_bytes; } + + void setPtr (void* ptr) { m_gpuPtr = ptr; } + +private: + void* m_gpuPtr; + size_t m_bytes; +}; + +//------------------------------------------------------------------------ + +class HostBuffer +{ +public: + HostBuffer (void); + ~HostBuffer (void); + + void reset (size_t bytes); + void grow (size_t bytes); + void* getPtr (void) { return m_hostPtr; } + size_t getSize (void) const { return m_bytes; } + + void setPtr (void* ptr) { m_hostPtr = ptr; } + +private: + void* m_hostPtr; + size_t m_bytes; +}; + +//------------------------------------------------------------------------ +} diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/CoarseRaster.inl b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/CoarseRaster.inl new file mode 100644 index 0000000000000000000000000000000000000000..dc39da88d341da0d9d704f9cc3cdff1c793cca54 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/CoarseRaster.inl @@ -0,0 +1,730 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ + +__device__ __inline__ int globalTileIdx(int tileInBin, int widthTiles) +{ + int tileX = tileInBin & (CR_BIN_SIZE - 1); + int tileY = tileInBin >> CR_BIN_LOG2; + return tileX + tileY * widthTiles; +} + +//------------------------------------------------------------------------ + +__device__ __inline__ void coarseRasterImpl(const CRParams p) +{ + // Common. + + __shared__ volatile U32 s_workCounter; + __shared__ volatile U32 s_scanTemp [CR_COARSE_WARPS][48]; // 3KB + + // Input. + + __shared__ volatile U32 s_binOrder [CR_MAXBINS_SQR]; // 1KB + __shared__ volatile S32 s_binStreamCurrSeg [CR_BIN_STREAMS_SIZE]; // 0KB + __shared__ volatile S32 s_binStreamFirstTri [CR_BIN_STREAMS_SIZE]; // 0KB + __shared__ volatile S32 s_triQueue [CR_COARSE_QUEUE_SIZE]; // 4KB + __shared__ volatile S32 s_triQueueWritePos; + __shared__ volatile U32 s_binStreamSelectedOfs; + __shared__ volatile U32 s_binStreamSelectedSize; + + // Output. + + __shared__ volatile U32 s_warpEmitMask [CR_COARSE_WARPS][CR_BIN_SQR + 1]; // 16KB, +1 to avoid bank collisions + __shared__ volatile U32 s_warpEmitPrefixSum [CR_COARSE_WARPS][CR_BIN_SQR + 1]; // 16KB, +1 to avoid bank collisions + __shared__ volatile U32 s_tileEmitPrefixSum [CR_BIN_SQR + 1]; // 1KB, zero at the beginning + __shared__ volatile U32 s_tileAllocPrefixSum[CR_BIN_SQR + 1]; // 1KB, zero at the beginning + __shared__ volatile S32 s_tileStreamCurrOfs [CR_BIN_SQR]; // 1KB + __shared__ volatile U32 s_firstAllocSeg; + __shared__ volatile U32 s_firstActiveIdx; + + // Pointers and constants. + + CRAtomics& atomics = p.atomics[blockIdx.z]; + const CRTriangleHeader* triHeader = (const CRTriangleHeader*)p.triHeader + p.maxSubtris * blockIdx.z; + const S32* binFirstSeg = (const S32*)p.binFirstSeg + CR_MAXBINS_SQR * CR_BIN_STREAMS_SIZE * blockIdx.z; + const S32* binTotal = (const S32*)p.binTotal + CR_MAXBINS_SQR * CR_BIN_STREAMS_SIZE * blockIdx.z; + const S32* binSegData = (const S32*)p.binSegData + p.maxBinSegs * CR_BIN_SEG_SIZE * blockIdx.z; + const S32* binSegNext = (const S32*)p.binSegNext + p.maxBinSegs * blockIdx.z; + const S32* binSegCount = (const S32*)p.binSegCount + p.maxBinSegs * blockIdx.z; + S32* activeTiles = (S32*)p.activeTiles + CR_MAXTILES_SQR * blockIdx.z; + S32* tileFirstSeg = (S32*)p.tileFirstSeg + CR_MAXTILES_SQR * blockIdx.z; + S32* tileSegData = (S32*)p.tileSegData + p.maxTileSegs * CR_TILE_SEG_SIZE * blockIdx.z; + S32* tileSegNext = (S32*)p.tileSegNext + p.maxTileSegs * blockIdx.z; + S32* tileSegCount = (S32*)p.tileSegCount + p.maxTileSegs * blockIdx.z; + + int tileLog = CR_TILE_LOG2 + CR_SUBPIXEL_LOG2; + int thrInBlock = threadIdx.x + threadIdx.y * 32; + int emitShift = CR_BIN_LOG2 * 2 + 5; // We scan ((numEmits << emitShift) | numAllocs) over tiles. + + if (atomics.numSubtris > p.maxSubtris || atomics.numBinSegs > p.maxBinSegs) + return; + + // Initialize sharedmem arrays. + + if (thrInBlock == 0) + { + s_tileEmitPrefixSum[0] = 0; + s_tileAllocPrefixSum[0] = 0; + } + s_scanTemp[threadIdx.y][threadIdx.x] = 0; + + // Sort bins in descending order of triangle count. + + for (int binIdx = thrInBlock; binIdx < p.numBins; binIdx += CR_COARSE_WARPS * 32) + { + int count = 0; + for (int i = 0; i < CR_BIN_STREAMS_SIZE; i++) + count += binTotal[(binIdx << CR_BIN_STREAMS_LOG2) + i]; + s_binOrder[binIdx] = (~count << (CR_MAXBINS_LOG2 * 2)) | binIdx; + } + + __syncthreads(); + sortShared(s_binOrder, p.numBins); + + // Process each bin by one block. + + for (;;) + { + // Pick a bin for the block. + + if (thrInBlock == 0) + s_workCounter = atomicAdd(&atomics.coarseCounter, 1); + __syncthreads(); + + int workCounter = s_workCounter; + if (workCounter >= p.numBins) + break; + + U32 binOrder = s_binOrder[workCounter]; + bool binEmpty = ((~binOrder >> (CR_MAXBINS_LOG2 * 2)) == 0); + if (binEmpty && !p.deferredClear) + break; + + int binIdx = binOrder & (CR_MAXBINS_SQR - 1); + + // Initialize input/output streams. + + int triQueueWritePos = 0; + int triQueueReadPos = 0; + + if (thrInBlock < CR_BIN_STREAMS_SIZE) + { + int segIdx = binFirstSeg[(binIdx << CR_BIN_STREAMS_LOG2) + thrInBlock]; + s_binStreamCurrSeg[thrInBlock] = segIdx; + s_binStreamFirstTri[thrInBlock] = (segIdx == -1) ? ~0u : binSegData[segIdx << CR_BIN_SEG_LOG2]; + } + + for (int tileInBin = CR_COARSE_WARPS * 32 - 1 - thrInBlock; tileInBin < CR_BIN_SQR; tileInBin += CR_COARSE_WARPS * 32) + s_tileStreamCurrOfs[tileInBin] = -CR_TILE_SEG_SIZE; + + // Initialize per-bin state. + + int binY = idiv_fast(binIdx, p.widthBins); + int binX = binIdx - binY * p.widthBins; + int originX = (binX << (CR_BIN_LOG2 + tileLog)) - (p.widthPixelsVp << (CR_SUBPIXEL_LOG2 - 1)); + int originY = (binY << (CR_BIN_LOG2 + tileLog)) - (p.heightPixelsVp << (CR_SUBPIXEL_LOG2 - 1)); + int maxTileXInBin = ::min(p.widthTiles - (binX << CR_BIN_LOG2), CR_BIN_SIZE) - 1; + int maxTileYInBin = ::min(p.heightTiles - (binY << CR_BIN_LOG2), CR_BIN_SIZE) - 1; + int binTileIdx = (binX + binY * p.widthTiles) << CR_BIN_LOG2; + + // Entire block: Merge input streams and process triangles. + + if (!binEmpty) + do + { + //------------------------------------------------------------------------ + // Merge. + //------------------------------------------------------------------------ + + // Entire block: Not enough triangles => merge and queue segments. + // NOTE: The bin exit criterion assumes that we queue more triangles than we actually need. + + while (triQueueWritePos - triQueueReadPos <= CR_COARSE_WARPS * 32) + { + // First warp: Choose the segment with the lowest initial triangle index. + + bool hasStream = (thrInBlock < CR_BIN_STREAMS_SIZE); + U32 hasStreamMask = __ballot_sync(~0u, hasStream); + if (hasStream) + { + // Find the stream with the lowest triangle index. + + U32 firstTri = s_binStreamFirstTri[thrInBlock]; + U32 t = firstTri; + volatile U32* v = &s_scanTemp[0][thrInBlock + 16]; + + #if (CR_BIN_STREAMS_SIZE > 1) + v[0] = t; __syncwarp(hasStreamMask); t = ::min(t, v[-1]); __syncwarp(hasStreamMask); + #endif + #if (CR_BIN_STREAMS_SIZE > 2) + v[0] = t; __syncwarp(hasStreamMask); t = ::min(t, v[-2]); __syncwarp(hasStreamMask); + #endif + #if (CR_BIN_STREAMS_SIZE > 4) + v[0] = t; __syncwarp(hasStreamMask); t = ::min(t, v[-4]); __syncwarp(hasStreamMask); + #endif + #if (CR_BIN_STREAMS_SIZE > 8) + v[0] = t; __syncwarp(hasStreamMask); t = ::min(t, v[-8]); __syncwarp(hasStreamMask); + #endif + #if (CR_BIN_STREAMS_SIZE > 16) + v[0] = t; __syncwarp(hasStreamMask); t = ::min(t, v[-16]); __syncwarp(hasStreamMask); + #endif + v[0] = t; __syncwarp(hasStreamMask); + + // Consume and broadcast. + + bool first = (s_scanTemp[0][CR_BIN_STREAMS_SIZE - 1 + 16] == firstTri); + U32 firstMask = __ballot_sync(hasStreamMask, first); + if (first && (firstMask >> threadIdx.x) == 1u) + { + int segIdx = s_binStreamCurrSeg[thrInBlock]; + s_binStreamSelectedOfs = segIdx << CR_BIN_SEG_LOG2; + if (segIdx != -1) + { + int segSize = binSegCount[segIdx]; + int segNext = binSegNext[segIdx]; + s_binStreamSelectedSize = segSize; + s_triQueueWritePos = triQueueWritePos + segSize; + s_binStreamCurrSeg[thrInBlock] = segNext; + s_binStreamFirstTri[thrInBlock] = (segNext == -1) ? ~0u : binSegData[segNext << CR_BIN_SEG_LOG2]; + } + } + } + + // No more segments => break. + + __syncthreads(); + triQueueWritePos = s_triQueueWritePos; + int segOfs = s_binStreamSelectedOfs; + if (segOfs < 0) + break; + + int segSize = s_binStreamSelectedSize; + __syncthreads(); + + // Fetch triangles into the queue. + + for (int idxInSeg = CR_COARSE_WARPS * 32 - 1 - thrInBlock; idxInSeg < segSize; idxInSeg += CR_COARSE_WARPS * 32) + { + S32 triIdx = binSegData[segOfs + idxInSeg]; + s_triQueue[(triQueueWritePos - segSize + idxInSeg) & (CR_COARSE_QUEUE_SIZE - 1)] = triIdx; + } + } + + // All threads: Clear emit masks. + + for (int maskIdx = thrInBlock; maskIdx < CR_COARSE_WARPS * CR_BIN_SQR; maskIdx += CR_COARSE_WARPS * 32) + s_warpEmitMask[maskIdx >> (CR_BIN_LOG2 * 2)][maskIdx & (CR_BIN_SQR - 1)] = 0; + + __syncthreads(); + + //------------------------------------------------------------------------ + // Raster. + //------------------------------------------------------------------------ + + // Triangle per thread: Read from the queue. + + int triIdx = -1; + if (triQueueReadPos + thrInBlock < triQueueWritePos) + triIdx = s_triQueue[(triQueueReadPos + thrInBlock) & (CR_COARSE_QUEUE_SIZE - 1)]; + + uint4 triData = make_uint4(0, 0, 0, 0); + if (triIdx != -1) + { + int dataIdx = triIdx >> 3; + int subtriIdx = triIdx & 7; + if (subtriIdx != 7) + dataIdx = triHeader[dataIdx].misc + subtriIdx; + triData = *((uint4*)triHeader + dataIdx); + } + + // 32 triangles per warp: Record emits (= tile intersections). + + if (__any_sync(~0u, triIdx != -1)) + { + S32 v0x = sub_s16lo_s16lo(triData.x, originX); + S32 v0y = sub_s16hi_s16lo(triData.x, originY); + S32 d01x = sub_s16lo_s16lo(triData.y, triData.x); + S32 d01y = sub_s16hi_s16hi(triData.y, triData.x); + S32 d02x = sub_s16lo_s16lo(triData.z, triData.x); + S32 d02y = sub_s16hi_s16hi(triData.z, triData.x); + + // Compute tile-based AABB. + + int lox = add_clamp_0_x((v0x + min_min(d01x, 0, d02x)) >> tileLog, 0, maxTileXInBin); + int loy = add_clamp_0_x((v0y + min_min(d01y, 0, d02y)) >> tileLog, 0, maxTileYInBin); + int hix = add_clamp_0_x((v0x + max_max(d01x, 0, d02x)) >> tileLog, 0, maxTileXInBin); + int hiy = add_clamp_0_x((v0y + max_max(d01y, 0, d02y)) >> tileLog, 0, maxTileYInBin); + int sizex = add_sub(hix, 1, lox); + int sizey = add_sub(hiy, 1, loy); + int area = sizex * sizey; + + // Miscellaneous init. + + U8* currPtr = (U8*)&s_warpEmitMask[threadIdx.y][lox + (loy << CR_BIN_LOG2)]; + int ptrYInc = CR_BIN_SIZE * 4 - (sizex << 2); + U32 maskBit = 1 << threadIdx.x; + + // Case A: All AABBs are small => record the full AABB using atomics. + + if (__all_sync(~0u, sizex <= 2 && sizey <= 2)) + { + if (triIdx != -1) + { + atomicOr((U32*)currPtr, maskBit); + if (sizex == 2) atomicOr((U32*)(currPtr + 4), maskBit); + if (sizey == 2) atomicOr((U32*)(currPtr + CR_BIN_SIZE * 4), maskBit); + if (sizex == 2 && sizey == 2) atomicOr((U32*)(currPtr + 4 + CR_BIN_SIZE * 4), maskBit); + } + } + else + { + // Compute warp-AABB (scan-32). + + U32 aabbMask = add_sub(2 << hix, 0x20000 << hiy, 1 << lox) - (0x10000 << loy); + if (triIdx == -1) + aabbMask = 0; + + volatile U32* v = &s_scanTemp[threadIdx.y][threadIdx.x + 16]; + v[0] = aabbMask; __syncwarp(); aabbMask |= v[-1]; __syncwarp(); + v[0] = aabbMask; __syncwarp(); aabbMask |= v[-2]; __syncwarp(); + v[0] = aabbMask; __syncwarp(); aabbMask |= v[-4]; __syncwarp(); + v[0] = aabbMask; __syncwarp(); aabbMask |= v[-8]; __syncwarp(); + v[0] = aabbMask; __syncwarp(); aabbMask |= v[-16]; __syncwarp(); + v[0] = aabbMask; __syncwarp(); aabbMask = s_scanTemp[threadIdx.y][47]; + + U32 maskX = aabbMask & 0xFFFF; + U32 maskY = aabbMask >> 16; + int wlox = findLeadingOne(maskX ^ (maskX - 1)); + int wloy = findLeadingOne(maskY ^ (maskY - 1)); + int whix = findLeadingOne(maskX); + int whiy = findLeadingOne(maskY); + int warea = (add_sub(whix, 1, wlox)) * (add_sub(whiy, 1, wloy)); + + // Initialize edge functions. + + S32 d12x = d02x - d01x; + S32 d12y = d02y - d01y; + v0x -= lox << tileLog; + v0y -= loy << tileLog; + + S32 t01 = v0x * d01y - v0y * d01x; + S32 t02 = v0y * d02x - v0x * d02y; + S32 t12 = d01x * d12y - d01y * d12x - t01 - t02; + S32 b01 = add_sub(t01 >> tileLog, ::max(d01x, 0), ::min(d01y, 0)); + S32 b02 = add_sub(t02 >> tileLog, ::max(d02y, 0), ::min(d02x, 0)); + S32 b12 = add_sub(t12 >> tileLog, ::max(d12x, 0), ::min(d12y, 0)); + + d01x += sizex * d01y; + d02x += sizex * d02y; + d12x += sizex * d12y; + + // Case B: Warp-AABB is not much larger than largest AABB => Check tiles in warp-AABB, record using ballots. + if (__any_sync(~0u, warea * 4 <= area * 8)) + { + // Not sure if this is any faster than Case C after all the post-Volta ballot mask tracking. + bool act = (triIdx != -1); + U32 actMask = __ballot_sync(~0u, act); + if (act) + { + for (int y = wloy; y <= whiy; y++) + { + bool yIn = (y >= loy && y <= hiy); + U32 yMask = __ballot_sync(actMask, yIn); + if (yIn) + { + for (int x = wlox; x <= whix; x++) + { + bool xyIn = (x >= lox && x <= hix); + U32 xyMask = __ballot_sync(yMask, xyIn); + if (xyIn) + { + U32 res = __ballot_sync(xyMask, b01 >= 0 && b02 >= 0 && b12 >= 0); + if (threadIdx.x == 31 - __clz(xyMask)) + *(U32*)currPtr = res; + currPtr += 4, b01 -= d01y, b02 += d02y, b12 -= d12y; + } + } + currPtr += ptrYInc, b01 += d01x, b02 -= d02x, b12 += d12x; + } + } + } + } + + // Case C: General case => Check tiles in AABB, record using atomics. + + else + { + if (triIdx != -1) + { + U8* skipPtr = currPtr + (sizex << 2); + U8* endPtr = currPtr + (sizey << (CR_BIN_LOG2 + 2)); + do + { + if (b01 >= 0 && b02 >= 0 && b12 >= 0) + atomicOr((U32*)currPtr, maskBit); + currPtr += 4, b01 -= d01y, b02 += d02y, b12 -= d12y; + if (currPtr == skipPtr) + currPtr += ptrYInc, b01 += d01x, b02 -= d02x, b12 += d12x, skipPtr += CR_BIN_SIZE * 4; + } + while (currPtr != endPtr); + } + } + } + } + + __syncthreads(); + + //------------------------------------------------------------------------ + // Count. + //------------------------------------------------------------------------ + + // Tile per thread: Initialize prefix sums. + + for (int tileInBin_base = 0; tileInBin_base < CR_BIN_SQR; tileInBin_base += CR_COARSE_WARPS * 32) + { + int tileInBin = tileInBin_base + thrInBlock; + bool act = (tileInBin < CR_BIN_SQR); + U32 actMask = __ballot_sync(~0u, act); + if (act) + { + // Compute prefix sum of emits over warps. + + U8* srcPtr = (U8*)&s_warpEmitMask[0][tileInBin]; + U8* dstPtr = (U8*)&s_warpEmitPrefixSum[0][tileInBin]; + int tileEmits = 0; + for (int i = 0; i < CR_COARSE_WARPS; i++) + { + tileEmits += __popc(*(U32*)srcPtr); + *(U32*)dstPtr = tileEmits; + srcPtr += (CR_BIN_SQR + 1) * 4; + dstPtr += (CR_BIN_SQR + 1) * 4; + } + + // Determine the number of segments to allocate. + + int spaceLeft = -s_tileStreamCurrOfs[tileInBin] & (CR_TILE_SEG_SIZE - 1); + int tileAllocs = (tileEmits - spaceLeft + CR_TILE_SEG_SIZE - 1) >> CR_TILE_SEG_LOG2; + volatile U32* v = &s_tileEmitPrefixSum[tileInBin + 1]; + + // All counters within the warp are small => compute prefix sum using ballot. + + if (!__any_sync(actMask, tileEmits >= 2)) + { + U32 m = getLaneMaskLe(); + *v = (__popc(__ballot_sync(actMask, tileEmits & 1) & m) << emitShift) | __popc(__ballot_sync(actMask, tileAllocs & 1) & m); + } + + // Otherwise => scan-32 within the warp. + + else + { + U32 sum = (tileEmits << emitShift) | tileAllocs; + *v = sum; __syncwarp(actMask); if (threadIdx.x >= 1) sum += v[-1]; __syncwarp(actMask); + *v = sum; __syncwarp(actMask); if (threadIdx.x >= 2) sum += v[-2]; __syncwarp(actMask); + *v = sum; __syncwarp(actMask); if (threadIdx.x >= 4) sum += v[-4]; __syncwarp(actMask); + *v = sum; __syncwarp(actMask); if (threadIdx.x >= 8) sum += v[-8]; __syncwarp(actMask); + *v = sum; __syncwarp(actMask); if (threadIdx.x >= 16) sum += v[-16]; __syncwarp(actMask); + *v = sum; __syncwarp(actMask); + } + } + } + + // First warp: Scan-8. + + __syncthreads(); + + bool scan8 = (thrInBlock < CR_BIN_SQR / 32); + U32 scan8Mask = __ballot_sync(~0u, scan8); + if (scan8) + { + int sum = s_tileEmitPrefixSum[(thrInBlock << 5) + 32]; + volatile U32* v = &s_scanTemp[0][thrInBlock + 16]; + v[0] = sum; __syncwarp(scan8Mask); + #if (CR_BIN_SQR > 1 * 32) + sum += v[-1]; __syncwarp(scan8Mask); v[0] = sum; __syncwarp(scan8Mask); + #endif + #if (CR_BIN_SQR > 2 * 32) + sum += v[-2]; __syncwarp(scan8Mask); v[0] = sum; __syncwarp(scan8Mask); + #endif + #if (CR_BIN_SQR > 4 * 32) + sum += v[-4]; __syncwarp(scan8Mask); v[0] = sum; __syncwarp(scan8Mask); + #endif + } + + __syncthreads(); + + // Tile per thread: Finalize prefix sums. + // Single thread: Allocate segments. + + for (int tileInBin = thrInBlock; tileInBin < CR_BIN_SQR; tileInBin += CR_COARSE_WARPS * 32) + { + int sum = s_tileEmitPrefixSum[tileInBin + 1] + s_scanTemp[0][(tileInBin >> 5) + 15]; + int numEmits = sum >> emitShift; + int numAllocs = sum & ((1 << emitShift) - 1); + s_tileEmitPrefixSum[tileInBin + 1] = numEmits; + s_tileAllocPrefixSum[tileInBin + 1] = numAllocs; + + if (tileInBin == CR_BIN_SQR - 1 && numAllocs != 0) + { + int t = atomicAdd(&atomics.numTileSegs, numAllocs); + s_firstAllocSeg = (t + numAllocs <= p.maxTileSegs) ? t : 0; + } + } + + __syncthreads(); + int firstAllocSeg = s_firstAllocSeg; + int totalEmits = s_tileEmitPrefixSum[CR_BIN_SQR]; + int totalAllocs = s_tileAllocPrefixSum[CR_BIN_SQR]; + + //------------------------------------------------------------------------ + // Emit. + //------------------------------------------------------------------------ + + // Emit per thread: Write triangle index to globalmem. + + for (int emitInBin = thrInBlock; emitInBin < totalEmits; emitInBin += CR_COARSE_WARPS * 32) + { + // Find tile in bin. + + U8* tileBase = (U8*)&s_tileEmitPrefixSum[0]; + U8* tilePtr = tileBase; + U8* ptr; + + #if (CR_BIN_SQR > 128) + ptr = tilePtr + 0x80 * 4; if (emitInBin >= *(U32*)ptr) tilePtr = ptr; + #endif + #if (CR_BIN_SQR > 64) + ptr = tilePtr + 0x40 * 4; if (emitInBin >= *(U32*)ptr) tilePtr = ptr; + #endif + #if (CR_BIN_SQR > 32) + ptr = tilePtr + 0x20 * 4; if (emitInBin >= *(U32*)ptr) tilePtr = ptr; + #endif + #if (CR_BIN_SQR > 16) + ptr = tilePtr + 0x10 * 4; if (emitInBin >= *(U32*)ptr) tilePtr = ptr; + #endif + #if (CR_BIN_SQR > 8) + ptr = tilePtr + 0x08 * 4; if (emitInBin >= *(U32*)ptr) tilePtr = ptr; + #endif + #if (CR_BIN_SQR > 4) + ptr = tilePtr + 0x04 * 4; if (emitInBin >= *(U32*)ptr) tilePtr = ptr; + #endif + #if (CR_BIN_SQR > 2) + ptr = tilePtr + 0x02 * 4; if (emitInBin >= *(U32*)ptr) tilePtr = ptr; + #endif + #if (CR_BIN_SQR > 1) + ptr = tilePtr + 0x01 * 4; if (emitInBin >= *(U32*)ptr) tilePtr = ptr; + #endif + + int tileInBin = (tilePtr - tileBase) >> 2; + int emitInTile = emitInBin - *(U32*)tilePtr; + + // Find warp in tile. + + int warpStep = (CR_BIN_SQR + 1) * 4; + U8* warpBase = (U8*)&s_warpEmitPrefixSum[0][tileInBin] - warpStep; + U8* warpPtr = warpBase; + + #if (CR_COARSE_WARPS > 8) + ptr = warpPtr + 0x08 * warpStep; if (emitInTile >= *(U32*)ptr) warpPtr = ptr; + #endif + #if (CR_COARSE_WARPS > 4) + ptr = warpPtr + 0x04 * warpStep; if (emitInTile >= *(U32*)ptr) warpPtr = ptr; + #endif + #if (CR_COARSE_WARPS > 2) + ptr = warpPtr + 0x02 * warpStep; if (emitInTile >= *(U32*)ptr) warpPtr = ptr; + #endif + #if (CR_COARSE_WARPS > 1) + ptr = warpPtr + 0x01 * warpStep; if (emitInTile >= *(U32*)ptr) warpPtr = ptr; + #endif + + int warpInTile = (warpPtr - warpBase) >> (CR_BIN_LOG2 * 2 + 2); + U32 emitMask = *(U32*)(warpPtr + warpStep + ((U8*)s_warpEmitMask - (U8*)s_warpEmitPrefixSum)); + int emitInWarp = emitInTile - *(U32*)(warpPtr + warpStep) + __popc(emitMask); + + // Find thread in warp. + + int threadInWarp = 0; + int pop = __popc(emitMask & 0xFFFF); + bool pred = (emitInWarp >= pop); + if (pred) emitInWarp -= pop; + if (pred) emitMask >>= 0x10; + if (pred) threadInWarp += 0x10; + + pop = __popc(emitMask & 0xFF); + pred = (emitInWarp >= pop); + if (pred) emitInWarp -= pop; + if (pred) emitMask >>= 0x08; + if (pred) threadInWarp += 0x08; + + pop = __popc(emitMask & 0xF); + pred = (emitInWarp >= pop); + if (pred) emitInWarp -= pop; + if (pred) emitMask >>= 0x04; + if (pred) threadInWarp += 0x04; + + pop = __popc(emitMask & 0x3); + pred = (emitInWarp >= pop); + if (pred) emitInWarp -= pop; + if (pred) emitMask >>= 0x02; + if (pred) threadInWarp += 0x02; + + if (emitInWarp >= (emitMask & 1)) + threadInWarp++; + + // Figure out where to write. + + int currOfs = s_tileStreamCurrOfs[tileInBin]; + int spaceLeft = -currOfs & (CR_TILE_SEG_SIZE - 1); + int outOfs = emitInTile; + + if (outOfs < spaceLeft) + outOfs += currOfs; + else + { + int allocLo = firstAllocSeg + s_tileAllocPrefixSum[tileInBin]; + outOfs += (allocLo << CR_TILE_SEG_LOG2) - spaceLeft; + } + + // Write. + + int queueIdx = warpInTile * 32 + threadInWarp; + int triIdx = s_triQueue[(triQueueReadPos + queueIdx) & (CR_COARSE_QUEUE_SIZE - 1)]; + + tileSegData[outOfs] = triIdx; + } + + //------------------------------------------------------------------------ + // Patch. + //------------------------------------------------------------------------ + + // Allocated segment per thread: Initialize next-pointer and count. + + for (int i = CR_COARSE_WARPS * 32 - 1 - thrInBlock; i < totalAllocs; i += CR_COARSE_WARPS * 32) + { + int segIdx = firstAllocSeg + i; + tileSegNext[segIdx] = segIdx + 1; + tileSegCount[segIdx] = CR_TILE_SEG_SIZE; + } + + // Tile per thread: Fix previous segment's next-pointer and update s_tileStreamCurrOfs. + + __syncthreads(); + for (int tileInBin = CR_COARSE_WARPS * 32 - 1 - thrInBlock; tileInBin < CR_BIN_SQR; tileInBin += CR_COARSE_WARPS * 32) + { + int oldOfs = s_tileStreamCurrOfs[tileInBin]; + int newOfs = oldOfs + s_warpEmitPrefixSum[CR_COARSE_WARPS - 1][tileInBin]; + int allocLo = s_tileAllocPrefixSum[tileInBin]; + int allocHi = s_tileAllocPrefixSum[tileInBin + 1]; + + if (allocLo != allocHi) + { + S32* nextPtr = &tileSegNext[(oldOfs - 1) >> CR_TILE_SEG_LOG2]; + if (oldOfs < 0) + nextPtr = &tileFirstSeg[binTileIdx + globalTileIdx(tileInBin, p.widthTiles)]; + *nextPtr = firstAllocSeg + allocLo; + + newOfs--; + newOfs &= CR_TILE_SEG_SIZE - 1; + newOfs |= (firstAllocSeg + allocHi - 1) << CR_TILE_SEG_LOG2; + newOfs++; + } + s_tileStreamCurrOfs[tileInBin] = newOfs; + } + + // Advance queue read pointer. + // Queue became empty => bin done. + + triQueueReadPos += CR_COARSE_WARPS * 32; + } + while (triQueueReadPos < triQueueWritePos); + + // Tile per thread: Fix next-pointer and count of the last segment. + // 32 tiles per warp: Count active tiles. + + __syncthreads(); + + for (int tileInBin_base = 0; tileInBin_base < CR_BIN_SQR; tileInBin_base += CR_COARSE_WARPS * 32) + { + int tileInBin = tileInBin_base + thrInBlock; + bool act = (tileInBin < CR_BIN_SQR); + U32 actMask = __ballot_sync(~0u, act); + if (act) + { + int tileX = tileInBin & (CR_BIN_SIZE - 1); + int tileY = tileInBin >> CR_BIN_LOG2; + bool force = (p.deferredClear & tileX <= maxTileXInBin & tileY <= maxTileYInBin); + + int ofs = s_tileStreamCurrOfs[tileInBin]; + int segIdx = (ofs - 1) >> CR_TILE_SEG_LOG2; + int segCount = ofs & (CR_TILE_SEG_SIZE - 1); + + if (ofs >= 0) + tileSegNext[segIdx] = -1; + else if (force) + { + s_tileStreamCurrOfs[tileInBin] = 0; + tileFirstSeg[binTileIdx + tileX + tileY * p.widthTiles] = -1; + } + + if (segCount != 0) + tileSegCount[segIdx] = segCount; + + U32 res = __ballot_sync(actMask, ofs >= 0 | force); + if (threadIdx.x == 0) + s_scanTemp[0][(tileInBin >> 5) + 16] = __popc(res); + } + } + + // First warp: Scan-8. + // One thread: Allocate space for active tiles. + + __syncthreads(); + + bool scan8 = (thrInBlock < CR_BIN_SQR / 32); + U32 scan8Mask = __ballot_sync(~0u, scan8); + if (scan8) + { + volatile U32* v = &s_scanTemp[0][thrInBlock + 16]; + U32 sum = v[0]; + #if (CR_BIN_SQR > 1 * 32) + sum += v[-1]; __syncwarp(scan8Mask); v[0] = sum; __syncwarp(scan8Mask); + #endif + #if (CR_BIN_SQR > 2 * 32) + sum += v[-2]; __syncwarp(scan8Mask); v[0] = sum; __syncwarp(scan8Mask); + #endif + #if (CR_BIN_SQR > 4 * 32) + sum += v[-4]; __syncwarp(scan8Mask); v[0] = sum; __syncwarp(scan8Mask); + #endif + + if (thrInBlock == CR_BIN_SQR / 32 - 1) + s_firstActiveIdx = atomicAdd(&atomics.numActiveTiles, sum); + } + + // Tile per thread: Output active tiles. + + __syncthreads(); + + for (int tileInBin_base = 0; tileInBin_base < CR_BIN_SQR; tileInBin_base += CR_COARSE_WARPS * 32) + { + int tileInBin = tileInBin_base + thrInBlock; + bool act = (tileInBin < CR_BIN_SQR) && (s_tileStreamCurrOfs[tileInBin] >= 0); + U32 actMask = __ballot_sync(~0u, act); + if (act) + { + int activeIdx = s_firstActiveIdx; + activeIdx += s_scanTemp[0][(tileInBin >> 5) + 15]; + activeIdx += __popc(actMask & getLaneMaskLt()); + activeTiles[activeIdx] = binTileIdx + globalTileIdx(tileInBin, p.widthTiles); + } + } + } +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Constants.hpp b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Constants.hpp new file mode 100644 index 0000000000000000000000000000000000000000..42d757fbba82e75c1358caee7419a4be009d7464 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Constants.hpp @@ -0,0 +1,73 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once + +//------------------------------------------------------------------------ + +#define CR_MAXVIEWPORT_LOG2 11 // ViewportSize / PixelSize. +#define CR_SUBPIXEL_LOG2 4 // PixelSize / SubpixelSize. + +#define CR_MAXBINS_LOG2 4 // ViewportSize / BinSize. +#define CR_BIN_LOG2 4 // BinSize / TileSize. +#define CR_TILE_LOG2 3 // TileSize / PixelSize. + +#define CR_COVER8X8_LUT_SIZE 768 // 64-bit entries. +#define CR_FLIPBIT_FLIP_Y 2 +#define CR_FLIPBIT_FLIP_X 3 +#define CR_FLIPBIT_SWAP_XY 4 +#define CR_FLIPBIT_COMPL 5 + +#define CR_BIN_STREAMS_LOG2 4 +#define CR_BIN_SEG_LOG2 9 // 32-bit entries. +#define CR_TILE_SEG_LOG2 5 // 32-bit entries. + +#define CR_MAXSUBTRIS_LOG2 24 // Triangle structs. Dictated by CoarseRaster. +#define CR_COARSE_QUEUE_LOG2 10 // Triangles. + +#define CR_SETUP_WARPS 2 +#define CR_SETUP_OPT_BLOCKS 8 +#define CR_BIN_WARPS 16 +#define CR_COARSE_WARPS 16 // Must be a power of two. +#define CR_FINE_MAX_WARPS 20 + +#define CR_EMBED_IMAGE_PARAMS 32 // Number of per-image parameter structs embedded in kernel launch parameter block. + +//------------------------------------------------------------------------ + +#define CR_MAXVIEWPORT_SIZE (1 << CR_MAXVIEWPORT_LOG2) +#define CR_SUBPIXEL_SIZE (1 << CR_SUBPIXEL_LOG2) +#define CR_SUBPIXEL_SQR (1 << (CR_SUBPIXEL_LOG2 * 2)) + +#define CR_MAXBINS_SIZE (1 << CR_MAXBINS_LOG2) +#define CR_MAXBINS_SQR (1 << (CR_MAXBINS_LOG2 * 2)) +#define CR_BIN_SIZE (1 << CR_BIN_LOG2) +#define CR_BIN_SQR (1 << (CR_BIN_LOG2 * 2)) + +#define CR_MAXTILES_LOG2 (CR_MAXBINS_LOG2 + CR_BIN_LOG2) +#define CR_MAXTILES_SIZE (1 << CR_MAXTILES_LOG2) +#define CR_MAXTILES_SQR (1 << (CR_MAXTILES_LOG2 * 2)) +#define CR_TILE_SIZE (1 << CR_TILE_LOG2) +#define CR_TILE_SQR (1 << (CR_TILE_LOG2 * 2)) + +#define CR_BIN_STREAMS_SIZE (1 << CR_BIN_STREAMS_LOG2) +#define CR_BIN_SEG_SIZE (1 << CR_BIN_SEG_LOG2) +#define CR_TILE_SEG_SIZE (1 << CR_TILE_SEG_LOG2) + +#define CR_MAXSUBTRIS_SIZE (1 << CR_MAXSUBTRIS_LOG2) +#define CR_COARSE_QUEUE_SIZE (1 << CR_COARSE_QUEUE_LOG2) + +//------------------------------------------------------------------------ +// When evaluating interpolated Z pixel centers, we may introduce an error +// of (+-CR_LERP_ERROR) ULPs. + +#define CR_LERP_ERROR(SAMPLES_LOG2) (2200u << (SAMPLES_LOG2)) +#define CR_DEPTH_MIN CR_LERP_ERROR(3) +#define CR_DEPTH_MAX (CR_U32_MAX - CR_LERP_ERROR(3)) + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/CudaRaster.cpp b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/CudaRaster.cpp new file mode 100644 index 0000000000000000000000000000000000000000..21f9a484324c7e5409b8e10796e7b6d12ffdbe2f --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/CudaRaster.cpp @@ -0,0 +1,79 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "Defs.hpp" +#include "../CudaRaster.hpp" +#include "RasterImpl.hpp" + +using namespace CR; + +//------------------------------------------------------------------------ +// Stub interface implementation. +//------------------------------------------------------------------------ + +CudaRaster::CudaRaster() +{ + m_impl = new RasterImpl(); +} + +CudaRaster::~CudaRaster() +{ + delete m_impl; +} + +void CudaRaster::setBufferSize(int width, int height, int numImages) +{ + m_impl->setBufferSize(Vec3i(width, height, numImages)); +} + +void CudaRaster::setViewport(int width, int height, int offsetX, int offsetY) +{ + m_impl->setViewport(Vec2i(width, height), Vec2i(offsetX, offsetY)); +} + +void CudaRaster::setRenderModeFlags(U32 flags) +{ + m_impl->setRenderModeFlags(flags); +} + +void CudaRaster::deferredClear(U32 clearColor) +{ + m_impl->deferredClear(clearColor); +} + +void CudaRaster::setVertexBuffer(void* vertices, int numVertices) +{ + m_impl->setVertexBuffer(vertices, numVertices); +} + +void CudaRaster::setIndexBuffer(void* indices, int numTriangles) +{ + m_impl->setIndexBuffer(indices, numTriangles); +} + +bool CudaRaster::drawTriangles(const int* ranges, bool peel, cudaStream_t stream) +{ + return m_impl->drawTriangles((const Vec2i*)ranges, peel, stream); +} + +void* CudaRaster::getColorBuffer(void) +{ + return m_impl->getColorBuffer(); +} + +void* CudaRaster::getDepthBuffer(void) +{ + return m_impl->getDepthBuffer(); +} + +void CudaRaster::swapDepthAndPeel(void) +{ + m_impl->swapDepthAndPeel(); +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Defs.hpp b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Defs.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5831664434b613c8443f2e1ece0a0ede0505ca80 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Defs.hpp @@ -0,0 +1,90 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once +#include +#include + +namespace CR +{ +//------------------------------------------------------------------------ + +#ifndef NULL +# define NULL 0 +#endif + +#ifdef __CUDACC__ +# define CR_CUDA 1 +#else +# define CR_CUDA 0 +#endif + +#if CR_CUDA +# define CR_CUDA_FUNC __device__ __inline__ +# define CR_CUDA_CONST __constant__ +#else +# define CR_CUDA_FUNC inline +# define CR_CUDA_CONST static const +#endif + +#define CR_UNREF(X) ((void)(X)) +#define CR_ARRAY_SIZE(X) ((int)(sizeof(X) / sizeof((X)[0]))) + +//------------------------------------------------------------------------ + +typedef uint8_t U8; +typedef uint16_t U16; +typedef uint32_t U32; +typedef uint64_t U64; +typedef int8_t S8; +typedef int16_t S16; +typedef int32_t S32; +typedef int64_t S64; +typedef float F32; +typedef double F64; +typedef void (*FuncPtr)(void); + +//------------------------------------------------------------------------ + +#define CR_U32_MAX (0xFFFFFFFFu) +#define CR_S32_MIN (~0x7FFFFFFF) +#define CR_S32_MAX (0x7FFFFFFF) +#define CR_U64_MAX ((U64)(S64)-1) +#define CR_S64_MIN ((S64)-1 << 63) +#define CR_S64_MAX (~((S64)-1 << 63)) +#define CR_F32_MIN (1.175494351e-38f) +#define CR_F32_MAX (3.402823466e+38f) +#define CR_F64_MIN (2.2250738585072014e-308) +#define CR_F64_MAX (1.7976931348623158e+308) + +//------------------------------------------------------------------------ +// Misc types. + +class Vec2i +{ +public: + Vec2i(int x_, int y_) : x(x_), y(y_) {} + int x, y; +}; + +class Vec3i +{ +public: + Vec3i(int x_, int y_, int z_) : x(x_), y(y_), z(z_) {} + int x, y, z; +}; + +//------------------------------------------------------------------------ +// CUDA utilities. + +#if CR_CUDA +# define globalThreadIdx (threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * (blockIdx.x + gridDim.x * blockIdx.y))) +#endif + +//------------------------------------------------------------------------ +} // namespace CR diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/FineRaster.inl b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/FineRaster.inl new file mode 100644 index 0000000000000000000000000000000000000000..8704cc8fd5473c67cb9f97bfb42d76ff0539beb8 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/FineRaster.inl @@ -0,0 +1,385 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// Utility funcs. +//------------------------------------------------------------------------ + +__device__ __inline__ void initTileZMax(U32& tileZMax, bool& tileZUpd, volatile U32* tileDepth) +{ + tileZMax = CR_DEPTH_MAX; + tileZUpd = (::min(tileDepth[threadIdx.x], tileDepth[threadIdx.x + 32]) < tileZMax); +} + +__device__ __inline__ void updateTileZMax(U32& tileZMax, bool& tileZUpd, volatile U32* tileDepth, volatile U32* temp) +{ + // Entry is warp-coherent. + if (__any_sync(~0u, tileZUpd)) + { + U32 z = ::max(tileDepth[threadIdx.x], tileDepth[threadIdx.x + 32]); __syncwarp(); + temp[threadIdx.x + 16] = z; __syncwarp(); + z = ::max(z, temp[threadIdx.x + 16 - 1]); __syncwarp(); temp[threadIdx.x + 16] = z; __syncwarp(); + z = ::max(z, temp[threadIdx.x + 16 - 2]); __syncwarp(); temp[threadIdx.x + 16] = z; __syncwarp(); + z = ::max(z, temp[threadIdx.x + 16 - 4]); __syncwarp(); temp[threadIdx.x + 16] = z; __syncwarp(); + z = ::max(z, temp[threadIdx.x + 16 - 8]); __syncwarp(); temp[threadIdx.x + 16] = z; __syncwarp(); + z = ::max(z, temp[threadIdx.x + 16 - 16]); __syncwarp(); temp[threadIdx.x + 16] = z; __syncwarp(); + tileZMax = temp[47]; + tileZUpd = false; + } +} + +//------------------------------------------------------------------------ + +__device__ __inline__ void getTriangle(const CRParams& p, S32& triIdx, S32& dataIdx, uint4& triHeader, S32& segment) +{ + const CRTriangleHeader* triHeaderPtr = (const CRTriangleHeader*)p.triHeader + blockIdx.z * p.maxSubtris;; + const S32* tileSegData = (const S32*)p.tileSegData + p.maxTileSegs * CR_TILE_SEG_SIZE * blockIdx.z; + const S32* tileSegNext = (const S32*)p.tileSegNext + p.maxTileSegs * blockIdx.z; + const S32* tileSegCount = (const S32*)p.tileSegCount + p.maxTileSegs * blockIdx.z; + + if (threadIdx.x >= tileSegCount[segment]) + { + triIdx = -1; + dataIdx = -1; + } + else + { + int subtriIdx = tileSegData[segment * CR_TILE_SEG_SIZE + threadIdx.x]; + triIdx = subtriIdx >> 3; + dataIdx = triIdx; + subtriIdx &= 7; + if (subtriIdx != 7) + dataIdx = triHeaderPtr[triIdx].misc + subtriIdx; + triHeader = *((uint4*)triHeaderPtr + dataIdx); + } + + // advance to next segment + segment = tileSegNext[segment]; +} + +//------------------------------------------------------------------------ + +__device__ __inline__ bool earlyZCull(uint4 triHeader, U32 tileZMax) +{ + U32 zmin = triHeader.w & 0xFFFFF000; + return (zmin > tileZMax); +} + +//------------------------------------------------------------------------ + +__device__ __inline__ U64 trianglePixelCoverage(const CRParams& p, const uint4& triHeader, int tileX, int tileY, volatile U64* s_cover8x8_lut) +{ + int baseX = (tileX << (CR_TILE_LOG2 + CR_SUBPIXEL_LOG2)) - ((p.widthPixelsVp - 1) << (CR_SUBPIXEL_LOG2 - 1)); + int baseY = (tileY << (CR_TILE_LOG2 + CR_SUBPIXEL_LOG2)) - ((p.heightPixelsVp - 1) << (CR_SUBPIXEL_LOG2 - 1)); + + // extract S16 vertex positions while subtracting tile coordinates + S32 v0x = sub_s16lo_s16lo(triHeader.x, baseX); + S32 v0y = sub_s16hi_s16lo(triHeader.x, baseY); + S32 v01x = sub_s16lo_s16lo(triHeader.y, triHeader.x); + S32 v01y = sub_s16hi_s16hi(triHeader.y, triHeader.x); + S32 v20x = sub_s16lo_s16lo(triHeader.x, triHeader.z); + S32 v20y = sub_s16hi_s16hi(triHeader.x, triHeader.z); + + // extract flipbits + U32 f01 = (triHeader.w >> 6) & 0x3C; + U32 f12 = (triHeader.w >> 2) & 0x3C; + U32 f20 = (triHeader.w << 2) & 0x3C; + + // compute per-edge coverage masks + U64 c01, c12, c20; + c01 = cover8x8_exact_fast(v0x, v0y, v01x, v01y, f01, s_cover8x8_lut); + c12 = cover8x8_exact_fast(v0x + v01x, v0y + v01y, -v01x - v20x, -v01y - v20y, f12, s_cover8x8_lut); + c20 = cover8x8_exact_fast(v0x, v0y, v20x, v20y, f20, s_cover8x8_lut); + + // combine masks + return c01 & c12 & c20; +} + +//------------------------------------------------------------------------ + +__device__ __inline__ U32 scan32_value(U32 value, volatile U32* temp) +{ + __syncwarp(); + temp[threadIdx.x + 16] = value; __syncwarp(); + value += temp[threadIdx.x + 16 - 1]; __syncwarp(); temp[threadIdx.x + 16] = value; __syncwarp(); + value += temp[threadIdx.x + 16 - 2]; __syncwarp(); temp[threadIdx.x + 16] = value; __syncwarp(); + value += temp[threadIdx.x + 16 - 4]; __syncwarp(); temp[threadIdx.x + 16] = value; __syncwarp(); + value += temp[threadIdx.x + 16 - 8]; __syncwarp(); temp[threadIdx.x + 16] = value; __syncwarp(); + value += temp[threadIdx.x + 16 - 16]; __syncwarp(); temp[threadIdx.x + 16] = value; __syncwarp(); + return value; +} + +__device__ __inline__ volatile const U32& scan32_total(volatile U32* temp) +{ + return temp[47]; +} + +//------------------------------------------------------------------------ + +__device__ __inline__ S32 findBit(U64 mask, int idx) +{ + U32 x = getLo(mask); + int pop = __popc(x); + bool p = (pop <= idx); + if (p) x = getHi(mask); + if (p) idx -= pop; + int bit = p ? 32 : 0; + + pop = __popc(x & 0x0000ffffu); + p = (pop <= idx); + if (p) x >>= 16; + if (p) bit += 16; + if (p) idx -= pop; + + U32 tmp = x & 0x000000ffu; + pop = __popc(tmp); + p = (pop <= idx); + if (p) tmp = x & 0x0000ff00u; + if (p) idx -= pop; + + return findLeadingOne(tmp) + bit - idx; +} + +//------------------------------------------------------------------------ +// Single-sample implementation. +//------------------------------------------------------------------------ + +__device__ __inline__ void executeROP(U32 color, U32 depth, volatile U32* pColor, volatile U32* pDepth, U32 ropMask) +{ + atomicMin((U32*)pDepth, depth); + __syncwarp(ropMask); + bool act = (depth == *pDepth); + __syncwarp(ropMask); + U32 actMask = __ballot_sync(ropMask, act); + if (act) + { + *pDepth = 0; + __syncwarp(actMask); + atomicMax((U32*)pDepth, threadIdx.x); + __syncwarp(actMask); + if (*pDepth == threadIdx.x) + { + *pDepth = depth; + *pColor = color; + } + __syncwarp(actMask); + } +} + +//------------------------------------------------------------------------ + +__device__ __inline__ void fineRasterImpl(const CRParams p) +{ + // for 20 warps: + __shared__ volatile U64 s_cover8x8_lut[CR_COVER8X8_LUT_SIZE]; // 6KB + __shared__ volatile U32 s_tileColor [CR_FINE_MAX_WARPS][CR_TILE_SQR]; // 5KB + __shared__ volatile U32 s_tileDepth [CR_FINE_MAX_WARPS][CR_TILE_SQR]; // 5KB + __shared__ volatile U32 s_tilePeel [CR_FINE_MAX_WARPS][CR_TILE_SQR]; // 5KB + __shared__ volatile U32 s_triDataIdx [CR_FINE_MAX_WARPS][64]; // 5KB CRTriangleData index + __shared__ volatile U64 s_triangleCov [CR_FINE_MAX_WARPS][64]; // 10KB coverage mask + __shared__ volatile U32 s_triangleFrag[CR_FINE_MAX_WARPS][64]; // 5KB fragment index + __shared__ volatile U32 s_temp [CR_FINE_MAX_WARPS][80]; // 6.25KB + // = 47.25KB total + + CRAtomics& atomics = p.atomics[blockIdx.z]; + const CRTriangleData* triData = (const CRTriangleData*)p.triData + blockIdx.z * p.maxSubtris; + + const S32* activeTiles = (const S32*)p.activeTiles + CR_MAXTILES_SQR * blockIdx.z; + const S32* tileFirstSeg = (const S32*)p.tileFirstSeg + CR_MAXTILES_SQR * blockIdx.z; + + volatile U32* tileColor = s_tileColor[threadIdx.y]; + volatile U32* tileDepth = s_tileDepth[threadIdx.y]; + volatile U32* tilePeel = s_tilePeel[threadIdx.y]; + volatile U32* triDataIdx = s_triDataIdx[threadIdx.y]; + volatile U64* triangleCov = s_triangleCov[threadIdx.y]; + volatile U32* triangleFrag = s_triangleFrag[threadIdx.y]; + volatile U32* temp = s_temp[threadIdx.y]; + + if (atomics.numSubtris > p.maxSubtris || atomics.numBinSegs > p.maxBinSegs || atomics.numTileSegs > p.maxTileSegs) + return; + + temp[threadIdx.x] = 0; // first 16 elements of temp are always zero + cover8x8_setupLUT(s_cover8x8_lut); + __syncthreads(); + + // loop over tiles + for (;;) + { + // pick a tile + if (threadIdx.x == 0) + temp[16] = atomicAdd(&atomics.fineCounter, 1); + __syncwarp(); + int activeIdx = temp[16]; + if (activeIdx >= atomics.numActiveTiles) + break; + + int tileIdx = activeTiles[activeIdx]; + S32 segment = tileFirstSeg[tileIdx]; + int tileY = tileIdx / p.widthTiles; + int tileX = tileIdx - tileY * p.widthTiles; + int px = (tileX << CR_TILE_LOG2) + (threadIdx.x & (CR_TILE_SIZE - 1)); + int py = (tileY << CR_TILE_LOG2) + (threadIdx.x >> CR_TILE_LOG2); + + // initialize per-tile state + int triRead = 0, triWrite = 0; + int fragRead = 0, fragWrite = 0; + if (threadIdx.x == 0) + triangleFrag[63] = 0; // "previous triangle" + + // deferred clear => clear tile + if (p.deferredClear) + { + tileColor[threadIdx.x] = p.clearColor; + tileDepth[threadIdx.x] = p.clearDepth; + tileColor[threadIdx.x + 32] = p.clearColor; + tileDepth[threadIdx.x + 32] = p.clearDepth; + } + else // otherwise => read tile from framebuffer + { + U32* pColor = (U32*)p.colorBuffer + p.strideX * p.strideY * blockIdx.z; + U32* pDepth = (U32*)p.depthBuffer + p.strideX * p.strideY * blockIdx.z; + tileColor[threadIdx.x] = pColor[px + p.strideX * py]; + tileDepth[threadIdx.x] = pDepth[px + p.strideX * py]; + tileColor[threadIdx.x + 32] = pColor[px + p.strideX * (py + 4)]; + tileDepth[threadIdx.x + 32] = pDepth[px + p.strideX * (py + 4)]; + } + + // read peeling inputs if enabled + if (p.renderModeFlags & CudaRaster::RenderModeFlag_EnableDepthPeeling) + { + U32* pPeel = (U32*)p.peelBuffer + p.strideX * p.strideY * blockIdx.z; + tilePeel[threadIdx.x] = pPeel[px + p.strideX * py]; + tilePeel[threadIdx.x + 32] = pPeel[px + p.strideX * (py + 4)]; + } + + U32 tileZMax; + bool tileZUpd; + initTileZMax(tileZMax, tileZUpd, tileDepth); + + // process fragments + for(;;) + { + // need to queue more fragments? + if (fragWrite - fragRead < 32 && segment >= 0) + { + // update tile z - coherent over warp + updateTileZMax(tileZMax, tileZUpd, tileDepth, temp); + + // read triangles + do + { + // read triangle index and data, advance to next segment + S32 triIdx, dataIdx; + uint4 triHeader; + getTriangle(p, triIdx, dataIdx, triHeader, segment); + + // early z cull + if (triIdx >= 0 && earlyZCull(triHeader, tileZMax)) + triIdx = -1; + + // determine coverage + U64 coverage = trianglePixelCoverage(p, triHeader, tileX, tileY, s_cover8x8_lut); + S32 pop = (triIdx == -1) ? 0 : __popcll(coverage); + + // fragment count scan + U32 frag = scan32_value(pop, temp); + frag += fragWrite; // frag now holds cumulative fragment count + fragWrite += scan32_total(temp); + + // queue non-empty triangles + U32 goodMask = __ballot_sync(~0u, pop != 0); + if (pop != 0) + { + int idx = (triWrite + __popc(goodMask & getLaneMaskLt())) & 63; + triDataIdx [idx] = dataIdx; + triangleFrag[idx] = frag; + triangleCov [idx] = coverage; + } + triWrite += __popc(goodMask); + } + while (fragWrite - fragRead < 32 && segment >= 0); + } + __syncwarp(); + + // end of segment? + if (fragRead == fragWrite) + break; + + // clear triangle boundaries + temp[threadIdx.x + 16] = 0; + __syncwarp(); + + // tag triangle boundaries + if (triRead + threadIdx.x < triWrite) + { + int idx = triangleFrag[(triRead + threadIdx.x) & 63] - fragRead; + if (idx <= 32) + temp[idx + 16 - 1] = 1; + } + __syncwarp(); + + int ropLaneIdx = threadIdx.x; + U32 boundaryMask = __ballot_sync(~0u, temp[ropLaneIdx + 16]); + + // distribute fragments + bool hasFragment = (ropLaneIdx < fragWrite - fragRead); + U32 fragmentMask = __ballot_sync(~0u, hasFragment); + if (hasFragment) + { + int triBufIdx = (triRead + __popc(boundaryMask & getLaneMaskLt())) & 63; + int fragIdx = add_sub(fragRead, ropLaneIdx, triangleFrag[(triBufIdx - 1) & 63]); + U64 coverage = triangleCov[triBufIdx]; + int pixelInTile = findBit(coverage, fragIdx); + int dataIdx = triDataIdx[triBufIdx]; + + // determine pixel position + U32 pixelX = (tileX << CR_TILE_LOG2) + (pixelInTile & 7); + U32 pixelY = (tileY << CR_TILE_LOG2) + (pixelInTile >> 3); + + // depth test + U32 depth = 0; + uint4 td = *((uint4*)triData + dataIdx * (sizeof(CRTriangleData) >> 4)); + + depth = td.x * pixelX + td.y * pixelY + td.z; + bool zkill = (p.renderModeFlags & CudaRaster::RenderModeFlag_EnableDepthPeeling) && (depth <= tilePeel[pixelInTile]); + if (!zkill) + { + U32 oldDepth = tileDepth[pixelInTile]; + if (depth > oldDepth) + zkill = true; + else if (oldDepth == tileZMax) + tileZUpd = true; // we are replacing previous zmax => need to update + } + + U32 ropMask = __ballot_sync(fragmentMask, !zkill); + if (!zkill) + executeROP(td.w, depth, &tileColor[pixelInTile], &tileDepth[pixelInTile], ropMask); + } + // no need to sync, as next up is updateTileZMax that does internal warp sync + + // update counters + fragRead = ::min(fragRead + 32, fragWrite); + triRead += __popc(boundaryMask); + } + + // Write tile back to the framebuffer. + if (true) + { + int px = (tileX << CR_TILE_LOG2) + (threadIdx.x & (CR_TILE_SIZE - 1)); + int py = (tileY << CR_TILE_LOG2) + (threadIdx.x >> CR_TILE_LOG2); + U32* pColor = (U32*)p.colorBuffer + p.strideX * p.strideY * blockIdx.z; + U32* pDepth = (U32*)p.depthBuffer + p.strideX * p.strideY * blockIdx.z; + pColor[px + p.strideX * py] = tileColor[threadIdx.x]; + pDepth[px + p.strideX * py] = tileDepth[threadIdx.x]; + pColor[px + p.strideX * (py + 4)] = tileColor[threadIdx.x + 32]; + pDepth[px + p.strideX * (py + 4)] = tileDepth[threadIdx.x + 32]; + } + } +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/PrivateDefs.hpp b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/PrivateDefs.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fda59339b310f51253c29cac42d47086ef35c28a --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/PrivateDefs.hpp @@ -0,0 +1,153 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once +#include "Defs.hpp" +#include "Constants.hpp" + +namespace CR +{ +//------------------------------------------------------------------------ +// Projected triangle. +//------------------------------------------------------------------------ + +struct CRTriangleHeader +{ + S16 v0x; // Subpixels relative to viewport center. Valid if triSubtris = 1. + S16 v0y; + S16 v1x; + S16 v1y; + S16 v2x; + S16 v2y; + + U32 misc; // triSubtris=1: (zmin:20, f01:4, f12:4, f20:4), triSubtris>=2: (subtriBase) +}; + +//------------------------------------------------------------------------ + +struct CRTriangleData +{ + U32 zx; // zx * sampleX + zy * sampleY + zb = lerp(CR_DEPTH_MIN, CR_DEPTH_MAX, (clipZ / clipW + 1) / 2) + U32 zy; + U32 zb; + U32 id; // Triangle id. +}; + +//------------------------------------------------------------------------ +// Device-side structures. +//------------------------------------------------------------------------ + +struct CRAtomics +{ + // Setup. + S32 numSubtris; // = numTris + + // Bin. + S32 binCounter; // = 0 + S32 numBinSegs; // = 0 + + // Coarse. + S32 coarseCounter; // = 0 + S32 numTileSegs; // = 0 + S32 numActiveTiles; // = 0 + + // Fine. + S32 fineCounter; // = 0 +}; + +//------------------------------------------------------------------------ + +struct CRImageParams +{ + S32 triOffset; // First triangle index to draw. + S32 triCount; // Number of triangles to draw. + S32 binBatchSize; // Number of triangles per batch. +}; + +//------------------------------------------------------------------------ + +struct CRParams +{ + // Common. + + CRAtomics* atomics; // Work counters. Per-image. + S32 numImages; // Batch size. + S32 totalCount; // In range mode, total number of triangles to render. + S32 instanceMode; // 0 = range mode, 1 = instance mode. + + S32 numVertices; // Number of vertices in input buffer, not counting multiples in instance mode. + S32 numTriangles; // Number of triangles in input buffer. + void* vertexBuffer; // numVertices * float4(x, y, z, w) + void* indexBuffer; // numTriangles * int3(vi0, vi1, vi2) + + S32 widthPixels; // Render buffer size in pixels. Must be multiple of tile size (8x8). + S32 heightPixels; + S32 widthPixelsVp; // Viewport size in pixels. + S32 heightPixelsVp; + S32 widthBins; // widthPixels / CR_BIN_SIZE + S32 heightBins; // heightPixels / CR_BIN_SIZE + S32 numBins; // widthBins * heightBins + + F32 xs; // Vertex position adjustments for tiled rendering. + F32 ys; + F32 xo; + F32 yo; + + S32 widthTiles; // widthPixels / CR_TILE_SIZE + S32 heightTiles; // heightPixels / CR_TILE_SIZE + S32 numTiles; // widthTiles * heightTiles + + U32 renderModeFlags; + S32 deferredClear; // 1 = Clear framebuffer before rendering triangles. + U32 clearColor; + U32 clearDepth; + + // These are uniform across batch. + + S32 maxSubtris; + S32 maxBinSegs; + S32 maxTileSegs; + + // Setup output / bin input. + + void* triSubtris; // maxSubtris * U8 + void* triHeader; // maxSubtris * CRTriangleHeader + void* triData; // maxSubtris * CRTriangleData + + // Bin output / coarse input. + + void* binSegData; // maxBinSegs * CR_BIN_SEG_SIZE * S32 + void* binSegNext; // maxBinSegs * S32 + void* binSegCount; // maxBinSegs * S32 + void* binFirstSeg; // CR_MAXBINS_SQR * CR_BIN_STREAMS_SIZE * (S32 segIdx), -1 = none + void* binTotal; // CR_MAXBINS_SQR * CR_BIN_STREAMS_SIZE * (S32 numTris) + + // Coarse output / fine input. + + void* tileSegData; // maxTileSegs * CR_TILE_SEG_SIZE * S32 + void* tileSegNext; // maxTileSegs * S32 + void* tileSegCount; // maxTileSegs * S32 + void* activeTiles; // CR_MAXTILES_SQR * (S32 tileIdx) + void* tileFirstSeg; // CR_MAXTILES_SQR * (S32 segIdx), -1 = none + + // Surface buffers. Outer tile offset is baked into pointers. + + void* colorBuffer; // sizePixels.x * sizePixels.y * numImages * U32 + void* depthBuffer; // sizePixels.x * sizePixels.y * numImages * U32 + void* peelBuffer; // sizePixels.x * sizePixels.y * numImages * U32, only if peeling enabled. + S32 strideX; // horizontal size in pixels + S32 strideY; // vertical stride in pixels + + // Per-image parameters for first images are embedded here to avoid extra memcpy for small batches. + + CRImageParams imageParamsFirst[CR_EMBED_IMAGE_PARAMS]; + const CRImageParams* imageParamsExtra; // After CR_EMBED_IMAGE_PARAMS. +}; + +//------------------------------------------------------------------------ +} diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl.cpp b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..48643b1e59ff00b4b2c44b3e7aa280bc4f486925 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl.cpp @@ -0,0 +1,370 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "../../framework.h" +#include "PrivateDefs.hpp" +#include "Constants.hpp" +#include "RasterImpl.hpp" +#include + +using namespace CR; +using std::min; +using std::max; + +//------------------------------------------------------------------------ +// Kernel prototypes and variables. + +void triangleSetupKernel (const CRParams p); +void binRasterKernel (const CRParams p); +void coarseRasterKernel (const CRParams p); +void fineRasterKernel (const CRParams p); + +//------------------------------------------------------------------------ + +RasterImpl::RasterImpl(void) +: m_renderModeFlags (0), + m_deferredClear (false), + m_clearColor (0), + m_vertexPtr (NULL), + m_indexPtr (NULL), + m_numVertices (0), + m_numTriangles (0), + m_bufferSizesReported (0), + + m_numImages (0), + m_bufferSizePixels (0, 0), + m_bufferSizeVp (0, 0), + m_sizePixels (0, 0), + m_sizeVp (0, 0), + m_offsetPixels (0, 0), + m_sizeBins (0, 0), + m_numBins (0), + m_sizeTiles (0, 0), + m_numTiles (0), + + m_numSMs (1), + m_numCoarseBlocksPerSM (1), + m_numFineBlocksPerSM (1), + m_numFineWarpsPerBlock (1), + + m_maxSubtris (1), + m_maxBinSegs (1), + m_maxTileSegs (1) +{ + // Query relevant device attributes. + + int currentDevice = 0; + NVDR_CHECK_CUDA_ERROR(cudaGetDevice(¤tDevice)); + NVDR_CHECK_CUDA_ERROR(cudaDeviceGetAttribute(&m_numSMs, cudaDevAttrMultiProcessorCount, currentDevice)); + cudaFuncAttributes attr; + NVDR_CHECK_CUDA_ERROR(cudaFuncGetAttributes(&attr, (void*)fineRasterKernel)); + m_numFineWarpsPerBlock = min(attr.maxThreadsPerBlock / 32, CR_FINE_MAX_WARPS); + NVDR_CHECK_CUDA_ERROR(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&m_numCoarseBlocksPerSM, (void*)coarseRasterKernel, 32 * CR_COARSE_WARPS, 0)); + NVDR_CHECK_CUDA_ERROR(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&m_numFineBlocksPerSM, (void*)fineRasterKernel, 32 * m_numFineWarpsPerBlock, 0)); + + // Setup functions. + + NVDR_CHECK_CUDA_ERROR(cudaFuncSetCacheConfig((void*)triangleSetupKernel, cudaFuncCachePreferShared)); + NVDR_CHECK_CUDA_ERROR(cudaFuncSetCacheConfig((void*)binRasterKernel, cudaFuncCachePreferShared)); + NVDR_CHECK_CUDA_ERROR(cudaFuncSetCacheConfig((void*)coarseRasterKernel, cudaFuncCachePreferShared)); + NVDR_CHECK_CUDA_ERROR(cudaFuncSetCacheConfig((void*)fineRasterKernel, cudaFuncCachePreferShared)); +} + +//------------------------------------------------------------------------ + +RasterImpl::~RasterImpl(void) +{ + // Empty. +} + +//------------------------------------------------------------------------ + +void RasterImpl::setBufferSize(Vec3i size) +{ + // Internal buffer width and height must be divisible by tile size. + int w = (size.x + CR_TILE_SIZE - 1) & (-CR_TILE_SIZE); + int h = (size.y + CR_TILE_SIZE - 1) & (-CR_TILE_SIZE); + + m_bufferSizePixels = Vec2i(w, h); + m_bufferSizeVp = Vec2i(size.x, size.y); + m_numImages = size.z; + + m_colorBuffer.reset(w * h * size.z * sizeof(U32)); + m_depthBuffer.reset(w * h * size.z * sizeof(U32)); +} + +//------------------------------------------------------------------------ + +void RasterImpl::setViewport(Vec2i size, Vec2i offset) +{ + // Offset must be divisible by tile size. + NVDR_CHECK((offset.x & (CR_TILE_SIZE - 1)) == 0 && (offset.y & (CR_TILE_SIZE - 1)) == 0, "invalid viewport offset"); + + // Round internal viewport size to multiples of tile size. + int w = (size.x + CR_TILE_SIZE - 1) & (-CR_TILE_SIZE); + int h = (size.y + CR_TILE_SIZE - 1) & (-CR_TILE_SIZE); + + m_sizePixels = Vec2i(w, h); + m_offsetPixels = offset; + m_sizeVp = Vec2i(size.x, size.y); + m_sizeTiles.x = m_sizePixels.x >> CR_TILE_LOG2; + m_sizeTiles.y = m_sizePixels.y >> CR_TILE_LOG2; + m_numTiles = m_sizeTiles.x * m_sizeTiles.y; + m_sizeBins.x = (m_sizeTiles.x + CR_BIN_SIZE - 1) >> CR_BIN_LOG2; + m_sizeBins.y = (m_sizeTiles.y + CR_BIN_SIZE - 1) >> CR_BIN_LOG2; + m_numBins = m_sizeBins.x * m_sizeBins.y; +} + +void RasterImpl::swapDepthAndPeel(void) +{ + m_peelBuffer.reset(m_depthBuffer.getSize()); // Ensure equal size and valid pointer. + + void* tmp = m_depthBuffer.getPtr(); + m_depthBuffer.setPtr(m_peelBuffer.getPtr()); + m_peelBuffer.setPtr(tmp); +} + +//------------------------------------------------------------------------ + +bool RasterImpl::drawTriangles(const Vec2i* ranges, bool peel, cudaStream_t stream) +{ + bool instanceMode = (!ranges); + + int maxSubtrisSlack = 4096; // x 81B = 324KB + int maxBinSegsSlack = 256; // x 2137B = 534KB + int maxTileSegsSlack = 4096; // x 136B = 544KB + + // Resize atomics as needed. + m_crAtomics .grow(m_numImages * sizeof(CRAtomics)); + m_crAtomicsHost.grow(m_numImages * sizeof(CRAtomics)); + + // Size of these buffers doesn't depend on input. + m_binFirstSeg .grow(m_numImages * CR_MAXBINS_SQR * CR_BIN_STREAMS_SIZE * sizeof(S32)); + m_binTotal .grow(m_numImages * CR_MAXBINS_SQR * CR_BIN_STREAMS_SIZE * sizeof(S32)); + m_activeTiles .grow(m_numImages * CR_MAXTILES_SQR * sizeof(S32)); + m_tileFirstSeg .grow(m_numImages * CR_MAXTILES_SQR * sizeof(S32)); + + // Construct per-image parameters and determine worst-case buffer sizes. + m_crImageParamsHost.grow(m_numImages * sizeof(CRImageParams)); + CRImageParams* imageParams = (CRImageParams*)m_crImageParamsHost.getPtr(); + for (int i=0; i < m_numImages; i++) + { + CRImageParams& ip = imageParams[i]; + + int roundSize = CR_BIN_WARPS * 32; + int minBatches = CR_BIN_STREAMS_SIZE * 2; + int maxRounds = 32; + + ip.triOffset = instanceMode ? 0 : ranges[i].x; + ip.triCount = instanceMode ? m_numTriangles : ranges[i].y; + ip.binBatchSize = min(max(ip.triCount / (roundSize * minBatches), 1), maxRounds) * roundSize; + + m_maxSubtris = max(m_maxSubtris, min(ip.triCount + maxSubtrisSlack, CR_MAXSUBTRIS_SIZE)); + m_maxBinSegs = max(m_maxBinSegs, max(m_numBins * CR_BIN_STREAMS_SIZE, (ip.triCount - 1) / CR_BIN_SEG_SIZE + 1) + maxBinSegsSlack); + m_maxTileSegs = max(m_maxTileSegs, max(m_numTiles, (ip.triCount - 1) / CR_TILE_SEG_SIZE + 1) + maxTileSegsSlack); + } + + // Retry until successful. + + for (;;) + { + // Allocate buffers. + m_triSubtris.reset(m_numImages * m_maxSubtris * sizeof(U8)); + m_triHeader .reset(m_numImages * m_maxSubtris * sizeof(CRTriangleHeader)); + m_triData .reset(m_numImages * m_maxSubtris * sizeof(CRTriangleData)); + + m_binSegData .reset(m_numImages * m_maxBinSegs * CR_BIN_SEG_SIZE * sizeof(S32)); + m_binSegNext .reset(m_numImages * m_maxBinSegs * sizeof(S32)); + m_binSegCount.reset(m_numImages * m_maxBinSegs * sizeof(S32)); + + m_tileSegData .reset(m_numImages * m_maxTileSegs * CR_TILE_SEG_SIZE * sizeof(S32)); + m_tileSegNext .reset(m_numImages * m_maxTileSegs * sizeof(S32)); + m_tileSegCount.reset(m_numImages * m_maxTileSegs * sizeof(S32)); + + // Report if buffers grow from last time. + size_t sizesTotal = getTotalBufferSizes(); + if (sizesTotal > m_bufferSizesReported) + { + size_t sizesMB = ((sizesTotal - 1) >> 20) + 1; // Round up. + sizesMB = ((sizesMB + 9) / 10) * 10; // 10MB granularity enough in this day and age. + LOG(INFO) << "Internal buffers grown to " << sizesMB << " MB"; + m_bufferSizesReported = sizesMB << 20; + } + + // Launch stages. Blocks until everything is done. + launchStages(instanceMode, peel, stream); + + // Peeling iteration cannot fail, so no point checking things further. + if (peel) + break; + + // Atomics after coarse stage are now available. + CRAtomics* atomics = (CRAtomics*)m_crAtomicsHost.getPtr(); + + // Success? + bool failed = false; + for (int i=0; i < m_numImages; i++) + { + const CRAtomics& a = atomics[i]; + failed = failed || (a.numSubtris > m_maxSubtris) || (a.numBinSegs > m_maxBinSegs) || (a.numTileSegs > m_maxTileSegs); + } + if (!failed) + break; // Success! + + // If we were already at maximum capacity, no can do. + if (m_maxSubtris == CR_MAXSUBTRIS_SIZE) + return false; + + // Enlarge buffers and try again. + for (int i=0; i < m_numImages; i++) + { + const CRAtomics& a = atomics[i]; + m_maxSubtris = max(m_maxSubtris, min(a.numSubtris + maxSubtrisSlack, CR_MAXSUBTRIS_SIZE)); + m_maxBinSegs = max(m_maxBinSegs, a.numBinSegs + maxBinSegsSlack); + m_maxTileSegs = max(m_maxTileSegs, a.numTileSegs + maxTileSegsSlack); + } + } + + m_deferredClear = false; + return true; // Success. +} + +//------------------------------------------------------------------------ + +size_t RasterImpl::getTotalBufferSizes(void) const +{ + return + m_colorBuffer.getSize() + m_depthBuffer.getSize() + // Don't include atomics and image params. + m_triSubtris.getSize() + m_triHeader.getSize() + m_triData.getSize() + + m_binFirstSeg.getSize() + m_binTotal.getSize() + m_binSegData.getSize() + m_binSegNext.getSize() + m_binSegCount.getSize() + + m_activeTiles.getSize() + m_tileFirstSeg.getSize() + m_tileSegData.getSize() + m_tileSegNext.getSize() + m_tileSegCount.getSize(); +} + +//------------------------------------------------------------------------ + +void RasterImpl::launchStages(bool instanceMode, bool peel, cudaStream_t stream) +{ + CRImageParams* imageParams = (CRImageParams*)m_crImageParamsHost.getPtr(); + + // Unless peeling, initialize atomics to mostly zero. + CRAtomics* atomics = (CRAtomics*)m_crAtomicsHost.getPtr(); + if (!peel) + { + memset(atomics, 0, m_numImages * sizeof(CRAtomics)); + for (int i=0; i < m_numImages; i++) + atomics[i].numSubtris = imageParams[i].triCount; + } + + // Copy to device. If peeling, this is the state after coarse raster launch on first iteration. + NVDR_CHECK_CUDA_ERROR(cudaMemcpyAsync(m_crAtomics.getPtr(), atomics, m_numImages * sizeof(CRAtomics), cudaMemcpyHostToDevice, stream)); + + // Copy per-image parameters if there are more than fits in launch parameter block and we haven't done it already. + if (!peel && m_numImages > CR_EMBED_IMAGE_PARAMS) + { + int numImageParamsExtra = m_numImages - CR_EMBED_IMAGE_PARAMS; + m_crImageParamsExtra.grow(numImageParamsExtra * sizeof(CRImageParams)); + NVDR_CHECK_CUDA_ERROR(cudaMemcpyAsync(m_crImageParamsExtra.getPtr(), imageParams + CR_EMBED_IMAGE_PARAMS, numImageParamsExtra * sizeof(CRImageParams), cudaMemcpyHostToDevice, stream)); + } + + // Set global parameters. + CRParams p; + { + p.atomics = (CRAtomics*)m_crAtomics.getPtr(); + p.numImages = m_numImages; + p.totalCount = 0; // Only relevant in range mode. + p.instanceMode = instanceMode ? 1 : 0; + + p.numVertices = m_numVertices; + p.numTriangles = m_numTriangles; + p.vertexBuffer = m_vertexPtr; + p.indexBuffer = m_indexPtr; + + p.widthPixels = m_sizePixels.x; + p.heightPixels = m_sizePixels.y; + p.widthPixelsVp = m_sizeVp.x; + p.heightPixelsVp = m_sizeVp.y; + p.widthBins = m_sizeBins.x; + p.heightBins = m_sizeBins.y; + p.numBins = m_numBins; + + p.xs = (float)m_bufferSizeVp.x / (float)m_sizeVp.x; + p.ys = (float)m_bufferSizeVp.y / (float)m_sizeVp.y; + p.xo = (float)(m_bufferSizeVp.x - m_sizeVp.x - 2 * m_offsetPixels.x) / (float)m_sizeVp.x; + p.yo = (float)(m_bufferSizeVp.y - m_sizeVp.y - 2 * m_offsetPixels.y) / (float)m_sizeVp.y; + + p.widthTiles = m_sizeTiles.x; + p.heightTiles = m_sizeTiles.y; + p.numTiles = m_numTiles; + + p.renderModeFlags = m_renderModeFlags; + p.deferredClear = m_deferredClear ? 1 : 0; + p.clearColor = m_clearColor; + p.clearDepth = CR_DEPTH_MAX; + + p.maxSubtris = m_maxSubtris; + p.maxBinSegs = m_maxBinSegs; + p.maxTileSegs = m_maxTileSegs; + + p.triSubtris = m_triSubtris.getPtr(); + p.triHeader = m_triHeader.getPtr(); + p.triData = m_triData.getPtr(); + p.binSegData = m_binSegData.getPtr(); + p.binSegNext = m_binSegNext.getPtr(); + p.binSegCount = m_binSegCount.getPtr(); + p.binFirstSeg = m_binFirstSeg.getPtr(); + p.binTotal = m_binTotal.getPtr(); + p.tileSegData = m_tileSegData.getPtr(); + p.tileSegNext = m_tileSegNext.getPtr(); + p.tileSegCount = m_tileSegCount.getPtr(); + p.activeTiles = m_activeTiles.getPtr(); + p.tileFirstSeg = m_tileFirstSeg.getPtr(); + + size_t byteOffset = ((size_t)m_offsetPixels.x + (size_t)m_offsetPixels.y * (size_t)p.strideX) * sizeof(U32); + p.colorBuffer = m_colorBuffer.getPtr(byteOffset); + p.depthBuffer = m_depthBuffer.getPtr(byteOffset); + p.peelBuffer = (m_renderModeFlags & CudaRaster::RenderModeFlag_EnableDepthPeeling) ? m_peelBuffer.getPtr(byteOffset) : 0; + p.strideX = m_bufferSizePixels.x; + p.strideY = m_bufferSizePixels.y; + + memcpy(&p.imageParamsFirst, imageParams, min(m_numImages, CR_EMBED_IMAGE_PARAMS) * sizeof(CRImageParams)); + p.imageParamsExtra = (CRImageParams*)m_crImageParamsExtra.getPtr(); + } + + // Setup block sizes. + + dim3 brBlock(32, CR_BIN_WARPS); + dim3 crBlock(32, CR_COARSE_WARPS); + dim3 frBlock(32, m_numFineWarpsPerBlock); + void* args[] = {&p}; + + // Launch stages from setup to coarse and copy atomics to host only if this is not a single-tile peeling iteration. + if (!peel) + { + if (instanceMode) + { + int setupBlocks = (m_numTriangles - 1) / (32 * CR_SETUP_WARPS) + 1; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((void*)triangleSetupKernel, dim3(setupBlocks, 1, m_numImages), dim3(32, CR_SETUP_WARPS), args, 0, stream)); + } + else + { + for (int i=0; i < m_numImages; i++) + p.totalCount += imageParams[i].triCount; + int setupBlocks = (p.totalCount - 1) / (32 * CR_SETUP_WARPS) + 1; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((void*)triangleSetupKernel, dim3(setupBlocks, 1, 1), dim3(32, CR_SETUP_WARPS), args, 0, stream)); + } + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((void*)binRasterKernel, dim3(CR_BIN_STREAMS_SIZE, 1, m_numImages), brBlock, args, 0, stream)); + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((void*)coarseRasterKernel, dim3(m_numSMs * m_numCoarseBlocksPerSM, 1, m_numImages), crBlock, args, 0, stream)); + NVDR_CHECK_CUDA_ERROR(cudaMemcpyAsync(m_crAtomicsHost.getPtr(), m_crAtomics.getPtr(), sizeof(CRAtomics) * m_numImages, cudaMemcpyDeviceToHost, stream)); + } + + // Fine rasterizer is launched always. + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((void*)fineRasterKernel, dim3(m_numSMs * m_numFineBlocksPerSM, 1, m_numImages), frBlock, args, 0, stream)); + NVDR_CHECK_CUDA_ERROR(cudaStreamSynchronize(stream)); +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl.hpp b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..321c7b30787ef50ac0d066f07ad5113c7079f499 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl.hpp @@ -0,0 +1,102 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once +#include "PrivateDefs.hpp" +#include "Buffer.hpp" +#include "../CudaRaster.hpp" + +namespace CR +{ +//------------------------------------------------------------------------ + +class RasterImpl +{ +public: + RasterImpl (void); + ~RasterImpl (void); + + void setBufferSize (Vec3i size); + void setViewport (Vec2i size, Vec2i offset); + void setRenderModeFlags (U32 flags) { m_renderModeFlags = flags; } + void deferredClear (U32 color) { m_deferredClear = true; m_clearColor = color; } + void setVertexBuffer (void* ptr, int numVertices) { m_vertexPtr = ptr; m_numVertices = numVertices; } // GPU pointer. + void setIndexBuffer (void* ptr, int numTriangles) { m_indexPtr = ptr; m_numTriangles = numTriangles; } // GPU pointer. + bool drawTriangles (const Vec2i* ranges, bool peel, cudaStream_t stream); + void* getColorBuffer (void) { return m_colorBuffer.getPtr(); } // GPU pointer. + void* getDepthBuffer (void) { return m_depthBuffer.getPtr(); } // GPU pointer. + void swapDepthAndPeel (void); + size_t getTotalBufferSizes (void) const; + +private: + void launchStages (bool instanceMode, bool peel, cudaStream_t stream); + + // State. + + unsigned int m_renderModeFlags; + bool m_deferredClear; + unsigned int m_clearColor; + void* m_vertexPtr; + void* m_indexPtr; + int m_numVertices; // Input buffer size. + int m_numTriangles; // Input buffer size. + size_t m_bufferSizesReported; // Previously reported buffer sizes. + + // Surfaces. + + Buffer m_colorBuffer; + Buffer m_depthBuffer; + Buffer m_peelBuffer; + int m_numImages; + Vec2i m_bufferSizePixels; // Internal buffer size. + Vec2i m_bufferSizeVp; // Total viewport size. + Vec2i m_sizePixels; // Internal size at which all computation is done, buffers reserved, etc. + Vec2i m_sizeVp; // Size to which output will be cropped outside, determines viewport size. + Vec2i m_offsetPixels; // Viewport offset for tiled rendering. + Vec2i m_sizeBins; + S32 m_numBins; + Vec2i m_sizeTiles; + S32 m_numTiles; + + // Launch sizes etc. + + S32 m_numSMs; + S32 m_numCoarseBlocksPerSM; + S32 m_numFineBlocksPerSM; + S32 m_numFineWarpsPerBlock; + + // Global intermediate buffers. Individual images have offsets to these. + + Buffer m_crAtomics; + HostBuffer m_crAtomicsHost; + HostBuffer m_crImageParamsHost; + Buffer m_crImageParamsExtra; + Buffer m_triSubtris; + Buffer m_triHeader; + Buffer m_triData; + Buffer m_binFirstSeg; + Buffer m_binTotal; + Buffer m_binSegData; + Buffer m_binSegNext; + Buffer m_binSegCount; + Buffer m_activeTiles; + Buffer m_tileFirstSeg; + Buffer m_tileSegData; + Buffer m_tileSegNext; + Buffer m_tileSegCount; + + // Actual buffer sizes. + + S32 m_maxSubtris; + S32 m_maxBinSegs; + S32 m_maxTileSegs; +}; + +//------------------------------------------------------------------------ +} // namespace CR + diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl_.cu b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl_.cu new file mode 100644 index 0000000000000000000000000000000000000000..69c137fdbfaf98f958ed3fe7afebc0d2debf05f1 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/RasterImpl_.cu @@ -0,0 +1,37 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "../CudaRaster.hpp" +#include "PrivateDefs.hpp" +#include "Constants.hpp" +#include "Util.inl" + +namespace CR +{ + +//------------------------------------------------------------------------ +// Stage implementations. +//------------------------------------------------------------------------ + +#include "TriangleSetup.inl" +#include "BinRaster.inl" +#include "CoarseRaster.inl" +#include "FineRaster.inl" + +} + +//------------------------------------------------------------------------ +// Stage entry points. +//------------------------------------------------------------------------ + +__global__ void __launch_bounds__(CR_SETUP_WARPS * 32, CR_SETUP_OPT_BLOCKS) triangleSetupKernel (const CR::CRParams p) { CR::triangleSetupImpl(p); } +__global__ void __launch_bounds__(CR_BIN_WARPS * 32, 1) binRasterKernel (const CR::CRParams p) { CR::binRasterImpl(p); } +__global__ void __launch_bounds__(CR_COARSE_WARPS * 32, 1) coarseRasterKernel (const CR::CRParams p) { CR::coarseRasterImpl(p); } +__global__ void __launch_bounds__(CR_FINE_MAX_WARPS * 32, 1) fineRasterKernel (const CR::CRParams p) { CR::fineRasterImpl(p); } + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/TriangleSetup.inl b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/TriangleSetup.inl new file mode 100644 index 0000000000000000000000000000000000000000..716cb0e7c3d9b4f3a8d639f8ad0e23d425fefd91 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/TriangleSetup.inl @@ -0,0 +1,402 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ + +__device__ __inline__ void snapTriangle( + const CRParams& p, + float4 v0, float4 v1, float4 v2, + int2& p0, int2& p1, int2& p2, float3& rcpW, int2& lo, int2& hi) +{ + F32 viewScaleX = (F32)(p.widthPixelsVp << (CR_SUBPIXEL_LOG2 - 1)); + F32 viewScaleY = (F32)(p.heightPixelsVp << (CR_SUBPIXEL_LOG2 - 1)); + rcpW = make_float3(1.0f / v0.w, 1.0f / v1.w, 1.0f / v2.w); + p0 = make_int2(f32_to_s32_sat(v0.x * rcpW.x * viewScaleX), f32_to_s32_sat(v0.y * rcpW.x * viewScaleY)); + p1 = make_int2(f32_to_s32_sat(v1.x * rcpW.y * viewScaleX), f32_to_s32_sat(v1.y * rcpW.y * viewScaleY)); + p2 = make_int2(f32_to_s32_sat(v2.x * rcpW.z * viewScaleX), f32_to_s32_sat(v2.y * rcpW.z * viewScaleY)); + lo = make_int2(min_min(p0.x, p1.x, p2.x), min_min(p0.y, p1.y, p2.y)); + hi = make_int2(max_max(p0.x, p1.x, p2.x), max_max(p0.y, p1.y, p2.y)); +} + +//------------------------------------------------------------------------ + +__device__ __inline__ U32 cover8x8_selectFlips(S32 dx, S32 dy) // 10 instr +{ + U32 flips = 0; + if (dy > 0 || (dy == 0 && dx <= 0)) + flips ^= (1 << CR_FLIPBIT_FLIP_X) ^ (1 << CR_FLIPBIT_FLIP_Y) ^ (1 << CR_FLIPBIT_COMPL); + if (dx > 0) + flips ^= (1 << CR_FLIPBIT_FLIP_X) ^ (1 << CR_FLIPBIT_FLIP_Y); + if (::abs(dx) < ::abs(dy)) + flips ^= (1 << CR_FLIPBIT_SWAP_XY) ^ (1 << CR_FLIPBIT_FLIP_Y); + return flips; +} + +//------------------------------------------------------------------------ + +__device__ __inline__ bool prepareTriangle( + const CRParams& p, + int2 p0, int2 p1, int2 p2, int2 lo, int2 hi, + int2& d1, int2& d2, S32& area) +{ + // Backfacing or degenerate => cull. + + d1 = make_int2(p1.x - p0.x, p1.y - p0.y); + d2 = make_int2(p2.x - p0.x, p2.y - p0.y); + area = d1.x * d2.y - d1.y * d2.x; + + if (area == 0) + return false; // Degenerate. + + if (area < 0 && (p.renderModeFlags & CudaRaster::RenderModeFlag_EnableBackfaceCulling) != 0) + return false; // Backfacing. + + // AABB falls between samples => cull. + + int sampleSize = 1 << CR_SUBPIXEL_LOG2; + int biasX = (p.widthPixelsVp << (CR_SUBPIXEL_LOG2 - 1)) - (sampleSize >> 1); + int biasY = (p.heightPixelsVp << (CR_SUBPIXEL_LOG2 - 1)) - (sampleSize >> 1); + int lox = (int)add_add(lo.x, sampleSize - 1, biasX) & -sampleSize; + int loy = (int)add_add(lo.y, sampleSize - 1, biasY) & -sampleSize; + int hix = (hi.x + biasX) & -sampleSize; + int hiy = (hi.y + biasY) & -sampleSize; + + if (lox > hix || loy > hiy) + return false; // Between pixels. + + // AABB covers 1 or 2 samples => cull if they are not covered. + + int diff = add_sub(hix, hiy, lox) - loy; + if (diff <= sampleSize) + { + int2 t0 = make_int2(add_sub(p0.x, biasX, lox), add_sub(p0.y, biasY, loy)); + int2 t1 = make_int2(add_sub(p1.x, biasX, lox), add_sub(p1.y, biasY, loy)); + int2 t2 = make_int2(add_sub(p2.x, biasX, lox), add_sub(p2.y, biasY, loy)); + S32 e0 = t0.x * t1.y - t0.y * t1.x; + S32 e1 = t1.x * t2.y - t1.y * t2.x; + S32 e2 = t2.x * t0.y - t2.y * t0.x; + if (area < 0) + { + e0 = -e0; + e1 = -e1; + e2 = -e2; + } + + if (e0 < 0 || e1 < 0 || e2 < 0) + { + if (diff == 0) + return false; // Between pixels. + + t0 = make_int2(add_sub(p0.x, biasX, hix), add_sub(p0.y, biasY, hiy)); + t1 = make_int2(add_sub(p1.x, biasX, hix), add_sub(p1.y, biasY, hiy)); + t2 = make_int2(add_sub(p2.x, biasX, hix), add_sub(p2.y, biasY, hiy)); + e0 = t0.x * t1.y - t0.y * t1.x; + e1 = t1.x * t2.y - t1.y * t2.x; + e2 = t2.x * t0.y - t2.y * t0.x; + if (area < 0) + { + e0 = -e0; + e1 = -e1; + e2 = -e2; + } + + if (e0 < 0 || e1 < 0 || e2 < 0) + return false; // Between pixels. + } + } + + // Otherwise => proceed to output the triangle. + + return true; // Visible. +} + +//------------------------------------------------------------------------ + +__device__ __inline__ void setupTriangle( + const CRParams& p, + CRTriangleHeader* th, CRTriangleData* td, int triId, + float v0z, float v1z, float v2z, + int2 p0, int2 p1, int2 p2, float3 rcpW, + int2 d1, int2 d2, S32 area) +{ + // Swap vertices 1 and 2 if area is negative. Only executed if backface culling is + // disabled (if it is enabled, we never come here with area < 0). + + if (area < 0) + { + swap(d1, d2); + swap(p1, p2); + swap(v1z, v2z); + swap(rcpW.y, rcpW.z); + area = -area; + } + + int2 wv0; + wv0.x = p0.x + (p.widthPixelsVp << (CR_SUBPIXEL_LOG2 - 1)); + wv0.y = p0.y + (p.heightPixelsVp << (CR_SUBPIXEL_LOG2 - 1)); + + // Setup depth plane equation. + + F32 zcoef = (F32)(CR_DEPTH_MAX - CR_DEPTH_MIN) * 0.5f; + F32 zbias = (F32)(CR_DEPTH_MAX + CR_DEPTH_MIN) * 0.5f; + float3 zvert = make_float3( + (v0z * zcoef) * rcpW.x + zbias, + (v1z * zcoef) * rcpW.y + zbias, + (v2z * zcoef) * rcpW.z + zbias + ); + int2 zv0 = make_int2( + wv0.x - (1 << (CR_SUBPIXEL_LOG2 - 1)), + wv0.y - (1 << (CR_SUBPIXEL_LOG2 - 1)) + ); + uint3 zpleq = setupPleq(zvert, zv0, d1, d2, 1.0f / (F32)area); + + U32 zmin = f32_to_u32_sat(fminf(fminf(zvert.x, zvert.y), zvert.z) - (F32)CR_LERP_ERROR(0)); + + // Write CRTriangleData. + + *(uint4*)td = make_uint4(zpleq.x, zpleq.y, zpleq.z, triId); + + // Determine flipbits. + + U32 f01 = cover8x8_selectFlips(d1.x, d1.y); + U32 f12 = cover8x8_selectFlips(d2.x - d1.x, d2.y - d1.y); + U32 f20 = cover8x8_selectFlips(-d2.x, -d2.y); + + // Write CRTriangleHeader. + + *(uint4*)th = make_uint4( + prmt(p0.x, p0.y, 0x5410), + prmt(p1.x, p1.y, 0x5410), + prmt(p2.x, p2.y, 0x5410), + (zmin & 0xfffff000u) | (f01 << 6) | (f12 << 2) | (f20 >> 2)); +} + +//------------------------------------------------------------------------ + +__device__ __inline__ void triangleSetupImpl(const CRParams p) +{ + __shared__ F32 s_bary[CR_SETUP_WARPS * 32][18]; + F32* bary = s_bary[threadIdx.x + threadIdx.y * 32]; + + // Compute task and image indices. + + int taskIdx = threadIdx.x + 32 * (threadIdx.y + CR_SETUP_WARPS * blockIdx.x); + int imageIdx = 0; + if (p.instanceMode) + { + imageIdx = blockIdx.z; + if (taskIdx >= p.numTriangles) + return; + } + else + { + while (imageIdx < p.numImages) + { + int count = getImageParams(p, imageIdx).triCount; + if (taskIdx < count) + break; + taskIdx -= count; + imageIdx += 1; + } + if (imageIdx == p.numImages) + return; + } + + // Per-image data structures. + + const CRImageParams& ip = getImageParams(p, imageIdx); + CRAtomics& atomics = p.atomics[imageIdx]; + + const int* indexBuffer = (const int*)p.indexBuffer; + U8* triSubtris = (U8*)p.triSubtris + imageIdx * p.maxSubtris; + CRTriangleHeader* triHeader = (CRTriangleHeader*)p.triHeader + imageIdx * p.maxSubtris; + CRTriangleData* triData = (CRTriangleData*)p.triData + imageIdx * p.maxSubtris; + + // Determine triangle index. + + int triIdx = taskIdx; + if (!p.instanceMode) + triIdx += ip.triOffset; + + // Read vertex indices. + + if ((U32)triIdx >= (U32)p.numTriangles) + { + // Bad triangle index. + triSubtris[taskIdx] = 0; + return; + } + + uint4 vidx; + vidx.x = indexBuffer[triIdx * 3 + 0]; + vidx.y = indexBuffer[triIdx * 3 + 1]; + vidx.z = indexBuffer[triIdx * 3 + 2]; + vidx.w = triIdx + 1; // Triangle index. + + if (vidx.x >= (U32)p.numVertices || + vidx.y >= (U32)p.numVertices || + vidx.z >= (U32)p.numVertices) + { + // Bad vertex index. + triSubtris[taskIdx] = 0; + return; + } + + // Read vertex positions. + + const float4* vertexBuffer = (const float4*)p.vertexBuffer; + if (p.instanceMode) + vertexBuffer += p.numVertices * imageIdx; // Instance offset. + + float4 v0 = vertexBuffer[vidx.x]; + float4 v1 = vertexBuffer[vidx.y]; + float4 v2 = vertexBuffer[vidx.z]; + + // Adjust vertex positions according to current viewport size and offset. + + v0.x = v0.x * p.xs + v0.w * p.xo; + v0.y = v0.y * p.ys + v0.w * p.yo; + v1.x = v1.x * p.xs + v1.w * p.xo; + v1.y = v1.y * p.ys + v1.w * p.yo; + v2.x = v2.x * p.xs + v2.w * p.xo; + v2.y = v2.y * p.ys + v2.w * p.yo; + + // Outside view frustum => cull. + + if (v0.w < fabsf(v0.x) | v0.w < fabsf(v0.y) | v0.w < fabsf(v0.z)) + { + if ((v0.w < +v0.x & v1.w < +v1.x & v2.w < +v2.x) | + (v0.w < -v0.x & v1.w < -v1.x & v2.w < -v2.x) | + (v0.w < +v0.y & v1.w < +v1.y & v2.w < +v2.y) | + (v0.w < -v0.y & v1.w < -v1.y & v2.w < -v2.y) | + (v0.w < +v0.z & v1.w < +v1.z & v2.w < +v2.z) | + (v0.w < -v0.z & v1.w < -v1.z & v2.w < -v2.z)) + { + triSubtris[taskIdx] = 0; + return; + } + } + + // Inside depth range => try to snap vertices. + + if (v0.w >= fabsf(v0.z) & v1.w >= fabsf(v1.z) & v2.w >= fabsf(v2.z)) + { + // Inside S16 range and small enough => fast path. + // Note: aabbLimit comes from the fact that cover8x8 + // does not support guardband with maximal viewport. + + int2 p0, p1, p2, lo, hi; + float3 rcpW; + + snapTriangle(p, v0, v1, v2, p0, p1, p2, rcpW, lo, hi); + S32 loxy = ::min(lo.x, lo.y); + S32 hixy = ::max(hi.x, hi.y); + S32 aabbLimit = (1 << (CR_MAXVIEWPORT_LOG2 + CR_SUBPIXEL_LOG2)) - 1; + + if (loxy >= -32768 && hixy <= 32767 && hixy - loxy <= aabbLimit) + { + int2 d1, d2; + S32 area; + bool res = prepareTriangle(p, p0, p1, p2, lo, hi, d1, d2, area); + triSubtris[taskIdx] = res ? 1 : 0; + + if (res) + setupTriangle( + p, + &triHeader[taskIdx], &triData[taskIdx], vidx.w, + v0.z, v1.z, v2.z, + p0, p1, p2, rcpW, + d1, d2, area); + + return; + } + } + + // Clip to view frustum. + + float4 ov0 = v0; + float4 od1 = make_float4(v1.x - v0.x, v1.y - v0.y, v1.z - v0.z, v1.w - v0.w); + float4 od2 = make_float4(v2.x - v0.x, v2.y - v0.y, v2.z - v0.z, v2.w - v0.w); + int numVerts = clipTriangleWithFrustum(bary, &ov0.x, &v1.x, &v2.x, &od1.x, &od2.x); + + // Count non-culled subtriangles. + + v0.x = ov0.x + od1.x * bary[0] + od2.x * bary[1]; + v0.y = ov0.y + od1.y * bary[0] + od2.y * bary[1]; + v0.z = ov0.z + od1.z * bary[0] + od2.z * bary[1]; + v0.w = ov0.w + od1.w * bary[0] + od2.w * bary[1]; + v1.x = ov0.x + od1.x * bary[2] + od2.x * bary[3]; + v1.y = ov0.y + od1.y * bary[2] + od2.y * bary[3]; + v1.z = ov0.z + od1.z * bary[2] + od2.z * bary[3]; + v1.w = ov0.w + od1.w * bary[2] + od2.w * bary[3]; + float4 tv1 = v1; + + int numSubtris = 0; + for (int i = 2; i < numVerts; i++) + { + v2.x = ov0.x + od1.x * bary[i * 2 + 0] + od2.x * bary[i * 2 + 1]; + v2.y = ov0.y + od1.y * bary[i * 2 + 0] + od2.y * bary[i * 2 + 1]; + v2.z = ov0.z + od1.z * bary[i * 2 + 0] + od2.z * bary[i * 2 + 1]; + v2.w = ov0.w + od1.w * bary[i * 2 + 0] + od2.w * bary[i * 2 + 1]; + + int2 p0, p1, p2, lo, hi, d1, d2; + float3 rcpW; + S32 area; + + snapTriangle(p, v0, v1, v2, p0, p1, p2, rcpW, lo, hi); + if (prepareTriangle(p, p0, p1, p2, lo, hi, d1, d2, area)) + numSubtris++; + + v1 = v2; + } + + triSubtris[taskIdx] = numSubtris; + + // Multiple subtriangles => allocate. + + int subtriBase = taskIdx; + if (numSubtris > 1) + { + subtriBase = atomicAdd(&atomics.numSubtris, numSubtris); + triHeader[taskIdx].misc = subtriBase; + if (subtriBase + numSubtris > p.maxSubtris) + numVerts = 0; + } + + // Setup subtriangles. + + v1 = tv1; + for (int i = 2; i < numVerts; i++) + { + v2.x = ov0.x + od1.x * bary[i * 2 + 0] + od2.x * bary[i * 2 + 1]; + v2.y = ov0.y + od1.y * bary[i * 2 + 0] + od2.y * bary[i * 2 + 1]; + v2.z = ov0.z + od1.z * bary[i * 2 + 0] + od2.z * bary[i * 2 + 1]; + v2.w = ov0.w + od1.w * bary[i * 2 + 0] + od2.w * bary[i * 2 + 1]; + + int2 p0, p1, p2, lo, hi, d1, d2; + float3 rcpW; + S32 area; + + snapTriangle(p, v0, v1, v2, p0, p1, p2, rcpW, lo, hi); + if (prepareTriangle(p, p0, p1, p2, lo, hi, d1, d2, area)) + { + setupTriangle( + p, + &triHeader[subtriBase], &triData[subtriBase], vidx.w, + v0.z, v1.z, v2.z, + p0, p1, p2, rcpW, + d1, d2, area); + + subtriBase++; + } + + v1 = v2; + } +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Util.inl b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Util.inl new file mode 100644 index 0000000000000000000000000000000000000000..cf5bb79f9606e3b81e904c2815910e3a0626598e --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/cudaraster/impl/Util.inl @@ -0,0 +1,452 @@ +// Copyright (c) 2009-2022, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "PrivateDefs.hpp" + +namespace CR +{ +//------------------------------------------------------------------------ + +template __device__ __inline__ void swap(T& a, T& b) { T t = a; a = b; b = t; } + +__device__ __inline__ U32 getLo (U64 a) { return __double2loint(__longlong_as_double(a)); } +__device__ __inline__ S32 getLo (S64 a) { return __double2loint(__longlong_as_double(a)); } +__device__ __inline__ U32 getHi (U64 a) { return __double2hiint(__longlong_as_double(a)); } +__device__ __inline__ S32 getHi (S64 a) { return __double2hiint(__longlong_as_double(a)); } +__device__ __inline__ U64 combineLoHi (U32 lo, U32 hi) { return __double_as_longlong(__hiloint2double(hi, lo)); } +__device__ __inline__ S64 combineLoHi (S32 lo, S32 hi) { return __double_as_longlong(__hiloint2double(hi, lo)); } +__device__ __inline__ U32 getLaneMaskLt (void) { U32 r; asm("mov.u32 %0, %lanemask_lt;" : "=r"(r)); return r; } +__device__ __inline__ U32 getLaneMaskLe (void) { U32 r; asm("mov.u32 %0, %lanemask_le;" : "=r"(r)); return r; } +__device__ __inline__ U32 getLaneMaskGt (void) { U32 r; asm("mov.u32 %0, %lanemask_gt;" : "=r"(r)); return r; } +__device__ __inline__ U32 getLaneMaskGe (void) { U32 r; asm("mov.u32 %0, %lanemask_ge;" : "=r"(r)); return r; } +__device__ __inline__ int findLeadingOne (U32 v) { U32 r; asm("bfind.u32 %0, %1;" : "=r"(r) : "r"(v)); return r; } +__device__ __inline__ bool singleLane (void) { return ((::__ballot_sync(~0u, true) & getLaneMaskLt()) == 0); } + +__device__ __inline__ void add_add_carry (U32& rlo, U32 alo, U32 blo, U32& rhi, U32 ahi, U32 bhi) { U64 r = combineLoHi(alo, ahi) + combineLoHi(blo, bhi); rlo = getLo(r); rhi = getHi(r); } +__device__ __inline__ S32 f32_to_s32_sat (F32 a) { S32 v; asm("cvt.rni.sat.s32.f32 %0, %1;" : "=r"(v) : "f"(a)); return v; } +__device__ __inline__ U32 f32_to_u32_sat (F32 a) { U32 v; asm("cvt.rni.sat.u32.f32 %0, %1;" : "=r"(v) : "f"(a)); return v; } +__device__ __inline__ U32 f32_to_u32_sat_rmi (F32 a) { U32 v; asm("cvt.rmi.sat.u32.f32 %0, %1;" : "=r"(v) : "f"(a)); return v; } +__device__ __inline__ U32 f32_to_u8_sat (F32 a) { U32 v; asm("cvt.rni.sat.u8.f32 %0, %1;" : "=r"(v) : "f"(a)); return v; } +__device__ __inline__ S64 f32_to_s64 (F32 a) { S64 v; asm("cvt.rni.s64.f32 %0, %1;" : "=l"(v) : "f"(a)); return v; } +__device__ __inline__ S32 add_s16lo_s16lo (S32 a, S32 b) { S32 v; asm("vadd.s32.s32.s32 %0, %1.h0, %2.h0;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ S32 add_s16hi_s16lo (S32 a, S32 b) { S32 v; asm("vadd.s32.s32.s32 %0, %1.h1, %2.h0;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ S32 add_s16lo_s16hi (S32 a, S32 b) { S32 v; asm("vadd.s32.s32.s32 %0, %1.h0, %2.h1;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ S32 add_s16hi_s16hi (S32 a, S32 b) { S32 v; asm("vadd.s32.s32.s32 %0, %1.h1, %2.h1;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ S32 sub_s16lo_s16lo (S32 a, S32 b) { S32 v; asm("vsub.s32.s32.s32 %0, %1.h0, %2.h0;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ S32 sub_s16hi_s16lo (S32 a, S32 b) { S32 v; asm("vsub.s32.s32.s32 %0, %1.h1, %2.h0;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ S32 sub_s16lo_s16hi (S32 a, S32 b) { S32 v; asm("vsub.s32.s32.s32 %0, %1.h0, %2.h1;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ S32 sub_s16hi_s16hi (S32 a, S32 b) { S32 v; asm("vsub.s32.s32.s32 %0, %1.h1, %2.h1;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ S32 sub_u16lo_u16lo (U32 a, U32 b) { S32 v; asm("vsub.s32.u32.u32 %0, %1.h0, %2.h0;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ S32 sub_u16hi_u16lo (U32 a, U32 b) { S32 v; asm("vsub.s32.u32.u32 %0, %1.h1, %2.h0;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ S32 sub_u16lo_u16hi (U32 a, U32 b) { S32 v; asm("vsub.s32.u32.u32 %0, %1.h0, %2.h1;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ S32 sub_u16hi_u16hi (U32 a, U32 b) { S32 v; asm("vsub.s32.u32.u32 %0, %1.h1, %2.h1;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ U32 add_b0 (U32 a, U32 b) { U32 v; asm("vadd.u32.u32.u32 %0, %1.b0, %2;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ U32 add_b1 (U32 a, U32 b) { U32 v; asm("vadd.u32.u32.u32 %0, %1.b1, %2;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ U32 add_b2 (U32 a, U32 b) { U32 v; asm("vadd.u32.u32.u32 %0, %1.b2, %2;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ U32 add_b3 (U32 a, U32 b) { U32 v; asm("vadd.u32.u32.u32 %0, %1.b3, %2;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ U32 vmad_b0 (U32 a, U32 b, U32 c) { U32 v; asm("vmad.u32.u32.u32 %0, %1.b0, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 vmad_b1 (U32 a, U32 b, U32 c) { U32 v; asm("vmad.u32.u32.u32 %0, %1.b1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 vmad_b2 (U32 a, U32 b, U32 c) { U32 v; asm("vmad.u32.u32.u32 %0, %1.b2, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 vmad_b3 (U32 a, U32 b, U32 c) { U32 v; asm("vmad.u32.u32.u32 %0, %1.b3, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 vmad_b0_b3 (U32 a, U32 b, U32 c) { U32 v; asm("vmad.u32.u32.u32 %0, %1.b0, %2.b3, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 vmad_b1_b3 (U32 a, U32 b, U32 c) { U32 v; asm("vmad.u32.u32.u32 %0, %1.b1, %2.b3, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 vmad_b2_b3 (U32 a, U32 b, U32 c) { U32 v; asm("vmad.u32.u32.u32 %0, %1.b2, %2.b3, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 vmad_b3_b3 (U32 a, U32 b, U32 c) { U32 v; asm("vmad.u32.u32.u32 %0, %1.b3, %2.b3, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 add_mask8 (U32 a, U32 b) { U32 v; U32 z=0; asm("vadd.u32.u32.u32 %0.b0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(z)); return v; } +__device__ __inline__ U32 sub_mask8 (U32 a, U32 b) { U32 v; U32 z=0; asm("vsub.u32.u32.u32 %0.b0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(z)); return v; } +__device__ __inline__ S32 max_max (S32 a, S32 b, S32 c) { S32 v; asm("vmax.s32.s32.s32.max %0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ S32 min_min (S32 a, S32 b, S32 c) { S32 v; asm("vmin.s32.s32.s32.min %0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ S32 max_add (S32 a, S32 b, S32 c) { S32 v; asm("vmax.s32.s32.s32.add %0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ S32 min_add (S32 a, S32 b, S32 c) { S32 v; asm("vmin.s32.s32.s32.add %0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 add_add (U32 a, U32 b, U32 c) { U32 v; asm("vadd.u32.u32.u32.add %0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 sub_add (U32 a, U32 b, U32 c) { U32 v; asm("vsub.u32.u32.u32.add %0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 add_sub (U32 a, U32 b, U32 c) { U32 v; asm("vsub.u32.u32.u32.add %0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(c), "r"(b)); return v; } +__device__ __inline__ S32 add_clamp_0_x (S32 a, S32 b, S32 c) { S32 v; asm("vadd.u32.s32.s32.sat.min %0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ S32 add_clamp_b0 (S32 a, S32 b, S32 c) { S32 v; asm("vadd.u32.s32.s32.sat %0.b0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ S32 add_clamp_b2 (S32 a, S32 b, S32 c) { S32 v; asm("vadd.u32.s32.s32.sat %0.b2, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ U32 prmt (U32 a, U32 b, U32 c) { U32 v; asm("prmt.b32 %0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ S32 u32lo_sext (U32 a) { U32 v; asm("cvt.s16.u32 %0, %1;" : "=r"(v) : "r"(a)); return v; } +__device__ __inline__ U32 slct (U32 a, U32 b, S32 c) { U32 v; asm("slct.u32.s32 %0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ S32 slct (S32 a, S32 b, S32 c) { S32 v; asm("slct.s32.s32 %0, %1, %2, %3;" : "=r"(v) : "r"(a), "r"(b), "r"(c)); return v; } +__device__ __inline__ F32 slct (F32 a, F32 b, S32 c) { F32 v; asm("slct.f32.s32 %0, %1, %2, %3;" : "=f"(v) : "f"(a), "f"(b), "r"(c)); return v; } +__device__ __inline__ U32 isetge (S32 a, S32 b) { U32 v; asm("set.ge.u32.s32 %0, %1, %2;" : "=r"(v) : "r"(a), "r"(b)); return v; } +__device__ __inline__ F64 rcp_approx (F64 a) { F64 v; asm("rcp.approx.ftz.f64 %0, %1;" : "=d"(v) : "d"(a)); return v; } +__device__ __inline__ F32 fma_rm (F32 a, F32 b, F32 c) { F32 v; asm("fma.rm.f32 %0, %1, %2, %3;" : "=f"(v) : "f"(a), "f"(b), "f"(c)); return v; } +__device__ __inline__ U32 idiv_fast (U32 a, U32 b); + +__device__ __inline__ uint3 setupPleq (float3 values, int2 v0, int2 d1, int2 d2, F32 areaRcp); + +__device__ __inline__ void cover8x8_setupLUT (volatile U64* lut); +__device__ __inline__ U64 cover8x8_exact_fast (S32 ox, S32 oy, S32 dx, S32 dy, U32 flips, volatile const U64* lut); // Assumes viewport <= 2^11, subpixels <= 2^4, no guardband. +__device__ __inline__ U64 cover8x8_lookupMask (S64 yinit, U32 yinc, U32 flips, volatile const U64* lut); + +__device__ __inline__ U64 cover8x8_exact_noLUT (S32 ox, S32 oy, S32 dx, S32 dy); // optimized reference implementation, does not require look-up table +__device__ __inline__ U64 cover8x8_conservative_noLUT (S32 ox, S32 oy, S32 dx, S32 dy); +__device__ __inline__ U64 cover8x8_generateMask_noLUT (S32 curr, S32 dx, S32 dy); + +template __device__ __inline__ void sortShared(T* ptr, int numItems); // Assumes that numItems <= threadsInBlock. Must sync before & after the call. + +__device__ __inline__ const CRImageParams& getImageParams(const CRParams& p, int idx) +{ + return (idx < CR_EMBED_IMAGE_PARAMS) ? p.imageParamsFirst[idx] : p.imageParamsExtra[idx - CR_EMBED_IMAGE_PARAMS]; +} + +//------------------------------------------------------------------------ + +__device__ __inline__ int clipPolygonWithPlane(F32* baryOut, const F32* baryIn, int numIn, F32 v0, F32 v1, F32 v2) +{ + int numOut = 0; + if (numIn >= 3) + { + int ai = (numIn - 1) * 2; + F32 av = v0 + v1 * baryIn[ai + 0] + v2 * baryIn[ai + 1]; + for (int bi = 0; bi < numIn * 2; bi += 2) + { + F32 bv = v0 + v1 * baryIn[bi + 0] + v2 * baryIn[bi + 1]; + if (av * bv < 0.0f) + { + F32 bc = av / (av - bv); + F32 ac = 1.0f - bc; + baryOut[numOut + 0] = baryIn[ai + 0] * ac + baryIn[bi + 0] * bc; + baryOut[numOut + 1] = baryIn[ai + 1] * ac + baryIn[bi + 1] * bc; + numOut += 2; + } + if (bv >= 0.0f) + { + baryOut[numOut + 0] = baryIn[bi + 0]; + baryOut[numOut + 1] = baryIn[bi + 1]; + numOut += 2; + } + ai = bi; + av = bv; + } + } + return (numOut >> 1); +} + +//------------------------------------------------------------------------ + +__device__ __inline__ int clipTriangleWithFrustum(F32* bary, const F32* v0, const F32* v1, const F32* v2, const F32* d1, const F32* d2) +{ + int num = 3; + bary[0] = 0.0f, bary[1] = 0.0f; + bary[2] = 1.0f, bary[3] = 0.0f; + bary[4] = 0.0f, bary[5] = 1.0f; + + if ((v0[3] < fabsf(v0[0])) | (v1[3] < fabsf(v1[0])) | (v2[3] < fabsf(v2[0]))) + { + F32 temp[18]; + num = clipPolygonWithPlane(temp, bary, num, v0[3] + v0[0], d1[3] + d1[0], d2[3] + d2[0]); + num = clipPolygonWithPlane(bary, temp, num, v0[3] - v0[0], d1[3] - d1[0], d2[3] - d2[0]); + } + if ((v0[3] < fabsf(v0[1])) | (v1[3] < fabsf(v1[1])) | (v2[3] < fabsf(v2[1]))) + { + F32 temp[18]; + num = clipPolygonWithPlane(temp, bary, num, v0[3] + v0[1], d1[3] + d1[1], d2[3] + d2[1]); + num = clipPolygonWithPlane(bary, temp, num, v0[3] - v0[1], d1[3] - d1[1], d2[3] - d2[1]); + } + if ((v0[3] < fabsf(v0[2])) | (v1[3] < fabsf(v1[2])) | (v2[3] < fabsf(v2[2]))) + { + F32 temp[18]; + num = clipPolygonWithPlane(temp, bary, num, v0[3] + v0[2], d1[3] + d1[2], d2[3] + d2[2]); + num = clipPolygonWithPlane(bary, temp, num, v0[3] - v0[2], d1[3] - d1[2], d2[3] - d2[2]); + } + return num; +} + +//------------------------------------------------------------------------ + +__device__ __inline__ U32 idiv_fast(U32 a, U32 b) +{ + return f32_to_u32_sat_rmi(((F32)a + 0.5f) / (F32)b); +} + +//------------------------------------------------------------------------ + +__device__ __inline__ U32 toABGR(float4 color) +{ + // 11 instructions: 4*FFMA, 4*F2I, 3*PRMT + U32 x = f32_to_u32_sat_rmi(fma_rm(color.x, (1 << 24) * 255.0f, (1 << 24) * 0.5f)); + U32 y = f32_to_u32_sat_rmi(fma_rm(color.y, (1 << 24) * 255.0f, (1 << 24) * 0.5f)); + U32 z = f32_to_u32_sat_rmi(fma_rm(color.z, (1 << 24) * 255.0f, (1 << 24) * 0.5f)); + U32 w = f32_to_u32_sat_rmi(fma_rm(color.w, (1 << 24) * 255.0f, (1 << 24) * 0.5f)); + return prmt(prmt(x, y, 0x0073), prmt(z, w, 0x0073), 0x5410); +} + +//------------------------------------------------------------------------ +// v0 = subpixels relative to the bottom-left sampling point + +__device__ __inline__ uint3 setupPleq(float3 values, int2 v0, int2 d1, int2 d2, F32 areaRcp) +{ + F32 mx = fmaxf(fmaxf(values.x, values.y), values.z); + int sh = ::min(::max((__float_as_int(mx) >> 23) - (127 + 22), 0), 8); + S32 t0 = (U32)values.x >> sh; + S32 t1 = ((U32)values.y >> sh) - t0; + S32 t2 = ((U32)values.z >> sh) - t0; + + U32 rcpMant = (__float_as_int(areaRcp) & 0x007FFFFF) | 0x00800000; + int rcpShift = (23 + 127) - (__float_as_int(areaRcp) >> 23); + + uint3 pleq; + S64 xc = ((S64)t1 * d2.y - (S64)t2 * d1.y) * rcpMant; + S64 yc = ((S64)t2 * d1.x - (S64)t1 * d2.x) * rcpMant; + pleq.x = (U32)(xc >> (rcpShift - (sh + CR_SUBPIXEL_LOG2))); + pleq.y = (U32)(yc >> (rcpShift - (sh + CR_SUBPIXEL_LOG2))); + + S32 centerX = (v0.x * 2 + min_min(d1.x, d2.x, 0) + max_max(d1.x, d2.x, 0)) >> (CR_SUBPIXEL_LOG2 + 1); + S32 centerY = (v0.y * 2 + min_min(d1.y, d2.y, 0) + max_max(d1.y, d2.y, 0)) >> (CR_SUBPIXEL_LOG2 + 1); + S32 vcx = v0.x - (centerX << CR_SUBPIXEL_LOG2); + S32 vcy = v0.y - (centerY << CR_SUBPIXEL_LOG2); + + pleq.z = t0 << sh; + pleq.z -= (U32)(((xc >> 13) * vcx + (yc >> 13) * vcy) >> (rcpShift - (sh + 13))); + pleq.z -= pleq.x * centerX + pleq.y * centerY; + return pleq; +} + +//------------------------------------------------------------------------ + +__device__ __inline__ void cover8x8_setupLUT(volatile U64* lut) +{ + for (S32 lutIdx = threadIdx.x + blockDim.x * threadIdx.y; lutIdx < CR_COVER8X8_LUT_SIZE; lutIdx += blockDim.x * blockDim.y) + { + int half = (lutIdx < (12 << 5)) ? 0 : 1; + int yint = (lutIdx >> 5) - half * 12 - 3; + U32 shape = ((lutIdx >> 2) & 7) << (31 - 2); + S32 slctSwapXY = lutIdx << (31 - 1); + S32 slctNegX = lutIdx << (31 - 0); + S32 slctCompl = slctSwapXY ^ slctNegX; + + U64 mask = 0; + int xlo = half * 4; + int xhi = xlo + 4; + for (int x = xlo; x < xhi; x++) + { + int ylo = slct(0, ::max(yint, 0), slctCompl); + int yhi = slct(::min(yint, 8), 8, slctCompl); + for (int y = ylo; y < yhi; y++) + { + int xx = slct(x, y, slctSwapXY); + int yy = slct(y, x, slctSwapXY); + xx = slct(xx, 7 - xx, slctNegX); + mask |= (U64)1 << (xx + yy * 8); + } + yint += shape >> 31; + shape <<= 1; + } + lut[lutIdx] = mask; + } +} + +//------------------------------------------------------------------------ + +__device__ __inline__ U64 cover8x8_exact_fast(S32 ox, S32 oy, S32 dx, S32 dy, U32 flips, volatile const U64* lut) // 52 instr +{ + F32 yinitBias = (F32)(1 << (31 - CR_MAXVIEWPORT_LOG2 - CR_SUBPIXEL_LOG2 * 2)); + F32 yinitScale = (F32)(1 << (32 - CR_SUBPIXEL_LOG2)); + F32 yincScale = 65536.0f * 65536.0f; + + S32 slctFlipY = flips << (31 - CR_FLIPBIT_FLIP_Y); + S32 slctFlipX = flips << (31 - CR_FLIPBIT_FLIP_X); + S32 slctSwapXY = flips << (31 - CR_FLIPBIT_SWAP_XY); + + // Evaluate cross product. + + S32 t = ox * dy - oy * dx; + F32 det = (F32)slct(t, t - dy * (7 << CR_SUBPIXEL_LOG2), slctFlipX); + if (flips >= (1 << CR_FLIPBIT_COMPL)) + det = -det; + + // Represent Y as a function of X. + + F32 xrcp = 1.0f / (F32)::abs(slct(dx, dy, slctSwapXY)); + F32 yzero = det * yinitScale * xrcp + yinitBias; + S64 yinit = f32_to_s64(slct(yzero, -yzero, slctFlipY)); + U32 yinc = f32_to_u32_sat((F32)::abs(slct(dy, dx, slctSwapXY)) * xrcp * yincScale); + + // Lookup. + + return cover8x8_lookupMask(yinit, yinc, flips, lut); +} + +//------------------------------------------------------------------------ + +__device__ __inline__ U64 cover8x8_lookupMask(S64 yinit, U32 yinc, U32 flips, volatile const U64* lut) +{ + // First half. + + U32 yfrac = getLo(yinit); + U32 shape = add_clamp_0_x(getHi(yinit) + 4, 0, 11); + add_add_carry(yfrac, yfrac, yinc, shape, shape, shape); + add_add_carry(yfrac, yfrac, yinc, shape, shape, shape); + add_add_carry(yfrac, yfrac, yinc, shape, shape, shape); + int oct = flips & ((1 << CR_FLIPBIT_FLIP_X) | (1 << CR_FLIPBIT_SWAP_XY)); + U64 mask = *(U64*)((U8*)lut + oct + (shape << 5)); + + // Second half. + + add_add_carry(yfrac, yfrac, yinc, shape, shape, shape); + shape = add_clamp_0_x(getHi(yinit) + 4, __popc(shape & 15), 11); + add_add_carry(yfrac, yfrac, yinc, shape, shape, shape); + add_add_carry(yfrac, yfrac, yinc, shape, shape, shape); + add_add_carry(yfrac, yfrac, yinc, shape, shape, shape); + mask |= *(U64*)((U8*)lut + oct + (shape << 5) + (12 << 8)); + return (flips >= (1 << CR_FLIPBIT_COMPL)) ? ~mask : mask; +} + +//------------------------------------------------------------------------ + +__device__ __inline__ U64 cover8x8_exact_noLUT(S32 ox, S32 oy, S32 dx, S32 dy) +{ + S32 curr = ox * dy - oy * dx; + if (dy > 0 || (dy == 0 && dx <= 0)) curr--; // exclusive + return cover8x8_generateMask_noLUT(curr, dx, dy); +} + +//------------------------------------------------------------------------ + +__device__ __inline__ U64 cover8x8_conservative_noLUT(S32 ox, S32 oy, S32 dx, S32 dy) +{ + S32 curr = ox * dy - oy * dx; + if (dy > 0 || (dy == 0 && dx <= 0)) curr--; // exclusive + curr += (::abs(dx) + ::abs(dy)) << (CR_SUBPIXEL_LOG2 - 1); + return cover8x8_generateMask_noLUT(curr, dx, dy); +} + +//------------------------------------------------------------------------ + +__device__ __inline__ U64 cover8x8_generateMask_noLUT(S32 curr, S32 dx, S32 dy) +{ + curr += (dx - dy) * (7 << CR_SUBPIXEL_LOG2); + S32 stepX = dy << (CR_SUBPIXEL_LOG2 + 1); + S32 stepYorig = -dx - dy * 7; + S32 stepY = stepYorig << (CR_SUBPIXEL_LOG2 + 1); + + U32 hi = isetge(curr, 0); + U32 frac = curr + curr; + for (int i = 62; i >= 32; i--) + add_add_carry(frac, frac, ((i & 7) == 7) ? stepY : stepX, hi, hi, hi); + + U32 lo = 0; + for (int i = 31; i >= 0; i--) + add_add_carry(frac, frac, ((i & 7) == 7) ? stepY : stepX, lo, lo, lo); + + lo ^= lo >> 1, hi ^= hi >> 1; + lo ^= lo >> 2, hi ^= hi >> 2; + lo ^= lo >> 4, hi ^= hi >> 4; + lo ^= lo >> 8, hi ^= hi >> 8; + lo ^= lo >> 16, hi ^= hi >> 16; + + if (dy < 0) + { + lo ^= 0x55AA55AA; + hi ^= 0x55AA55AA; + } + if (stepYorig < 0) + { + lo ^= 0xFF00FF00; + hi ^= 0x00FF00FF; + } + if ((hi & 1) != 0) + lo = ~lo; + + return combineLoHi(lo, hi); +} + +//------------------------------------------------------------------------ + +template __device__ __inline__ void sortShared(T* ptr, int numItems) +{ + int thrInBlock = threadIdx.x + threadIdx.y * blockDim.x; + int range = 16; + + // Use transposition sort within each 16-wide subrange. + + int base = thrInBlock * 2; + bool act = (base < numItems - 1); + U32 actMask = __ballot_sync(~0u, act); + if (act) + { + bool tryOdd = (base < numItems - 2 && (~base & (range - 2)) != 0); + T mid = ptr[base + 1]; + + for (int iter = 0; iter < range; iter += 2) + { + // Evens. + + T tmp = ptr[base + 0]; + if (tmp > mid) + { + ptr[base + 0] = mid; + mid = tmp; + } + __syncwarp(actMask); + + // Odds. + + if (tryOdd) + { + tmp = ptr[base + 2]; + if (mid > tmp) + { + ptr[base + 2] = mid; + mid = tmp; + } + } + __syncwarp(actMask); + } + ptr[base + 1] = mid; + } + + // Multiple subranges => Merge hierarchically. + + for (; range < numItems; range <<= 1) + { + // Assuming that we would insert the current item into the other + // subrange, use binary search to find the appropriate slot. + + __syncthreads(); + + T item; + int slot; + if (thrInBlock < numItems) + { + item = ptr[thrInBlock]; + slot = (thrInBlock & -range) ^ range; + if (slot < numItems) + { + T tmp = ptr[slot]; + bool inclusive = ((thrInBlock & range) != 0); + if (tmp < item || (inclusive && tmp == item)) + { + for (int step = (range >> 1); step != 0; step >>= 1) + { + int probe = slot + step; + if (probe < numItems) + { + tmp = ptr[probe]; + if (tmp < item || (inclusive && tmp == item)) + slot = probe; + } + } + slot++; + } + } + } + + // Store the item at an appropriate place. + + __syncthreads(); + + if (thrInBlock < numItems) + ptr[slot + (thrInBlock & (range * 2 - 1)) - range] = item; + } +} + +//------------------------------------------------------------------------ +} diff --git a/extensions/nvdiffrast/nvdiffrast/common/framework.h b/extensions/nvdiffrast/nvdiffrast/common/framework.h new file mode 100644 index 0000000000000000000000000000000000000000..3b1b8e993f1c21bd9f0129e14645db9e773cec71 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/framework.h @@ -0,0 +1,49 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once + +// Framework-specific macros to enable code sharing. + +//------------------------------------------------------------------------ +// Tensorflow. + +#ifdef NVDR_TENSORFLOW +#define EIGEN_USE_GPU +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/platform/default/logging.h" +using namespace tensorflow; +using namespace tensorflow::shape_inference; +#define NVDR_CTX_ARGS OpKernelContext* _nvdr_ctx +#define NVDR_CTX_PARAMS _nvdr_ctx +#define NVDR_CHECK(COND, ERR) OP_REQUIRES(_nvdr_ctx, COND, errors::Internal(ERR)) +#define NVDR_CHECK_CUDA_ERROR(CUDA_CALL) OP_CHECK_CUDA_ERROR(_nvdr_ctx, CUDA_CALL) +#define NVDR_CHECK_GL_ERROR(GL_CALL) OP_CHECK_GL_ERROR(_nvdr_ctx, GL_CALL) +#endif + +//------------------------------------------------------------------------ +// PyTorch. + +#ifdef NVDR_TORCH +#ifndef __CUDACC__ +#include +#include +#include +#include +#include +#endif +#define NVDR_CTX_ARGS int _nvdr_ctx_dummy +#define NVDR_CTX_PARAMS 0 +#define NVDR_CHECK(COND, ERR) do { TORCH_CHECK(COND, ERR) } while(0) +#define NVDR_CHECK_CUDA_ERROR(CUDA_CALL) do { cudaError_t err = CUDA_CALL; TORCH_CHECK(!err, "Cuda error: ", cudaGetLastError(), "[", #CUDA_CALL, ";]"); } while(0) +#define NVDR_CHECK_GL_ERROR(GL_CALL) do { GL_CALL; GLenum err = glGetError(); TORCH_CHECK(err == GL_NO_ERROR, "OpenGL error: ", getGLErrorString(err), "[", #GL_CALL, ";]"); } while(0) +#endif + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/glutil.cpp b/extensions/nvdiffrast/nvdiffrast/common/glutil.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bea4d86bf8ae95bc958a0cc8dd18238e4b663324 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/glutil.cpp @@ -0,0 +1,403 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// Common. +//------------------------------------------------------------------------ + +#include "framework.h" +#include "glutil.h" +#include +#include + +// Create the function pointers. +#define GLUTIL_EXT(return_type, name, ...) return_type (GLAPIENTRY* name)(__VA_ARGS__) = 0; +#include "glutil_extlist.h" +#undef GLUTIL_EXT + +// Track initialization status. +static volatile bool s_glExtInitialized = false; + +// Error strings. +const char* getGLErrorString(GLenum err) +{ + switch(err) + { + case GL_NO_ERROR: return "GL_NO_ERROR"; + case GL_INVALID_ENUM: return "GL_INVALID_ENUM"; + case GL_INVALID_VALUE: return "GL_INVALID_VALUE"; + case GL_INVALID_OPERATION: return "GL_INVALID_OPERATION"; + case GL_STACK_OVERFLOW: return "GL_STACK_OVERFLOW"; + case GL_STACK_UNDERFLOW: return "GL_STACK_UNDERFLOW"; + case GL_OUT_OF_MEMORY: return "GL_OUT_OF_MEMORY"; + case GL_INVALID_FRAMEBUFFER_OPERATION: return "GL_INVALID_FRAMEBUFFER_OPERATION"; + case GL_TABLE_TOO_LARGE: return "GL_TABLE_TOO_LARGE"; + case GL_CONTEXT_LOST: return "GL_CONTEXT_LOST"; + } + return "Unknown error"; +} + +//------------------------------------------------------------------------ +// Windows. +//------------------------------------------------------------------------ + +#ifdef _WIN32 + +static CRITICAL_SECTION getInitializedCriticalSection(void) +{ + CRITICAL_SECTION cs; + InitializeCriticalSection(&cs); + return cs; +} + +static CRITICAL_SECTION s_getProcAddressMutex = getInitializedCriticalSection(); + +static void safeGetProcAddress(const char* name, PROC* pfn) +{ + PROC result = wglGetProcAddress(name); + if (!result) + { + LeaveCriticalSection(&s_getProcAddressMutex); // Prepare for thread exit. + LOG(FATAL) << "wglGetProcAddress() failed for '" << name << "'"; + exit(1); // Should never get here but make sure we exit. + } + *pfn = result; +} + +static void initializeGLExtensions(void) +{ + // Use critical section for thread safety. + EnterCriticalSection(&s_getProcAddressMutex); + + // Only dig function pointers if not done already. + if (!s_glExtInitialized) + { + // Generate code to populate the function pointers. +#define GLUTIL_EXT(return_type, name, ...) safeGetProcAddress(#name, (PROC*)&name); +#include "glutil_extlist.h" +#undef GLUTIL_EXT + + // Mark as initialized. + s_glExtInitialized = true; + } + + // Done. + LeaveCriticalSection(&s_getProcAddressMutex); + return; +} + +void setGLContext(GLContext& glctx) +{ + if (!glctx.hglrc) + LOG(FATAL) << "setGLContext() called with null gltcx"; + if (!wglMakeCurrent(glctx.hdc, glctx.hglrc)) + LOG(FATAL) << "wglMakeCurrent() failed when setting GL context"; + + if (glctx.extInitialized) + return; + initializeGLExtensions(); + glctx.extInitialized = 1; +} + +void releaseGLContext(void) +{ + if (!wglMakeCurrent(NULL, NULL)) + LOG(FATAL) << "wglMakeCurrent() failed when releasing GL context"; +} + +extern "C" int set_gpu(const char*); // In setgpu.lib +GLContext createGLContext(int cudaDeviceIdx) +{ + if (cudaDeviceIdx >= 0) + { + char pciBusId[256] = ""; + LOG(INFO) << "Creating GL context for Cuda device " << cudaDeviceIdx; + if (cudaDeviceGetPCIBusId(pciBusId, 255, cudaDeviceIdx)) + { + LOG(INFO) << "PCI bus id query failed"; + } + else + { + int res = set_gpu(pciBusId); + LOG(INFO) << "Selecting device with PCI bus id " << pciBusId << " - " << (res ? "failed, expect crash or major slowdown" : "success"); + } + } + + HINSTANCE hInstance = GetModuleHandle(NULL); + WNDCLASS wc = {}; + wc.style = CS_OWNDC; + wc.lpfnWndProc = DefWindowProc; + wc.hInstance = hInstance; + wc.lpszClassName = "__DummyGLClassCPP"; + int res = RegisterClass(&wc); + + HWND hwnd = CreateWindow( + "__DummyGLClassCPP", // lpClassName + "__DummyGLWindowCPP", // lpWindowName + WS_OVERLAPPEDWINDOW, // dwStyle + CW_USEDEFAULT, // x + CW_USEDEFAULT, // y + 0, 0, // nWidth, nHeight + NULL, NULL, // hWndParent, hMenu + hInstance, // hInstance + NULL // lpParam + ); + + PIXELFORMATDESCRIPTOR pfd = {}; + pfd.dwFlags = PFD_SUPPORT_OPENGL; + pfd.iPixelType = PFD_TYPE_RGBA; + pfd.iLayerType = PFD_MAIN_PLANE; + pfd.cColorBits = 32; + pfd.cDepthBits = 24; + pfd.cStencilBits = 8; + + HDC hdc = GetDC(hwnd); + int pixelformat = ChoosePixelFormat(hdc, &pfd); + SetPixelFormat(hdc, pixelformat, &pfd); + + HGLRC hglrc = wglCreateContext(hdc); + LOG(INFO) << std::hex << std::setfill('0') + << "WGL OpenGL context created (hdc: 0x" << std::setw(8) << (uint32_t)(uintptr_t)hdc + << ", hglrc: 0x" << std::setw(8) << (uint32_t)(uintptr_t)hglrc << ")"; + + GLContext glctx = {hdc, hglrc, 0}; + return glctx; +} + +void destroyGLContext(GLContext& glctx) +{ + if (!glctx.hglrc) + LOG(FATAL) << "destroyGLContext() called with null gltcx"; + + // If this is the current context, release it. + if (wglGetCurrentContext() == glctx.hglrc) + releaseGLContext(); + + HWND hwnd = WindowFromDC(glctx.hdc); + if (!hwnd) + LOG(FATAL) << "WindowFromDC() failed"; + if (!ReleaseDC(hwnd, glctx.hdc)) + LOG(FATAL) << "ReleaseDC() failed"; + if (!wglDeleteContext(glctx.hglrc)) + LOG(FATAL) << "wglDeleteContext() failed"; + if (!DestroyWindow(hwnd)) + LOG(FATAL) << "DestroyWindow() failed"; + + LOG(INFO) << std::hex << std::setfill('0') + << "WGL OpenGL context destroyed (hdc: 0x" << std::setw(8) << (uint32_t)(uintptr_t)glctx.hdc + << ", hglrc: 0x" << std::setw(8) << (uint32_t)(uintptr_t)glctx.hglrc << ")"; + + memset(&glctx, 0, sizeof(GLContext)); +} + +#endif // _WIN32 + +//------------------------------------------------------------------------ +// Linux. +//------------------------------------------------------------------------ + +#ifdef __linux__ + +static pthread_mutex_t s_getProcAddressMutex; + +typedef void (*PROCFN)(); + +static void safeGetProcAddress(const char* name, PROCFN* pfn) +{ + PROCFN result = eglGetProcAddress(name); + if (!result) + { + pthread_mutex_unlock(&s_getProcAddressMutex); // Prepare for thread exit. + LOG(FATAL) << "wglGetProcAddress() failed for '" << name << "'"; + exit(1); // Should never get here but make sure we exit. + } + *pfn = result; +} + +static void initializeGLExtensions(void) +{ + pthread_mutex_lock(&s_getProcAddressMutex); + + // Only dig function pointers if not done already. + if (!s_glExtInitialized) + { + // Generate code to populate the function pointers. +#define GLUTIL_EXT(return_type, name, ...) safeGetProcAddress(#name, (PROCFN*)&name); +#include "glutil_extlist.h" +#undef GLUTIL_EXT + + // Mark as initialized. + s_glExtInitialized = true; + } + + pthread_mutex_unlock(&s_getProcAddressMutex); + return; +} + +void setGLContext(GLContext& glctx) +{ + if (!glctx.context) + LOG(FATAL) << "setGLContext() called with null gltcx"; + + if (!eglMakeCurrent(glctx.display, EGL_NO_SURFACE, EGL_NO_SURFACE, glctx.context)) + LOG(ERROR) << "eglMakeCurrent() failed when setting GL context"; + + if (glctx.extInitialized) + return; + initializeGLExtensions(); + glctx.extInitialized = 1; +} + +void releaseGLContext(void) +{ + EGLDisplay display = eglGetCurrentDisplay(); + if (display == EGL_NO_DISPLAY) + LOG(WARNING) << "releaseGLContext() called with no active display"; + if (!eglMakeCurrent(display, EGL_NO_SURFACE, EGL_NO_SURFACE, EGL_NO_CONTEXT)) + LOG(FATAL) << "eglMakeCurrent() failed when releasing GL context"; +} + +static EGLDisplay getCudaDisplay(int cudaDeviceIdx) +{ + typedef EGLBoolean (*eglQueryDevicesEXT_t)(EGLint, EGLDeviceEXT, EGLint*); + typedef EGLBoolean (*eglQueryDeviceAttribEXT_t)(EGLDeviceEXT, EGLint, EGLAttrib*); + typedef EGLDisplay (*eglGetPlatformDisplayEXT_t)(EGLenum, void*, const EGLint*); + + eglQueryDevicesEXT_t eglQueryDevicesEXT = (eglQueryDevicesEXT_t)eglGetProcAddress("eglQueryDevicesEXT"); + if (!eglQueryDevicesEXT) + { + LOG(INFO) << "eglGetProcAddress(\"eglQueryDevicesEXT\") failed"; + return 0; + } + + eglQueryDeviceAttribEXT_t eglQueryDeviceAttribEXT = (eglQueryDeviceAttribEXT_t)eglGetProcAddress("eglQueryDeviceAttribEXT"); + if (!eglQueryDeviceAttribEXT) + { + LOG(INFO) << "eglGetProcAddress(\"eglQueryDeviceAttribEXT\") failed"; + return 0; + } + + eglGetPlatformDisplayEXT_t eglGetPlatformDisplayEXT = (eglGetPlatformDisplayEXT_t)eglGetProcAddress("eglGetPlatformDisplayEXT"); + if (!eglGetPlatformDisplayEXT) + { + LOG(INFO) << "eglGetProcAddress(\"eglGetPlatformDisplayEXT\") failed"; + return 0; + } + + int num_devices = 0; + eglQueryDevicesEXT(0, 0, &num_devices); + if (!num_devices) + return 0; + + EGLDisplay display = 0; + EGLDeviceEXT* devices = (EGLDeviceEXT*)malloc(num_devices * sizeof(void*)); + eglQueryDevicesEXT(num_devices, devices, &num_devices); + for (int i=0; i < num_devices; i++) + { + EGLDeviceEXT device = devices[i]; + intptr_t value = -1; + if (eglQueryDeviceAttribEXT(device, EGL_CUDA_DEVICE_NV, &value) && value == cudaDeviceIdx) + { + display = eglGetPlatformDisplayEXT(EGL_PLATFORM_DEVICE_EXT, device, 0); + break; + } + } + + free(devices); + return display; +} + +GLContext createGLContext(int cudaDeviceIdx) +{ + EGLDisplay display = 0; + + if (cudaDeviceIdx >= 0) + { + char pciBusId[256] = ""; + LOG(INFO) << "Creating GL context for Cuda device " << cudaDeviceIdx; + display = getCudaDisplay(cudaDeviceIdx); + if (!display) + LOG(INFO) << "Failed, falling back to default display"; + } + + if (!display) + { + display = eglGetDisplay(EGL_DEFAULT_DISPLAY); + if (display == EGL_NO_DISPLAY) + LOG(FATAL) << "eglGetDisplay() failed"; + } + + EGLint major; + EGLint minor; + if (!eglInitialize(display, &major, &minor)) + LOG(FATAL) << "eglInitialize() failed"; + + // Choose configuration. + + const EGLint context_attribs[] = { + EGL_RED_SIZE, 8, + EGL_GREEN_SIZE, 8, + EGL_BLUE_SIZE, 8, + EGL_ALPHA_SIZE, 8, + EGL_DEPTH_SIZE, 24, + EGL_STENCIL_SIZE, 8, + EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT, + EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, + EGL_NONE + }; + + EGLConfig config; + EGLint num_config; + if (!eglChooseConfig(display, context_attribs, &config, 1, &num_config)) + LOG(FATAL) << "eglChooseConfig() failed"; + + // Create GL context. + + if (!eglBindAPI(EGL_OPENGL_API)) + LOG(FATAL) << "eglBindAPI() failed"; + + EGLContext context = eglCreateContext(display, config, EGL_NO_CONTEXT, NULL); + if (context == EGL_NO_CONTEXT) + LOG(FATAL) << "eglCreateContext() failed"; + + // Done. + + LOG(INFO) << "EGL " << (int)minor << "." << (int)major << " OpenGL context created (disp: 0x" + << std::hex << std::setfill('0') + << std::setw(16) << (uintptr_t)display + << ", ctx: 0x" << std::setw(16) << (uintptr_t)context << ")"; + + GLContext glctx = {display, context, 0}; + return glctx; +} + +void destroyGLContext(GLContext& glctx) +{ + if (!glctx.context) + LOG(FATAL) << "destroyGLContext() called with null gltcx"; + + // If this is the current context, release it. + if (eglGetCurrentContext() == glctx.context) + releaseGLContext(); + + if (!eglDestroyContext(glctx.display, glctx.context)) + LOG(ERROR) << "eglDestroyContext() failed"; + + LOG(INFO) << "EGL OpenGL context destroyed (disp: 0x" + << std::hex << std::setfill('0') + << std::setw(16) << (uintptr_t)glctx.display + << ", ctx: 0x" << std::setw(16) << (uintptr_t)glctx.context << ")"; + + memset(&glctx, 0, sizeof(GLContext)); +} + +//------------------------------------------------------------------------ + +#endif // __linux__ + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/glutil.h b/extensions/nvdiffrast/nvdiffrast/common/glutil.h new file mode 100644 index 0000000000000000000000000000000000000000..6e4c384626324b5d36e0cc9ac002eacbb10c65f2 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/glutil.h @@ -0,0 +1,113 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once + +//------------------------------------------------------------------------ +// Windows-specific headers and types. +//------------------------------------------------------------------------ + +#ifdef _WIN32 +#define NOMINMAX +#include // Required by gl.h in Windows. +#define GLAPIENTRY APIENTRY + +struct GLContext +{ + HDC hdc; + HGLRC hglrc; + int extInitialized; +}; + +#endif // _WIN32 + +//------------------------------------------------------------------------ +// Linux-specific headers and types. +//------------------------------------------------------------------------ + +#ifdef __linux__ +#define EGL_NO_X11 // X11/Xlib.h has "#define Status int" which breaks Tensorflow. Avoid it. +#define MESA_EGL_NO_X11_HEADERS +#include +#include +#define GLAPIENTRY + +struct GLContext +{ + EGLDisplay display; + EGLContext context; + int extInitialized; +}; + +#endif // __linux__ + +//------------------------------------------------------------------------ +// OpenGL, CUDA interop, GL extensions. +//------------------------------------------------------------------------ +#define GL_GLEXT_LEGACY +#include +#include + +// Constants. +#ifndef GL_VERSION_1_2 +#define GL_CLAMP_TO_EDGE 0x812F +#define GL_TEXTURE_3D 0x806F +#endif +#ifndef GL_VERSION_1_5 +#define GL_ARRAY_BUFFER 0x8892 +#define GL_DYNAMIC_DRAW 0x88E8 +#define GL_ELEMENT_ARRAY_BUFFER 0x8893 +#endif +#ifndef GL_VERSION_2_0 +#define GL_FRAGMENT_SHADER 0x8B30 +#define GL_INFO_LOG_LENGTH 0x8B84 +#define GL_LINK_STATUS 0x8B82 +#define GL_VERTEX_SHADER 0x8B31 +#endif +#ifndef GL_VERSION_3_0 +#define GL_MAJOR_VERSION 0x821B +#define GL_MINOR_VERSION 0x821C +#define GL_RGBA32F 0x8814 +#define GL_TEXTURE_2D_ARRAY 0x8C1A +#endif +#ifndef GL_VERSION_3_2 +#define GL_GEOMETRY_SHADER 0x8DD9 +#endif +#ifndef GL_ARB_framebuffer_object +#define GL_COLOR_ATTACHMENT0 0x8CE0 +#define GL_COLOR_ATTACHMENT1 0x8CE1 +#define GL_DEPTH_STENCIL 0x84F9 +#define GL_DEPTH_STENCIL_ATTACHMENT 0x821A +#define GL_DEPTH24_STENCIL8 0x88F0 +#define GL_FRAMEBUFFER 0x8D40 +#define GL_INVALID_FRAMEBUFFER_OPERATION 0x0506 +#define GL_UNSIGNED_INT_24_8 0x84FA +#endif +#ifndef GL_ARB_imaging +#define GL_TABLE_TOO_LARGE 0x8031 +#endif +#ifndef GL_KHR_robustness +#define GL_CONTEXT_LOST 0x0507 +#endif + +// Declare function pointers to OpenGL extension functions. +#define GLUTIL_EXT(return_type, name, ...) extern return_type (GLAPIENTRY* name)(__VA_ARGS__); +#include "glutil_extlist.h" +#undef GLUTIL_EXT + +//------------------------------------------------------------------------ +// Common functions. +//------------------------------------------------------------------------ + +void setGLContext (GLContext& glctx); +void releaseGLContext (void); +GLContext createGLContext (int cudaDeviceIdx); +void destroyGLContext (GLContext& glctx); +const char* getGLErrorString (GLenum err); + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/glutil_extlist.h b/extensions/nvdiffrast/nvdiffrast/common/glutil_extlist.h new file mode 100644 index 0000000000000000000000000000000000000000..6b0fdc2c7820ee4ead0c3e68ad8bafce033c482f --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/glutil_extlist.h @@ -0,0 +1,48 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#ifndef GL_VERSION_1_2 +GLUTIL_EXT(void, glTexImage3D, GLenum target, GLint level, GLint internalFormat, GLsizei width, GLsizei height, GLsizei depth, GLint border, GLenum format, GLenum type, const void *pixels); +#endif +#ifndef GL_VERSION_1_5 +GLUTIL_EXT(void, glBindBuffer, GLenum target, GLuint buffer); +GLUTIL_EXT(void, glBufferData, GLenum target, ptrdiff_t size, const void* data, GLenum usage); +GLUTIL_EXT(void, glGenBuffers, GLsizei n, GLuint* buffers); +#endif +#ifndef GL_VERSION_2_0 +GLUTIL_EXT(void, glAttachShader, GLuint program, GLuint shader); +GLUTIL_EXT(void, glCompileShader, GLuint shader); +GLUTIL_EXT(GLuint, glCreateProgram, void); +GLUTIL_EXT(GLuint, glCreateShader, GLenum type); +GLUTIL_EXT(void, glDrawBuffers, GLsizei n, const GLenum* bufs); +GLUTIL_EXT(void, glEnableVertexAttribArray, GLuint index); +GLUTIL_EXT(void, glGetProgramInfoLog, GLuint program, GLsizei bufSize, GLsizei* length, char* infoLog); +GLUTIL_EXT(void, glGetProgramiv, GLuint program, GLenum pname, GLint* param); +GLUTIL_EXT(void, glLinkProgram, GLuint program); +GLUTIL_EXT(void, glShaderSource, GLuint shader, GLsizei count, const char *const* string, const GLint* length); +GLUTIL_EXT(void, glUniform1f, GLint location, GLfloat v0); +GLUTIL_EXT(void, glUniform2f, GLint location, GLfloat v0, GLfloat v1); +GLUTIL_EXT(void, glUseProgram, GLuint program); +GLUTIL_EXT(void, glVertexAttribPointer, GLuint index, GLint size, GLenum type, GLboolean normalized, GLsizei stride, const void* pointer); +#endif +#ifndef GL_VERSION_3_2 +GLUTIL_EXT(void, glFramebufferTexture, GLenum target, GLenum attachment, GLuint texture, GLint level); +#endif +#ifndef GL_ARB_framebuffer_object +GLUTIL_EXT(void, glBindFramebuffer, GLenum target, GLuint framebuffer); +GLUTIL_EXT(void, glGenFramebuffers, GLsizei n, GLuint* framebuffers); +#endif +#ifndef GL_ARB_vertex_array_object +GLUTIL_EXT(void, glBindVertexArray, GLuint array); +GLUTIL_EXT(void, glGenVertexArrays, GLsizei n, GLuint* arrays); +#endif +#ifndef GL_ARB_multi_draw_indirect +GLUTIL_EXT(void, glMultiDrawElementsIndirect, GLenum mode, GLenum type, const void *indirect, GLsizei primcount, GLsizei stride); +#endif + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/interpolate.cu b/extensions/nvdiffrast/nvdiffrast/common/interpolate.cu new file mode 100644 index 0000000000000000000000000000000000000000..f75f5f1f4ffa39bee40c01dc6b62843e658b83c6 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/interpolate.cu @@ -0,0 +1,276 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "common.h" +#include "interpolate.h" + +//------------------------------------------------------------------------ +// Forward kernel. + +template +static __forceinline__ __device__ void InterpolateFwdKernelTemplate(const InterpolateKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.width || py >= p.height || pz >= p.depth) + return; + + // Pixel index. + int pidx = px + p.width * (py + p.height * pz); + + // Output ptrs. + float* out = p.out + pidx * p.numAttr; + float2* outDA = ENABLE_DA ? (((float2*)p.outDA) + pidx * p.numDiffAttr) : 0; + + // Fetch rasterizer output. + float4 r = ((float4*)p.rast)[pidx]; + int triIdx = float_to_triidx(r.w) - 1; + bool triValid = (triIdx >= 0 && triIdx < p.numTriangles); + + // If no geometry in entire warp, zero the output and exit. + // Otherwise force barys to zero and output with live threads. + if (__all_sync(0xffffffffu, !triValid)) + { + for (int i=0; i < p.numAttr; i++) + out[i] = 0.f; + if (ENABLE_DA) + for (int i=0; i < p.numDiffAttr; i++) + outDA[i] = make_float2(0.f, 0.f); + return; + } + + // Fetch vertex indices. + int vi0 = triValid ? p.tri[triIdx * 3 + 0] : 0; + int vi1 = triValid ? p.tri[triIdx * 3 + 1] : 0; + int vi2 = triValid ? p.tri[triIdx * 3 + 2] : 0; + + // Bail out if corrupt indices. + if (vi0 < 0 || vi0 >= p.numVertices || + vi1 < 0 || vi1 >= p.numVertices || + vi2 < 0 || vi2 >= p.numVertices) + return; + + // In instance mode, adjust vertex indices by minibatch index unless broadcasting. + if (p.instance_mode && !p.attrBC) + { + vi0 += pz * p.numVertices; + vi1 += pz * p.numVertices; + vi2 += pz * p.numVertices; + } + + // Pointers to attributes. + const float* a0 = p.attr + vi0 * p.numAttr; + const float* a1 = p.attr + vi1 * p.numAttr; + const float* a2 = p.attr + vi2 * p.numAttr; + + // Barys. If no triangle, force all to zero -> output is zero. + float b0 = triValid ? r.x : 0.f; + float b1 = triValid ? r.y : 0.f; + float b2 = triValid ? (1.f - r.x - r.y) : 0.f; + + // Interpolate and write attributes. + for (int i=0; i < p.numAttr; i++) + out[i] = b0*a0[i] + b1*a1[i] + b2*a2[i]; + + // No diff attrs? Exit. + if (!ENABLE_DA) + return; + + // Read bary pixel differentials if we have a triangle. + float4 db = make_float4(0.f, 0.f, 0.f, 0.f); + if (triValid) + db = ((float4*)p.rastDB)[pidx]; + + // Unpack a bit. + float dudx = db.x; + float dudy = db.y; + float dvdx = db.z; + float dvdy = db.w; + + // Calculate the pixel differentials of chosen attributes. + for (int i=0; i < p.numDiffAttr; i++) + { + // Input attribute index. + int j = p.diff_attrs_all ? i : p.diffAttrs[i]; + if (j < 0) + j += p.numAttr; // Python-style negative indices. + + // Zero output if invalid index. + float dsdx = 0.f; + float dsdy = 0.f; + if (j >= 0 && j < p.numAttr) + { + float s0 = a0[j]; + float s1 = a1[j]; + float s2 = a2[j]; + float dsdu = s0 - s2; + float dsdv = s1 - s2; + dsdx = dudx*dsdu + dvdx*dsdv; + dsdy = dudy*dsdu + dvdy*dsdv; + } + + // Write. + outDA[i] = make_float2(dsdx, dsdy); + } +} + +// Template specializations. +__global__ void InterpolateFwdKernel (const InterpolateKernelParams p) { InterpolateFwdKernelTemplate(p); } +__global__ void InterpolateFwdKernelDa(const InterpolateKernelParams p) { InterpolateFwdKernelTemplate(p); } + +//------------------------------------------------------------------------ +// Gradient kernel. + +template +static __forceinline__ __device__ void InterpolateGradKernelTemplate(const InterpolateKernelParams p) +{ + // Temporary space for coalesced atomics. + CA_DECLARE_TEMP(IP_GRAD_MAX_KERNEL_BLOCK_WIDTH * IP_GRAD_MAX_KERNEL_BLOCK_HEIGHT); + + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.width || py >= p.height || pz >= p.depth) + return; + + // Pixel index. + int pidx = px + p.width * (py + p.height * pz); + + // Fetch triangle ID. If none, output zero bary/db gradients and exit. + float4 r = ((float4*)p.rast)[pidx]; + int triIdx = float_to_triidx(r.w) - 1; + if (triIdx < 0 || triIdx >= p.numTriangles) + { + ((float4*)p.gradRaster)[pidx] = make_float4(0.f, 0.f, 0.f, 0.f); + if (ENABLE_DA) + ((float4*)p.gradRasterDB)[pidx] = make_float4(0.f, 0.f, 0.f, 0.f); + return; + } + + // Fetch vertex indices. + int vi0 = p.tri[triIdx * 3 + 0]; + int vi1 = p.tri[triIdx * 3 + 1]; + int vi2 = p.tri[triIdx * 3 + 2]; + + // Bail out if corrupt indices. + if (vi0 < 0 || vi0 >= p.numVertices || + vi1 < 0 || vi1 >= p.numVertices || + vi2 < 0 || vi2 >= p.numVertices) + return; + + // In instance mode, adjust vertex indices by minibatch index unless broadcasting. + if (p.instance_mode && !p.attrBC) + { + vi0 += pz * p.numVertices; + vi1 += pz * p.numVertices; + vi2 += pz * p.numVertices; + } + + // Initialize coalesced atomics. + CA_SET_GROUP(triIdx); + + // Pointers to inputs. + const float* a0 = p.attr + vi0 * p.numAttr; + const float* a1 = p.attr + vi1 * p.numAttr; + const float* a2 = p.attr + vi2 * p.numAttr; + const float* pdy = p.dy + pidx * p.numAttr; + + // Pointers to outputs. + float* ga0 = p.gradAttr + vi0 * p.numAttr; + float* ga1 = p.gradAttr + vi1 * p.numAttr; + float* ga2 = p.gradAttr + vi2 * p.numAttr; + + // Barys and bary gradient accumulators. + float b0 = r.x; + float b1 = r.y; + float b2 = 1.f - r.x - r.y; + float gb0 = 0.f; + float gb1 = 0.f; + + // Loop over attributes and accumulate attribute gradients. + for (int i=0; i < p.numAttr; i++) + { + float y = pdy[i]; + float s0 = a0[i]; + float s1 = a1[i]; + float s2 = a2[i]; + gb0 += y * (s0 - s2); + gb1 += y * (s1 - s2); + caAtomicAdd(ga0 + i, b0 * y); + caAtomicAdd(ga1 + i, b1 * y); + caAtomicAdd(ga2 + i, b2 * y); + } + + // Write the bary gradients. + ((float4*)p.gradRaster)[pidx] = make_float4(gb0, gb1, 0.f, 0.f); + + // If pixel differentials disabled, we're done. + if (!ENABLE_DA) + return; + + // Calculate gradients based on attribute pixel differentials. + const float2* dda = ((float2*)p.dda) + pidx * p.numDiffAttr; + float gdudx = 0.f; + float gdudy = 0.f; + float gdvdx = 0.f; + float gdvdy = 0.f; + + // Read bary pixel differentials. + float4 db = ((float4*)p.rastDB)[pidx]; + float dudx = db.x; + float dudy = db.y; + float dvdx = db.z; + float dvdy = db.w; + + for (int i=0; i < p.numDiffAttr; i++) + { + // Input attribute index. + int j = p.diff_attrs_all ? i : p.diffAttrs[i]; + if (j < 0) + j += p.numAttr; // Python-style negative indices. + + // Check that index is valid. + if (j >= 0 && j < p.numAttr) + { + float2 dsdxy = dda[i]; + float dsdx = dsdxy.x; + float dsdy = dsdxy.y; + + float s0 = a0[j]; + float s1 = a1[j]; + float s2 = a2[j]; + + // Gradients of db. + float dsdu = s0 - s2; + float dsdv = s1 - s2; + gdudx += dsdu * dsdx; + gdudy += dsdu * dsdy; + gdvdx += dsdv * dsdx; + gdvdy += dsdv * dsdy; + + // Gradients of attributes. + float du = dsdx*dudx + dsdy*dudy; + float dv = dsdx*dvdx + dsdy*dvdy; + caAtomicAdd(ga0 + j, du); + caAtomicAdd(ga1 + j, dv); + caAtomicAdd(ga2 + j, -du - dv); + } + } + + // Write. + ((float4*)p.gradRasterDB)[pidx] = make_float4(gdudx, gdudy, gdvdx, gdvdy); +} + +// Template specializations. +__global__ void InterpolateGradKernel (const InterpolateKernelParams p) { InterpolateGradKernelTemplate(p); } +__global__ void InterpolateGradKernelDa(const InterpolateKernelParams p) { InterpolateGradKernelTemplate(p); } + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/interpolate.h b/extensions/nvdiffrast/nvdiffrast/common/interpolate.h new file mode 100644 index 0000000000000000000000000000000000000000..335bb5a8baf53aa96f0cf2f51d0a0fc91ba49527 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/interpolate.h @@ -0,0 +1,49 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once + +//------------------------------------------------------------------------ +// Constants and helpers. + +#define IP_FWD_MAX_KERNEL_BLOCK_WIDTH 8 +#define IP_FWD_MAX_KERNEL_BLOCK_HEIGHT 8 +#define IP_GRAD_MAX_KERNEL_BLOCK_WIDTH 8 +#define IP_GRAD_MAX_KERNEL_BLOCK_HEIGHT 8 +#define IP_MAX_DIFF_ATTRS 32 + +//------------------------------------------------------------------------ +// CUDA kernel params. + +struct InterpolateKernelParams +{ + const int* tri; // Incoming triangle buffer. + const float* attr; // Incoming attribute buffer. + const float* rast; // Incoming rasterizer output buffer. + const float* rastDB; // Incoming rasterizer output buffer for bary derivatives. + const float* dy; // Incoming attribute gradients. + const float* dda; // Incoming attr diff gradients. + float* out; // Outgoing interpolated attributes. + float* outDA; // Outgoing texcoord major axis lengths. + float* gradAttr; // Outgoing attribute gradients. + float* gradRaster; // Outgoing rasterizer gradients. + float* gradRasterDB; // Outgoing rasterizer bary diff gradients. + int numTriangles; // Number of triangles. + int numVertices; // Number of vertices. + int numAttr; // Number of total vertex attributes. + int numDiffAttr; // Number of attributes to differentiate. + int width; // Image width. + int height; // Image height. + int depth; // Minibatch size. + int attrBC; // 0=normal, 1=attr is broadcast. + int instance_mode; // 0=normal, 1=instance mode. + int diff_attrs_all; // 0=normal, 1=produce pixel differentials for all attributes. + int diffAttrs[IP_MAX_DIFF_ATTRS]; // List of attributes to differentiate. +}; + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/rasterize.cu b/extensions/nvdiffrast/nvdiffrast/common/rasterize.cu new file mode 100644 index 0000000000000000000000000000000000000000..bdc518c9bec3ef45b74dcd75bc8429107b3c9896 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/rasterize.cu @@ -0,0 +1,276 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "common.h" +#include "rasterize.h" + +//------------------------------------------------------------------------ +// Cuda forward rasterizer pixel shader kernel. + +__global__ void RasterizeCudaFwdShaderKernel(const RasterizeCudaFwdShaderParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.width_out || py >= p.height_out || pz >= p.depth) + return; + + // Pixel indices. + int pidx_in = px + p.width_in * (py + p.height_in * pz); + int pidx_out = px + p.width_out * (py + p.height_out * pz); + + // Fetch triangle idx. + int triIdx = p.in_idx[pidx_in] - 1; + if (triIdx < 0 || triIdx >= p.numTriangles) + { + // No or corrupt triangle. + ((float4*)p.out)[pidx_out] = make_float4(0.0, 0.0, 0.0, 0.0); // Clear out. + ((float4*)p.out_db)[pidx_out] = make_float4(0.0, 0.0, 0.0, 0.0); // Clear out_db. + return; + } + + // Fetch vertex indices. + int vi0 = p.tri[triIdx * 3 + 0]; + int vi1 = p.tri[triIdx * 3 + 1]; + int vi2 = p.tri[triIdx * 3 + 2]; + + // Bail out if vertex indices are corrupt. + if (vi0 < 0 || vi0 >= p.numVertices || + vi1 < 0 || vi1 >= p.numVertices || + vi2 < 0 || vi2 >= p.numVertices) + return; + + // In instance mode, adjust vertex indices by minibatch index. + if (p.instance_mode) + { + vi0 += pz * p.numVertices; + vi1 += pz * p.numVertices; + vi2 += pz * p.numVertices; + } + + // Fetch vertex positions. + float4 p0 = ((float4*)p.pos)[vi0]; + float4 p1 = ((float4*)p.pos)[vi1]; + float4 p2 = ((float4*)p.pos)[vi2]; + + // Evaluate edge functions. + float fx = p.xs * (float)px + p.xo; + float fy = p.ys * (float)py + p.yo; + float p0x = p0.x - fx * p0.w; + float p0y = p0.y - fy * p0.w; + float p1x = p1.x - fx * p1.w; + float p1y = p1.y - fy * p1.w; + float p2x = p2.x - fx * p2.w; + float p2y = p2.y - fy * p2.w; + float a0 = p1x*p2y - p1y*p2x; + float a1 = p2x*p0y - p2y*p0x; + float a2 = p0x*p1y - p0y*p1x; + + // Perspective correct, normalized barycentrics. + float iw = 1.f / (a0 + a1 + a2); + float b0 = a0 * iw; + float b1 = a1 * iw; + + // Compute z/w for depth buffer. + float z = p0.z * a0 + p1.z * a1 + p2.z * a2; + float w = p0.w * a0 + p1.w * a1 + p2.w * a2; + float zw = z / w; + + // Clamps to avoid NaNs. + b0 = __saturatef(b0); // Clamp to [+0.0, 1.0]. + b1 = __saturatef(b1); // Clamp to [+0.0, 1.0]. + zw = fmaxf(fminf(zw, 1.f), -1.f); + + // Emit output. + ((float4*)p.out)[pidx_out] = make_float4(b0, b1, zw, triidx_to_float(triIdx + 1)); + + // Calculate bary pixel differentials. + float dfxdx = p.xs * iw; + float dfydy = p.ys * iw; + float da0dx = p2.y*p1.w - p1.y*p2.w; + float da0dy = p1.x*p2.w - p2.x*p1.w; + float da1dx = p0.y*p2.w - p2.y*p0.w; + float da1dy = p2.x*p0.w - p0.x*p2.w; + float da2dx = p1.y*p0.w - p0.y*p1.w; + float da2dy = p0.x*p1.w - p1.x*p0.w; + float datdx = da0dx + da1dx + da2dx; + float datdy = da0dy + da1dy + da2dy; + float dudx = dfxdx * (b0 * datdx - da0dx); + float dudy = dfydy * (b0 * datdy - da0dy); + float dvdx = dfxdx * (b1 * datdx - da1dx); + float dvdy = dfydy * (b1 * datdy - da1dy); + + // Emit bary pixel differentials. + ((float4*)p.out_db)[pidx_out] = make_float4(dudx, dudy, dvdx, dvdy); +} + +//------------------------------------------------------------------------ +// Gradient Cuda kernel. + +template +static __forceinline__ __device__ void RasterizeGradKernelTemplate(const RasterizeGradParams p) +{ + // Temporary space for coalesced atomics. + CA_DECLARE_TEMP(RAST_GRAD_MAX_KERNEL_BLOCK_WIDTH * RAST_GRAD_MAX_KERNEL_BLOCK_HEIGHT); + + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.width || py >= p.height || pz >= p.depth) + return; + + // Pixel index. + int pidx = px + p.width * (py + p.height * pz); + + // Read triangle idx and dy. + float2 dy = ((float2*)p.dy)[pidx * 2]; + float4 ddb = ENABLE_DB ? ((float4*)p.ddb)[pidx] : make_float4(0.f, 0.f, 0.f, 0.f); + int triIdx = float_to_triidx(((float*)p.out)[pidx * 4 + 3]) - 1; + + // Exit if nothing to do. + if (triIdx < 0 || triIdx >= p.numTriangles) + return; // No or corrupt triangle. + int grad_all_dy = __float_as_int(dy.x) | __float_as_int(dy.y); // Bitwise OR of all incoming gradients. + int grad_all_ddb = 0; + if (ENABLE_DB) + grad_all_ddb = __float_as_int(ddb.x) | __float_as_int(ddb.y) | __float_as_int(ddb.z) | __float_as_int(ddb.w); + if (((grad_all_dy | grad_all_ddb) << 1) == 0) + return; // All incoming gradients are +0/-0. + + // Fetch vertex indices. + int vi0 = p.tri[triIdx * 3 + 0]; + int vi1 = p.tri[triIdx * 3 + 1]; + int vi2 = p.tri[triIdx * 3 + 2]; + + // Bail out if vertex indices are corrupt. + if (vi0 < 0 || vi0 >= p.numVertices || + vi1 < 0 || vi1 >= p.numVertices || + vi2 < 0 || vi2 >= p.numVertices) + return; + + // In instance mode, adjust vertex indices by minibatch index. + if (p.instance_mode) + { + vi0 += pz * p.numVertices; + vi1 += pz * p.numVertices; + vi2 += pz * p.numVertices; + } + + // Initialize coalesced atomics. + CA_SET_GROUP(triIdx); + + // Fetch vertex positions. + float4 p0 = ((float4*)p.pos)[vi0]; + float4 p1 = ((float4*)p.pos)[vi1]; + float4 p2 = ((float4*)p.pos)[vi2]; + + // Evaluate edge functions. + float fx = p.xs * (float)px + p.xo; + float fy = p.ys * (float)py + p.yo; + float p0x = p0.x - fx * p0.w; + float p0y = p0.y - fy * p0.w; + float p1x = p1.x - fx * p1.w; + float p1y = p1.y - fy * p1.w; + float p2x = p2.x - fx * p2.w; + float p2y = p2.y - fy * p2.w; + float a0 = p1x*p2y - p1y*p2x; + float a1 = p2x*p0y - p2y*p0x; + float a2 = p0x*p1y - p0y*p1x; + + // Compute inverse area with epsilon. + float at = a0 + a1 + a2; + float ep = copysignf(1e-6f, at); // ~1 pixel in 1k x 1k image. + float iw = 1.f / (at + ep); + + // Perspective correct, normalized barycentrics. + float b0 = a0 * iw; + float b1 = a1 * iw; + + // Position gradients. + float gb0 = dy.x * iw; + float gb1 = dy.y * iw; + float gbb = gb0 * b0 + gb1 * b1; + float gp0x = gbb * (p2y - p1y) - gb1 * p2y; + float gp1x = gbb * (p0y - p2y) + gb0 * p2y; + float gp2x = gbb * (p1y - p0y) - gb0 * p1y + gb1 * p0y; + float gp0y = gbb * (p1x - p2x) + gb1 * p2x; + float gp1y = gbb * (p2x - p0x) - gb0 * p2x; + float gp2y = gbb * (p0x - p1x) + gb0 * p1x - gb1 * p0x; + float gp0w = -fx * gp0x - fy * gp0y; + float gp1w = -fx * gp1x - fy * gp1y; + float gp2w = -fx * gp2x - fy * gp2y; + + // Bary differential gradients. + if (ENABLE_DB && ((grad_all_ddb) << 1) != 0) + { + float dfxdX = p.xs * iw; + float dfydY = p.ys * iw; + ddb.x *= dfxdX; + ddb.y *= dfydY; + ddb.z *= dfxdX; + ddb.w *= dfydY; + + float da0dX = p1.y * p2.w - p2.y * p1.w; + float da1dX = p2.y * p0.w - p0.y * p2.w; + float da2dX = p0.y * p1.w - p1.y * p0.w; + float da0dY = p2.x * p1.w - p1.x * p2.w; + float da1dY = p0.x * p2.w - p2.x * p0.w; + float da2dY = p1.x * p0.w - p0.x * p1.w; + float datdX = da0dX + da1dX + da2dX; + float datdY = da0dY + da1dY + da2dY; + + float x01 = p0.x - p1.x; + float x12 = p1.x - p2.x; + float x20 = p2.x - p0.x; + float y01 = p0.y - p1.y; + float y12 = p1.y - p2.y; + float y20 = p2.y - p0.y; + float w01 = p0.w - p1.w; + float w12 = p1.w - p2.w; + float w20 = p2.w - p0.w; + + float a0p1 = fy * p2.x - fx * p2.y; + float a0p2 = fx * p1.y - fy * p1.x; + float a1p0 = fx * p2.y - fy * p2.x; + float a1p2 = fy * p0.x - fx * p0.y; + + float wdudX = 2.f * b0 * datdX - da0dX; + float wdudY = 2.f * b0 * datdY - da0dY; + float wdvdX = 2.f * b1 * datdX - da1dX; + float wdvdY = 2.f * b1 * datdY - da1dY; + + float c0 = iw * (ddb.x * wdudX + ddb.y * wdudY + ddb.z * wdvdX + ddb.w * wdvdY); + float cx = c0 * fx - ddb.x * b0 - ddb.z * b1; + float cy = c0 * fy - ddb.y * b0 - ddb.w * b1; + float cxy = iw * (ddb.x * datdX + ddb.y * datdY); + float czw = iw * (ddb.z * datdX + ddb.w * datdY); + + gp0x += c0 * y12 - cy * w12 + czw * p2y + ddb.w * p2.w; + gp1x += c0 * y20 - cy * w20 - cxy * p2y - ddb.y * p2.w; + gp2x += c0 * y01 - cy * w01 + cxy * p1y - czw * p0y + ddb.y * p1.w - ddb.w * p0.w; + gp0y += cx * w12 - c0 * x12 - czw * p2x - ddb.z * p2.w; + gp1y += cx * w20 - c0 * x20 + cxy * p2x + ddb.x * p2.w; + gp2y += cx * w01 - c0 * x01 - cxy * p1x + czw * p0x - ddb.x * p1.w + ddb.z * p0.w; + gp0w += cy * x12 - cx * y12 - czw * a1p0 + ddb.z * p2.y - ddb.w * p2.x; + gp1w += cy * x20 - cx * y20 - cxy * a0p1 - ddb.x * p2.y + ddb.y * p2.x; + gp2w += cy * x01 - cx * y01 - cxy * a0p2 - czw * a1p2 + ddb.x * p1.y - ddb.y * p1.x - ddb.z * p0.y + ddb.w * p0.x; + } + + // Accumulate using coalesced atomics. + caAtomicAdd3_xyw(p.grad + 4 * vi0, gp0x, gp0y, gp0w); + caAtomicAdd3_xyw(p.grad + 4 * vi1, gp1x, gp1y, gp1w); + caAtomicAdd3_xyw(p.grad + 4 * vi2, gp2x, gp2y, gp2w); +} + +// Template specializations. +__global__ void RasterizeGradKernel (const RasterizeGradParams p) { RasterizeGradKernelTemplate(p); } +__global__ void RasterizeGradKernelDb(const RasterizeGradParams p) { RasterizeGradKernelTemplate(p); } + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/rasterize.h b/extensions/nvdiffrast/nvdiffrast/common/rasterize.h new file mode 100644 index 0000000000000000000000000000000000000000..30c5e4951239891c4d9c1f28bc0012a39c4a6b06 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/rasterize.h @@ -0,0 +1,60 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once + +//------------------------------------------------------------------------ +// Constants and helpers. + +#define RAST_CUDA_FWD_SHADER_KERNEL_BLOCK_WIDTH 8 +#define RAST_CUDA_FWD_SHADER_KERNEL_BLOCK_HEIGHT 8 +#define RAST_GRAD_MAX_KERNEL_BLOCK_WIDTH 8 +#define RAST_GRAD_MAX_KERNEL_BLOCK_HEIGHT 8 + +//------------------------------------------------------------------------ +// CUDA forward rasterizer shader kernel params. + +struct RasterizeCudaFwdShaderParams +{ + const float* pos; // Vertex positions. + const int* tri; // Triangle indices. + const int* in_idx; // Triangle idx buffer from rasterizer. + float* out; // Main output buffer. + float* out_db; // Bary pixel gradient output buffer. + int numTriangles; // Number of triangles. + int numVertices; // Number of vertices. + int width_in; // Input image width. + int height_in; // Input image height. + int width_out; // Output image width. + int height_out; // Output image height. + int depth; // Size of minibatch. + int instance_mode; // 1 if in instance rendering mode. + float xs, xo, ys, yo; // Pixel position to clip-space x, y transform. +}; + +//------------------------------------------------------------------------ +// Gradient CUDA kernel params. + +struct RasterizeGradParams +{ + const float* pos; // Incoming position buffer. + const int* tri; // Incoming triangle buffer. + const float* out; // Rasterizer output buffer. + const float* dy; // Incoming gradients of rasterizer output buffer. + const float* ddb; // Incoming gradients of bary diff output buffer. + float* grad; // Outgoing position gradients. + int numTriangles; // Number of triangles. + int numVertices; // Number of vertices. + int width; // Image width. + int height; // Image height. + int depth; // Size of minibatch. + int instance_mode; // 1 if in instance rendering mode. + float xs, xo, ys, yo; // Pixel position to clip-space x, y transform. +}; + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/rasterize_gl.cpp b/extensions/nvdiffrast/nvdiffrast/common/rasterize_gl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..404dcf221bb1182a6755e503927127be113453a8 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/rasterize_gl.cpp @@ -0,0 +1,644 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "rasterize_gl.h" +#include "glutil.h" +#include +#define STRINGIFY_SHADER_SOURCE(x) #x + +//------------------------------------------------------------------------ +// Helpers. + +#define ROUND_UP(x, y) ((((x) + ((y) - 1)) / (y)) * (y)) +static int ROUND_UP_BITS(uint32_t x, uint32_t y) +{ + // Round x up so that it has at most y bits of mantissa. + if (x < (1u << y)) + return x; + uint32_t m = 0; + while (x & ~m) + m = (m << 1) | 1u; + m >>= y; + if (!(x & m)) + return x; + return (x | m) + 1u; +} + +//------------------------------------------------------------------------ +// Draw command struct used by rasterizer. + +struct GLDrawCmd +{ + uint32_t count; + uint32_t instanceCount; + uint32_t firstIndex; + uint32_t baseVertex; + uint32_t baseInstance; +}; + +//------------------------------------------------------------------------ +// GL helpers. + +static void compileGLShader(NVDR_CTX_ARGS, const RasterizeGLState& s, GLuint* pShader, GLenum shaderType, const char* src_buf) +{ + std::string src(src_buf); + + // Set preprocessor directives. + int n = src.find('\n') + 1; // After first line containing #version directive. + if (s.enableZModify) + src.insert(n, "#define IF_ZMODIFY(x) x\n"); + else + src.insert(n, "#define IF_ZMODIFY(x)\n"); + + const char *cstr = src.c_str(); + *pShader = 0; + NVDR_CHECK_GL_ERROR(*pShader = glCreateShader(shaderType)); + NVDR_CHECK_GL_ERROR(glShaderSource(*pShader, 1, &cstr, 0)); + NVDR_CHECK_GL_ERROR(glCompileShader(*pShader)); +} + +static void constructGLProgram(NVDR_CTX_ARGS, GLuint* pProgram, GLuint glVertexShader, GLuint glGeometryShader, GLuint glFragmentShader) +{ + *pProgram = 0; + + GLuint glProgram = 0; + NVDR_CHECK_GL_ERROR(glProgram = glCreateProgram()); + NVDR_CHECK_GL_ERROR(glAttachShader(glProgram, glVertexShader)); + NVDR_CHECK_GL_ERROR(glAttachShader(glProgram, glGeometryShader)); + NVDR_CHECK_GL_ERROR(glAttachShader(glProgram, glFragmentShader)); + NVDR_CHECK_GL_ERROR(glLinkProgram(glProgram)); + + GLint linkStatus = 0; + NVDR_CHECK_GL_ERROR(glGetProgramiv(glProgram, GL_LINK_STATUS, &linkStatus)); + if (!linkStatus) + { + GLint infoLen = 0; + NVDR_CHECK_GL_ERROR(glGetProgramiv(glProgram, GL_INFO_LOG_LENGTH, &infoLen)); + if (infoLen) + { + const char* hdr = "glLinkProgram() failed:\n"; + std::vector info(strlen(hdr) + infoLen); + strcpy(&info[0], hdr); + NVDR_CHECK_GL_ERROR(glGetProgramInfoLog(glProgram, infoLen, &infoLen, &info[strlen(hdr)])); + NVDR_CHECK(0, &info[0]); + } + NVDR_CHECK(0, "glLinkProgram() failed"); + } + + *pProgram = glProgram; +} + +//------------------------------------------------------------------------ +// Shared C++ functions. + +void rasterizeInitGLContext(NVDR_CTX_ARGS, RasterizeGLState& s, int cudaDeviceIdx) +{ + // Create GL context and set it current. + s.glctx = createGLContext(cudaDeviceIdx); + setGLContext(s.glctx); + + // Version check. + GLint vMajor = 0; + GLint vMinor = 0; + glGetIntegerv(GL_MAJOR_VERSION, &vMajor); + glGetIntegerv(GL_MINOR_VERSION, &vMinor); + glGetError(); // Clear possible GL_INVALID_ENUM error in version query. + LOG(INFO) << "OpenGL version reported as " << vMajor << "." << vMinor; + NVDR_CHECK((vMajor == 4 && vMinor >= 4) || vMajor > 4, "OpenGL 4.4 or later is required"); + + // Enable depth modification workaround on A100 and later. + int capMajor = 0; + NVDR_CHECK_CUDA_ERROR(cudaDeviceGetAttribute(&capMajor, cudaDevAttrComputeCapabilityMajor, cudaDeviceIdx)); + s.enableZModify = (capMajor >= 8); + + // Number of output buffers. + int num_outputs = s.enableDB ? 2 : 1; + + // Set up vertex shader. + compileGLShader(NVDR_CTX_PARAMS, s, &s.glVertexShader, GL_VERTEX_SHADER, + "#version 330\n" + "#extension GL_ARB_shader_draw_parameters : enable\n" + STRINGIFY_SHADER_SOURCE( + layout(location = 0) in vec4 in_pos; + out int v_layer; + out int v_offset; + void main() + { + int layer = gl_DrawIDARB; + gl_Position = in_pos; + v_layer = layer; + v_offset = gl_BaseInstanceARB; // Sneak in TriID offset here. + } + ) + ); + + // Geometry and fragment shaders depend on if bary differential output is enabled or not. + if (s.enableDB) + { + // Set up geometry shader. Calculation of per-pixel bary differentials is based on: + // u = (u/w) / (1/w) + // --> du/dX = d((u/w) / (1/w))/dX + // --> du/dX = [d(u/w)/dX - u*d(1/w)/dX] * w + // and we know both d(u/w)/dX and d(1/w)/dX are constant over triangle. + compileGLShader(NVDR_CTX_PARAMS, s, &s.glGeometryShader, GL_GEOMETRY_SHADER, + "#version 430\n" + STRINGIFY_SHADER_SOURCE( + layout(triangles) in; + layout(triangle_strip, max_vertices=3) out; + layout(location = 0) uniform vec2 vp_scale; + in int v_layer[]; + in int v_offset[]; + out vec4 var_uvzw; + out vec4 var_db; + void main() + { + // Plane equations for bary differentials. + float w0 = gl_in[0].gl_Position.w; + float w1 = gl_in[1].gl_Position.w; + float w2 = gl_in[2].gl_Position.w; + vec2 p0 = gl_in[0].gl_Position.xy; + vec2 p1 = gl_in[1].gl_Position.xy; + vec2 p2 = gl_in[2].gl_Position.xy; + vec2 e0 = p0*w2 - p2*w0; + vec2 e1 = p1*w2 - p2*w1; + float a = e0.x*e1.y - e0.y*e1.x; + + // Clamp area to an epsilon to avoid arbitrarily high bary differentials. + float eps = 1e-6f; // ~1 pixel in 1k x 1k image. + float ca = (abs(a) >= eps) ? a : (a < 0.f) ? -eps : eps; // Clamp with sign. + float ia = 1.f / ca; // Inverse area. + + vec2 ascl = ia * vp_scale; + float dudx = e1.y * ascl.x; + float dudy = -e1.x * ascl.y; + float dvdx = -e0.y * ascl.x; + float dvdy = e0.x * ascl.y; + + float duwdx = w2 * dudx; + float dvwdx = w2 * dvdx; + float duvdx = w0 * dudx + w1 * dvdx; + float duwdy = w2 * dudy; + float dvwdy = w2 * dvdy; + float duvdy = w0 * dudy + w1 * dvdy; + + vec4 db0 = vec4(duvdx - dvwdx, duvdy - dvwdy, dvwdx, dvwdy); + vec4 db1 = vec4(duwdx, duwdy, duvdx - duwdx, duvdy - duwdy); + vec4 db2 = vec4(duwdx, duwdy, dvwdx, dvwdy); + + int layer_id = v_layer[0]; + int prim_id = gl_PrimitiveIDIn + v_offset[0]; + + gl_Layer = layer_id; gl_PrimitiveID = prim_id; gl_Position = vec4(gl_in[0].gl_Position.x, gl_in[0].gl_Position.y, gl_in[0].gl_Position.z, gl_in[0].gl_Position.w); var_uvzw = vec4(1.f, 0.f, gl_in[0].gl_Position.z, gl_in[0].gl_Position.w); var_db = db0; EmitVertex(); + gl_Layer = layer_id; gl_PrimitiveID = prim_id; gl_Position = vec4(gl_in[1].gl_Position.x, gl_in[1].gl_Position.y, gl_in[1].gl_Position.z, gl_in[1].gl_Position.w); var_uvzw = vec4(0.f, 1.f, gl_in[1].gl_Position.z, gl_in[1].gl_Position.w); var_db = db1; EmitVertex(); + gl_Layer = layer_id; gl_PrimitiveID = prim_id; gl_Position = vec4(gl_in[2].gl_Position.x, gl_in[2].gl_Position.y, gl_in[2].gl_Position.z, gl_in[2].gl_Position.w); var_uvzw = vec4(0.f, 0.f, gl_in[2].gl_Position.z, gl_in[2].gl_Position.w); var_db = db2; EmitVertex(); + } + ) + ); + + // Set up fragment shader. + compileGLShader(NVDR_CTX_PARAMS, s, &s.glFragmentShader, GL_FRAGMENT_SHADER, + "#version 430\n" + STRINGIFY_SHADER_SOURCE( + in vec4 var_uvzw; + in vec4 var_db; + layout(location = 0) out vec4 out_raster; + layout(location = 1) out vec4 out_db; + IF_ZMODIFY( + layout(location = 1) uniform float in_dummy; + ) + void main() + { + int id_int = gl_PrimitiveID + 1; + float id_float = (id_int <= 0x01000000) ? float(id_int) : intBitsToFloat(0x4a800000 + id_int); + + out_raster = vec4(var_uvzw.x, var_uvzw.y, var_uvzw.z / var_uvzw.w, id_float); + out_db = var_db * var_uvzw.w; + IF_ZMODIFY(gl_FragDepth = gl_FragCoord.z + in_dummy;) + } + ) + ); + + // Set up fragment shader for depth peeling. + compileGLShader(NVDR_CTX_PARAMS, s, &s.glFragmentShaderDP, GL_FRAGMENT_SHADER, + "#version 430\n" + STRINGIFY_SHADER_SOURCE( + in vec4 var_uvzw; + in vec4 var_db; + layout(binding = 0) uniform sampler2DArray out_prev; + layout(location = 0) out vec4 out_raster; + layout(location = 1) out vec4 out_db; + IF_ZMODIFY( + layout(location = 1) uniform float in_dummy; + ) + void main() + { + int id_int = gl_PrimitiveID + 1; + float id_float = (id_int <= 0x01000000) ? float(id_int) : intBitsToFloat(0x4a800000 + id_int); + + vec4 prev = texelFetch(out_prev, ivec3(gl_FragCoord.x, gl_FragCoord.y, gl_Layer), 0); + float depth_new = var_uvzw.z / var_uvzw.w; + if (prev.w == 0 || depth_new <= prev.z) + discard; + out_raster = vec4(var_uvzw.x, var_uvzw.y, depth_new, id_float); + out_db = var_db * var_uvzw.w; + IF_ZMODIFY(gl_FragDepth = gl_FragCoord.z + in_dummy;) + } + ) + ); + } + else + { + // Geometry shader without bary differential output. + compileGLShader(NVDR_CTX_PARAMS, s, &s.glGeometryShader, GL_GEOMETRY_SHADER, + "#version 330\n" + STRINGIFY_SHADER_SOURCE( + layout(triangles) in; + layout(triangle_strip, max_vertices=3) out; + in int v_layer[]; + in int v_offset[]; + out vec4 var_uvzw; + void main() + { + int layer_id = v_layer[0]; + int prim_id = gl_PrimitiveIDIn + v_offset[0]; + + gl_Layer = layer_id; gl_PrimitiveID = prim_id; gl_Position = vec4(gl_in[0].gl_Position.x, gl_in[0].gl_Position.y, gl_in[0].gl_Position.z, gl_in[0].gl_Position.w); var_uvzw = vec4(1.f, 0.f, gl_in[0].gl_Position.z, gl_in[0].gl_Position.w); EmitVertex(); + gl_Layer = layer_id; gl_PrimitiveID = prim_id; gl_Position = vec4(gl_in[1].gl_Position.x, gl_in[1].gl_Position.y, gl_in[1].gl_Position.z, gl_in[1].gl_Position.w); var_uvzw = vec4(0.f, 1.f, gl_in[1].gl_Position.z, gl_in[1].gl_Position.w); EmitVertex(); + gl_Layer = layer_id; gl_PrimitiveID = prim_id; gl_Position = vec4(gl_in[2].gl_Position.x, gl_in[2].gl_Position.y, gl_in[2].gl_Position.z, gl_in[2].gl_Position.w); var_uvzw = vec4(0.f, 0.f, gl_in[2].gl_Position.z, gl_in[2].gl_Position.w); EmitVertex(); + } + ) + ); + + // Fragment shader without bary differential output. + compileGLShader(NVDR_CTX_PARAMS, s, &s.glFragmentShader, GL_FRAGMENT_SHADER, + "#version 430\n" + STRINGIFY_SHADER_SOURCE( + in vec4 var_uvzw; + layout(location = 0) out vec4 out_raster; + IF_ZMODIFY( + layout(location = 1) uniform float in_dummy; + ) + void main() + { + int id_int = gl_PrimitiveID + 1; + float id_float = (id_int <= 0x01000000) ? float(id_int) : intBitsToFloat(0x4a800000 + id_int); + + out_raster = vec4(var_uvzw.x, var_uvzw.y, var_uvzw.z / var_uvzw.w, id_float); + IF_ZMODIFY(gl_FragDepth = gl_FragCoord.z + in_dummy;) + } + ) + ); + + // Depth peeling variant of fragment shader. + compileGLShader(NVDR_CTX_PARAMS, s, &s.glFragmentShaderDP, GL_FRAGMENT_SHADER, + "#version 430\n" + STRINGIFY_SHADER_SOURCE( + in vec4 var_uvzw; + layout(binding = 0) uniform sampler2DArray out_prev; + layout(location = 0) out vec4 out_raster; + IF_ZMODIFY( + layout(location = 1) uniform float in_dummy; + ) + void main() + { + int id_int = gl_PrimitiveID + 1; + float id_float = (id_int <= 0x01000000) ? float(id_int) : intBitsToFloat(0x4a800000 + id_int); + + vec4 prev = texelFetch(out_prev, ivec3(gl_FragCoord.x, gl_FragCoord.y, gl_Layer), 0); + float depth_new = var_uvzw.z / var_uvzw.w; + if (prev.w == 0 || depth_new <= prev.z) + discard; + out_raster = vec4(var_uvzw.x, var_uvzw.y, var_uvzw.z / var_uvzw.w, id_float); + IF_ZMODIFY(gl_FragDepth = gl_FragCoord.z + in_dummy;) + } + ) + ); + } + + // Finalize programs. + constructGLProgram(NVDR_CTX_PARAMS, &s.glProgram, s.glVertexShader, s.glGeometryShader, s.glFragmentShader); + constructGLProgram(NVDR_CTX_PARAMS, &s.glProgramDP, s.glVertexShader, s.glGeometryShader, s.glFragmentShaderDP); + + // Construct main fbo and bind permanently. + NVDR_CHECK_GL_ERROR(glGenFramebuffers(1, &s.glFBO)); + NVDR_CHECK_GL_ERROR(glBindFramebuffer(GL_FRAMEBUFFER, s.glFBO)); + + // Enable two color attachments. + GLenum draw_buffers[2] = { GL_COLOR_ATTACHMENT0, GL_COLOR_ATTACHMENT1 }; + NVDR_CHECK_GL_ERROR(glDrawBuffers(num_outputs, draw_buffers)); + + // Construct vertex array object. + NVDR_CHECK_GL_ERROR(glGenVertexArrays(1, &s.glVAO)); + NVDR_CHECK_GL_ERROR(glBindVertexArray(s.glVAO)); + + // Construct position buffer, bind permanently, enable, set ptr. + NVDR_CHECK_GL_ERROR(glGenBuffers(1, &s.glPosBuffer)); + NVDR_CHECK_GL_ERROR(glBindBuffer(GL_ARRAY_BUFFER, s.glPosBuffer)); + NVDR_CHECK_GL_ERROR(glEnableVertexAttribArray(0)); + NVDR_CHECK_GL_ERROR(glVertexAttribPointer(0, 4, GL_FLOAT, GL_FALSE, 0, 0)); + + // Construct index buffer and bind permanently. + NVDR_CHECK_GL_ERROR(glGenBuffers(1, &s.glTriBuffer)); + NVDR_CHECK_GL_ERROR(glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, s.glTriBuffer)); + + // Set up depth test. + NVDR_CHECK_GL_ERROR(glEnable(GL_DEPTH_TEST)); + NVDR_CHECK_GL_ERROR(glDepthFunc(GL_LESS)); + NVDR_CHECK_GL_ERROR(glClearDepth(1.0)); + + // Create and bind output buffers. Storage is allocated later. + NVDR_CHECK_GL_ERROR(glGenTextures(num_outputs, s.glColorBuffer)); + for (int i=0; i < num_outputs; i++) + { + NVDR_CHECK_GL_ERROR(glBindTexture(GL_TEXTURE_2D_ARRAY, s.glColorBuffer[i])); + NVDR_CHECK_GL_ERROR(glFramebufferTexture(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0 + i, s.glColorBuffer[i], 0)); + } + + // Create and bind depth/stencil buffer. Storage is allocated later. + NVDR_CHECK_GL_ERROR(glGenTextures(1, &s.glDepthStencilBuffer)); + NVDR_CHECK_GL_ERROR(glBindTexture(GL_TEXTURE_2D_ARRAY, s.glDepthStencilBuffer)); + NVDR_CHECK_GL_ERROR(glFramebufferTexture(GL_FRAMEBUFFER, GL_DEPTH_STENCIL_ATTACHMENT, s.glDepthStencilBuffer, 0)); + + // Create texture name for previous output buffer (depth peeling). + NVDR_CHECK_GL_ERROR(glGenTextures(1, &s.glPrevOutBuffer)); +} + +void rasterizeResizeBuffers(NVDR_CTX_ARGS, RasterizeGLState& s, bool& changes, int posCount, int triCount, int width, int height, int depth) +{ + changes = false; + + // Resize vertex buffer? + if (posCount > s.posCount) + { + if (s.cudaPosBuffer) + NVDR_CHECK_CUDA_ERROR(cudaGraphicsUnregisterResource(s.cudaPosBuffer)); + s.posCount = (posCount > 64) ? ROUND_UP_BITS(posCount, 2) : 64; + LOG(INFO) << "Increasing position buffer size to " << s.posCount << " float32"; + NVDR_CHECK_GL_ERROR(glBufferData(GL_ARRAY_BUFFER, s.posCount * sizeof(float), NULL, GL_DYNAMIC_DRAW)); + NVDR_CHECK_CUDA_ERROR(cudaGraphicsGLRegisterBuffer(&s.cudaPosBuffer, s.glPosBuffer, cudaGraphicsRegisterFlagsWriteDiscard)); + changes = true; + } + + // Resize triangle buffer? + if (triCount > s.triCount) + { + if (s.cudaTriBuffer) + NVDR_CHECK_CUDA_ERROR(cudaGraphicsUnregisterResource(s.cudaTriBuffer)); + s.triCount = (triCount > 64) ? ROUND_UP_BITS(triCount, 2) : 64; + LOG(INFO) << "Increasing triangle buffer size to " << s.triCount << " int32"; + NVDR_CHECK_GL_ERROR(glBufferData(GL_ELEMENT_ARRAY_BUFFER, s.triCount * sizeof(int32_t), NULL, GL_DYNAMIC_DRAW)); + NVDR_CHECK_CUDA_ERROR(cudaGraphicsGLRegisterBuffer(&s.cudaTriBuffer, s.glTriBuffer, cudaGraphicsRegisterFlagsWriteDiscard)); + changes = true; + } + + // Resize framebuffer? + if (width > s.width || height > s.height || depth > s.depth) + { + int num_outputs = s.enableDB ? 2 : 1; + if (s.cudaColorBuffer[0]) + for (int i=0; i < num_outputs; i++) + NVDR_CHECK_CUDA_ERROR(cudaGraphicsUnregisterResource(s.cudaColorBuffer[i])); + + if (s.cudaPrevOutBuffer) + { + NVDR_CHECK_CUDA_ERROR(cudaGraphicsUnregisterResource(s.cudaPrevOutBuffer)); + s.cudaPrevOutBuffer = 0; + } + + // New framebuffer size. + s.width = (width > s.width) ? width : s.width; + s.height = (height > s.height) ? height : s.height; + s.depth = (depth > s.depth) ? depth : s.depth; + s.width = ROUND_UP(s.width, 32); + s.height = ROUND_UP(s.height, 32); + LOG(INFO) << "Increasing frame buffer size to (width, height, depth) = (" << s.width << ", " << s.height << ", " << s.depth << ")"; + + // Allocate color buffers. + for (int i=0; i < num_outputs; i++) + { + NVDR_CHECK_GL_ERROR(glBindTexture(GL_TEXTURE_2D_ARRAY, s.glColorBuffer[i])); + NVDR_CHECK_GL_ERROR(glTexImage3D(GL_TEXTURE_2D_ARRAY, 0, GL_RGBA32F, s.width, s.height, s.depth, 0, GL_RGBA, GL_UNSIGNED_BYTE, 0)); + NVDR_CHECK_GL_ERROR(glTexParameteri(GL_TEXTURE_2D_ARRAY, GL_TEXTURE_MAG_FILTER, GL_NEAREST)); + NVDR_CHECK_GL_ERROR(glTexParameteri(GL_TEXTURE_2D_ARRAY, GL_TEXTURE_MIN_FILTER, GL_NEAREST)); + NVDR_CHECK_GL_ERROR(glTexParameteri(GL_TEXTURE_2D_ARRAY, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE)); + NVDR_CHECK_GL_ERROR(glTexParameteri(GL_TEXTURE_2D_ARRAY, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE)); + } + + // Allocate depth/stencil buffer. + NVDR_CHECK_GL_ERROR(glBindTexture(GL_TEXTURE_2D_ARRAY, s.glDepthStencilBuffer)); + NVDR_CHECK_GL_ERROR(glTexImage3D(GL_TEXTURE_2D_ARRAY, 0, GL_DEPTH24_STENCIL8, s.width, s.height, s.depth, 0, GL_DEPTH_STENCIL, GL_UNSIGNED_INT_24_8, 0)); + + // (Re-)register all GL buffers into Cuda. + for (int i=0; i < num_outputs; i++) + NVDR_CHECK_CUDA_ERROR(cudaGraphicsGLRegisterImage(&s.cudaColorBuffer[i], s.glColorBuffer[i], GL_TEXTURE_3D, cudaGraphicsRegisterFlagsReadOnly)); + + changes = true; + } +} + +void rasterizeRender(NVDR_CTX_ARGS, RasterizeGLState& s, cudaStream_t stream, const float* posPtr, int posCount, int vtxPerInstance, const int32_t* triPtr, int triCount, const int32_t* rangesPtr, int width, int height, int depth, int peeling_idx) +{ + // Only copy inputs if we are on first iteration of depth peeling or not doing it at all. + if (peeling_idx < 1) + { + if (triPtr) + { + // Copy both position and triangle buffers. + void* glPosPtr = NULL; + void* glTriPtr = NULL; + size_t posBytes = 0; + size_t triBytes = 0; + NVDR_CHECK_CUDA_ERROR(cudaGraphicsMapResources(2, &s.cudaPosBuffer, stream)); + NVDR_CHECK_CUDA_ERROR(cudaGraphicsResourceGetMappedPointer(&glPosPtr, &posBytes, s.cudaPosBuffer)); + NVDR_CHECK_CUDA_ERROR(cudaGraphicsResourceGetMappedPointer(&glTriPtr, &triBytes, s.cudaTriBuffer)); + NVDR_CHECK(posBytes >= posCount * sizeof(float), "mapped GL position buffer size mismatch"); + NVDR_CHECK(triBytes >= triCount * sizeof(int32_t), "mapped GL triangle buffer size mismatch"); + NVDR_CHECK_CUDA_ERROR(cudaMemcpyAsync(glPosPtr, posPtr, posCount * sizeof(float), cudaMemcpyDeviceToDevice, stream)); + NVDR_CHECK_CUDA_ERROR(cudaMemcpyAsync(glTriPtr, triPtr, triCount * sizeof(int32_t), cudaMemcpyDeviceToDevice, stream)); + NVDR_CHECK_CUDA_ERROR(cudaGraphicsUnmapResources(2, &s.cudaPosBuffer, stream)); + } + else + { + // Copy position buffer only. Triangles are already copied and known to be constant. + void* glPosPtr = NULL; + size_t posBytes = 0; + NVDR_CHECK_CUDA_ERROR(cudaGraphicsMapResources(1, &s.cudaPosBuffer, stream)); + NVDR_CHECK_CUDA_ERROR(cudaGraphicsResourceGetMappedPointer(&glPosPtr, &posBytes, s.cudaPosBuffer)); + NVDR_CHECK(posBytes >= posCount * sizeof(float), "mapped GL position buffer size mismatch"); + NVDR_CHECK_CUDA_ERROR(cudaMemcpyAsync(glPosPtr, posPtr, posCount * sizeof(float), cudaMemcpyDeviceToDevice, stream)); + NVDR_CHECK_CUDA_ERROR(cudaGraphicsUnmapResources(1, &s.cudaPosBuffer, stream)); + } + } + + // Select program based on whether we have a depth peeling input or not. + if (peeling_idx < 1) + { + // Normal case: No peeling, or peeling disabled. + NVDR_CHECK_GL_ERROR(glUseProgram(s.glProgram)); + } + else + { + // If we don't have a third buffer yet, create one. + if (!s.cudaPrevOutBuffer) + { + NVDR_CHECK_GL_ERROR(glBindTexture(GL_TEXTURE_2D_ARRAY, s.glPrevOutBuffer)); + NVDR_CHECK_GL_ERROR(glTexImage3D(GL_TEXTURE_2D_ARRAY, 0, GL_RGBA32F, s.width, s.height, s.depth, 0, GL_RGBA, GL_UNSIGNED_BYTE, 0)); + NVDR_CHECK_GL_ERROR(glTexParameteri(GL_TEXTURE_2D_ARRAY, GL_TEXTURE_MAG_FILTER, GL_NEAREST)); + NVDR_CHECK_GL_ERROR(glTexParameteri(GL_TEXTURE_2D_ARRAY, GL_TEXTURE_MIN_FILTER, GL_NEAREST)); + NVDR_CHECK_GL_ERROR(glTexParameteri(GL_TEXTURE_2D_ARRAY, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE)); + NVDR_CHECK_GL_ERROR(glTexParameteri(GL_TEXTURE_2D_ARRAY, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE)); + NVDR_CHECK_CUDA_ERROR(cudaGraphicsGLRegisterImage(&s.cudaPrevOutBuffer, s.glPrevOutBuffer, GL_TEXTURE_3D, cudaGraphicsRegisterFlagsReadOnly)); + } + + // Swap the GL buffers. + GLuint glTempBuffer = s.glPrevOutBuffer; + s.glPrevOutBuffer = s.glColorBuffer[0]; + s.glColorBuffer[0] = glTempBuffer; + + // Swap the Cuda buffers. + cudaGraphicsResource_t cudaTempBuffer = s.cudaPrevOutBuffer; + s.cudaPrevOutBuffer = s.cudaColorBuffer[0]; + s.cudaColorBuffer[0] = cudaTempBuffer; + + // Bind the new output buffer. + NVDR_CHECK_GL_ERROR(glBindTexture(GL_TEXTURE_2D_ARRAY, s.glColorBuffer[0])); + NVDR_CHECK_GL_ERROR(glFramebufferTexture(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, s.glColorBuffer[0], 0)); + + // Bind old buffer as the input texture. + NVDR_CHECK_GL_ERROR(glBindTexture(GL_TEXTURE_2D_ARRAY, s.glPrevOutBuffer)); + + // Activate the correct program. + NVDR_CHECK_GL_ERROR(glUseProgram(s.glProgramDP)); + } + + // Set viewport, clear color buffer(s) and depth/stencil buffer. + NVDR_CHECK_GL_ERROR(glViewport(0, 0, width, height)); + NVDR_CHECK_GL_ERROR(glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT | GL_STENCIL_BUFFER_BIT)); + + // If outputting bary differentials, set resolution uniform + if (s.enableDB) + NVDR_CHECK_GL_ERROR(glUniform2f(0, 2.f / (float)width, 2.f / (float)height)); + + // Set the dummy uniform if depth modification workaround is active. + if (s.enableZModify) + NVDR_CHECK_GL_ERROR(glUniform1f(1, 0.f)); + + // Render the meshes. + if (depth == 1 && !rangesPtr) + { + // Trivial case. + NVDR_CHECK_GL_ERROR(glDrawElements(GL_TRIANGLES, triCount, GL_UNSIGNED_INT, 0)); + } + else + { + // Populate a buffer for draw commands and execute it. + std::vector drawCmdBuffer(depth); + + if (!rangesPtr) + { + // Fill in range array to instantiate the same triangles for each output layer. + // Triangle IDs starts at zero (i.e., one) for each layer, so they correspond to + // the first dimension in addressing the triangle array. + for (int i=0; i < depth; i++) + { + GLDrawCmd& cmd = drawCmdBuffer[i]; + cmd.firstIndex = 0; + cmd.count = triCount; + cmd.baseVertex = vtxPerInstance * i; + cmd.baseInstance = 0; + cmd.instanceCount = 1; + } + } + else + { + // Fill in the range array according to user-given ranges. Triangle IDs point + // to the input triangle array, NOT index within range, so they correspond to + // the first dimension in addressing the triangle array. + for (int i=0, j=0; i < depth; i++) + { + GLDrawCmd& cmd = drawCmdBuffer[i]; + int first = rangesPtr[j++]; + int count = rangesPtr[j++]; + NVDR_CHECK(first >= 0 && count >= 0, "range contains negative values"); + NVDR_CHECK((first + count) * 3 <= triCount, "range extends beyond end of triangle buffer"); + cmd.firstIndex = first * 3; + cmd.count = count * 3; + cmd.baseVertex = 0; + cmd.baseInstance = first; + cmd.instanceCount = 1; + } + } + + // Draw! + NVDR_CHECK_GL_ERROR(glMultiDrawElementsIndirect(GL_TRIANGLES, GL_UNSIGNED_INT, &drawCmdBuffer[0], depth, sizeof(GLDrawCmd))); + } +} + +void rasterizeCopyResults(NVDR_CTX_ARGS, RasterizeGLState& s, cudaStream_t stream, float** outputPtr, int width, int height, int depth) +{ + // Copy color buffers to output tensors. + cudaArray_t array = 0; + cudaChannelFormatDesc arrayDesc = {}; // For error checking. + cudaExtent arrayExt = {}; // For error checking. + int num_outputs = s.enableDB ? 2 : 1; + NVDR_CHECK_CUDA_ERROR(cudaGraphicsMapResources(num_outputs, s.cudaColorBuffer, stream)); + for (int i=0; i < num_outputs; i++) + { + NVDR_CHECK_CUDA_ERROR(cudaGraphicsSubResourceGetMappedArray(&array, s.cudaColorBuffer[i], 0, 0)); + NVDR_CHECK_CUDA_ERROR(cudaArrayGetInfo(&arrayDesc, &arrayExt, NULL, array)); + NVDR_CHECK(arrayDesc.f == cudaChannelFormatKindFloat, "CUDA mapped array data kind mismatch"); + NVDR_CHECK(arrayDesc.x == 32 && arrayDesc.y == 32 && arrayDesc.z == 32 && arrayDesc.w == 32, "CUDA mapped array data width mismatch"); + NVDR_CHECK(arrayExt.width >= width && arrayExt.height >= height && arrayExt.depth >= depth, "CUDA mapped array extent mismatch"); + cudaMemcpy3DParms p = {0}; + p.srcArray = array; + p.dstPtr.ptr = outputPtr[i]; + p.dstPtr.pitch = width * 4 * sizeof(float); + p.dstPtr.xsize = width; + p.dstPtr.ysize = height; + p.extent.width = width; + p.extent.height = height; + p.extent.depth = depth; + p.kind = cudaMemcpyDeviceToDevice; + NVDR_CHECK_CUDA_ERROR(cudaMemcpy3DAsync(&p, stream)); + } + NVDR_CHECK_CUDA_ERROR(cudaGraphicsUnmapResources(num_outputs, s.cudaColorBuffer, stream)); +} + +void rasterizeReleaseBuffers(NVDR_CTX_ARGS, RasterizeGLState& s) +{ + int num_outputs = s.enableDB ? 2 : 1; + + if (s.cudaPosBuffer) + { + NVDR_CHECK_CUDA_ERROR(cudaGraphicsUnregisterResource(s.cudaPosBuffer)); + s.cudaPosBuffer = 0; + } + + if (s.cudaTriBuffer) + { + NVDR_CHECK_CUDA_ERROR(cudaGraphicsUnregisterResource(s.cudaTriBuffer)); + s.cudaTriBuffer = 0; + } + + for (int i=0; i < num_outputs; i++) + { + if (s.cudaColorBuffer[i]) + { + NVDR_CHECK_CUDA_ERROR(cudaGraphicsUnregisterResource(s.cudaColorBuffer[i])); + s.cudaColorBuffer[i] = 0; + } + } + + if (s.cudaPrevOutBuffer) + { + NVDR_CHECK_CUDA_ERROR(cudaGraphicsUnregisterResource(s.cudaPrevOutBuffer)); + s.cudaPrevOutBuffer = 0; + } +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/rasterize_gl.h b/extensions/nvdiffrast/nvdiffrast/common/rasterize_gl.h new file mode 100644 index 0000000000000000000000000000000000000000..ed0f75036a837cf5345a7aa27abcf39bf0ddbd40 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/rasterize_gl.h @@ -0,0 +1,60 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once + +//------------------------------------------------------------------------ +// Do not try to include OpenGL stuff when compiling CUDA kernels for torch. + +#if !(defined(NVDR_TORCH) && defined(__CUDACC__)) +#include "framework.h" +#include "glutil.h" + +//------------------------------------------------------------------------ +// OpenGL-related persistent state for forward op. + +struct RasterizeGLState // Must be initializable by memset to zero. +{ + int width; // Allocated frame buffer width. + int height; // Allocated frame buffer height. + int depth; // Allocated frame buffer depth. + int posCount; // Allocated position buffer in floats. + int triCount; // Allocated triangle buffer in ints. + GLContext glctx; + GLuint glFBO; + GLuint glColorBuffer[2]; + GLuint glPrevOutBuffer; + GLuint glDepthStencilBuffer; + GLuint glVAO; + GLuint glTriBuffer; + GLuint glPosBuffer; + GLuint glProgram; + GLuint glProgramDP; + GLuint glVertexShader; + GLuint glGeometryShader; + GLuint glFragmentShader; + GLuint glFragmentShaderDP; + cudaGraphicsResource_t cudaColorBuffer[2]; + cudaGraphicsResource_t cudaPrevOutBuffer; + cudaGraphicsResource_t cudaPosBuffer; + cudaGraphicsResource_t cudaTriBuffer; + int enableDB; + int enableZModify; // Modify depth in shader, workaround for a rasterization issue on A100. +}; + +//------------------------------------------------------------------------ +// Shared C++ code prototypes. + +void rasterizeInitGLContext(NVDR_CTX_ARGS, RasterizeGLState& s, int cudaDeviceIdx); +void rasterizeResizeBuffers(NVDR_CTX_ARGS, RasterizeGLState& s, bool& changes, int posCount, int triCount, int width, int height, int depth); +void rasterizeRender(NVDR_CTX_ARGS, RasterizeGLState& s, cudaStream_t stream, const float* posPtr, int posCount, int vtxPerInstance, const int32_t* triPtr, int triCount, const int32_t* rangesPtr, int width, int height, int depth, int peeling_idx); +void rasterizeCopyResults(NVDR_CTX_ARGS, RasterizeGLState& s, cudaStream_t stream, float** outputPtr, int width, int height, int depth); +void rasterizeReleaseBuffers(NVDR_CTX_ARGS, RasterizeGLState& s); + +//------------------------------------------------------------------------ +#endif // !(defined(NVDR_TORCH) && defined(__CUDACC__)) diff --git a/extensions/nvdiffrast/nvdiffrast/common/texture.cpp b/extensions/nvdiffrast/nvdiffrast/common/texture.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8c38fb5a2d8fe8075719d24b25162c8c34347c02 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/texture.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "framework.h" +#include "texture.h" + +//------------------------------------------------------------------------ +// Mip stack construction and access helpers. + +void raiseMipSizeError(NVDR_CTX_ARGS, const TextureKernelParams& p) +{ + char buf[1024]; + int bufsz = 1024; + + std::string msg = "Mip-map size error - cannot downsample an odd extent greater than 1. Resize the texture so that both spatial extents are powers of two, or limit the number of mip maps using max_mip_level argument.\n"; + + int w = p.texWidth; + int h = p.texHeight; + bool ew = false; + bool eh = false; + + msg += "Attempted mip stack construction:\n"; + msg += "level width height\n"; + msg += "----- ----- ------\n"; + snprintf(buf, bufsz, "base %5d %5d\n", w, h); + msg += buf; + + int mipTotal = 0; + int level = 0; + while ((w|h) > 1 && !(ew || eh)) // Stop at first impossible size. + { + // Current level. + level += 1; + + // Determine if downsampling fails. + ew = ew || (w > 1 && (w & 1)); + eh = eh || (h > 1 && (h & 1)); + + // Downsample. + if (w > 1) w >>= 1; + if (h > 1) h >>= 1; + + // Append level size to error message. + snprintf(buf, bufsz, "mip %-2d ", level); + msg += buf; + if (ew) snprintf(buf, bufsz, " err "); + else snprintf(buf, bufsz, "%5d ", w); + msg += buf; + if (eh) snprintf(buf, bufsz, " err\n"); + else snprintf(buf, bufsz, "%5d\n", h); + msg += buf; + } + + NVDR_CHECK(0, msg); +} + +int calculateMipInfo(NVDR_CTX_ARGS, TextureKernelParams& p, int* mipOffsets) +{ + // No levels at all? + if (p.mipLevelLimit == 0) + { + p.mipLevelMax = 0; + return 0; + } + + // Current level size. + int w = p.texWidth; + int h = p.texHeight; + + int mipTotal = 0; + int level = 0; + int c = (p.boundaryMode == TEX_BOUNDARY_MODE_CUBE) ? (p.channels * 6) : p.channels; + mipOffsets[0] = 0; + while ((w|h) > 1) + { + // Current level. + level += 1; + + // Quit if cannot downsample. + if ((w > 1 && (w & 1)) || (h > 1 && (h & 1))) + raiseMipSizeError(NVDR_CTX_PARAMS, p); + + // Downsample. + if (w > 1) w >>= 1; + if (h > 1) h >>= 1; + + mipOffsets[level] = mipTotal; // Store the mip offset (#floats). + mipTotal += w * h * p.texDepth * c; + + // Hit the level limit? + if (p.mipLevelLimit >= 0 && level == p.mipLevelLimit) + break; + } + + p.mipLevelMax = level; + return mipTotal; +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/texture.h b/extensions/nvdiffrast/nvdiffrast/common/texture.h new file mode 100644 index 0000000000000000000000000000000000000000..a8d5ee1b6aa5b19ead3460afd24b9f8e64f456dd --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/texture.h @@ -0,0 +1,78 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once +#include "framework.h" + +//------------------------------------------------------------------------ +// Constants. + +#define TEX_DEBUG_MIP_RETAIN_VARIANCE 0 // For debugging +#define TEX_FWD_MAX_KERNEL_BLOCK_WIDTH 8 +#define TEX_FWD_MAX_KERNEL_BLOCK_HEIGHT 8 +#define TEX_FWD_MAX_MIP_KERNEL_BLOCK_WIDTH 8 +#define TEX_FWD_MAX_MIP_KERNEL_BLOCK_HEIGHT 8 +#define TEX_GRAD_MAX_KERNEL_BLOCK_WIDTH 8 +#define TEX_GRAD_MAX_KERNEL_BLOCK_HEIGHT 8 +#define TEX_GRAD_MAX_MIP_KERNEL_BLOCK_WIDTH 8 +#define TEX_GRAD_MAX_MIP_KERNEL_BLOCK_HEIGHT 8 +#define TEX_MAX_MIP_LEVEL 16 // Currently a texture cannot be larger than 2 GB because we use 32-bit indices everywhere. +#define TEX_MODE_NEAREST 0 // Nearest on base level. +#define TEX_MODE_LINEAR 1 // Bilinear on base level. +#define TEX_MODE_LINEAR_MIPMAP_NEAREST 2 // Bilinear on nearest mip level. +#define TEX_MODE_LINEAR_MIPMAP_LINEAR 3 // Trilinear. +#define TEX_MODE_COUNT 4 +#define TEX_BOUNDARY_MODE_CUBE 0 // Cube map mode. +#define TEX_BOUNDARY_MODE_WRAP 1 // Wrap (u, v). +#define TEX_BOUNDARY_MODE_CLAMP 2 // Clamp (u, v). +#define TEX_BOUNDARY_MODE_ZERO 3 // Pad with zeros. +#define TEX_BOUNDARY_MODE_COUNT 4 + +//------------------------------------------------------------------------ +// CUDA kernel params. + +struct TextureKernelParams +{ + const float* tex[TEX_MAX_MIP_LEVEL]; // Incoming texture buffer with mip levels. + const float* uv; // Incoming texcoord buffer. + const float* uvDA; // Incoming uv pixel diffs or NULL. + const float* mipLevelBias; // Incoming mip level bias or NULL. + const float* dy; // Incoming output gradient. + float* out; // Outgoing texture data. + float* gradTex[TEX_MAX_MIP_LEVEL]; // Outgoing texture gradients with mip levels. + float* gradUV; // Outgoing texcoord gradient. + float* gradUVDA; // Outgoing texcoord pixel differential gradient. + float* gradMipLevelBias; // Outgoing mip level bias gradient. + int enableMip; // If true, we have uv_da and/or mip_level_bias input(s), and a mip tensor. + int filterMode; // One of the TEX_MODE_ constants. + int boundaryMode; // One of the TEX_BOUNDARY_MODE_ contants. + int texConst; // If true, texture is known to be constant. + int mipLevelLimit; // Mip level limit coming from the op. + int channels; // Number of texture channels. + int imgWidth; // Image width. + int imgHeight; // Image height. + int texWidth; // Texture width. + int texHeight; // Texture height. + int texDepth; // Texture depth. + int n; // Minibatch size. + int mipLevelMax; // Maximum mip level index. Zero if mips disabled. + int mipLevelOut; // Mip level being calculated in builder kernel. +}; + +//------------------------------------------------------------------------ +// C++ helper function prototypes. + +void raiseMipSizeError(NVDR_CTX_ARGS, const TextureKernelParams& p); +int calculateMipInfo(NVDR_CTX_ARGS, TextureKernelParams& p, int* mipOffsets); + +//------------------------------------------------------------------------ +// Macros. + +#define mipLevelSize(p, i) make_int2(((p).texWidth >> (i)) > 1 ? ((p).texWidth >> (i)) : 1, ((p).texHeight >> (i)) > 1 ? ((p).texHeight >> (i)) : 1) + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/common/texture_.cu b/extensions/nvdiffrast/nvdiffrast/common/texture_.cu new file mode 100644 index 0000000000000000000000000000000000000000..40bd62974b76a7cd5fa594bb74a1d61330857011 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/common/texture_.cu @@ -0,0 +1,1156 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "common.h" +#include "texture.h" + +//------------------------------------------------------------------------ +// Memory access and math helpers. + +static __device__ __forceinline__ void accum_from_mem(float* a, int s, float b, float c) { a[0] += b * c; } +static __device__ __forceinline__ void accum_from_mem(float* a, int s, float2 b, float c) { a[0] += b.x * c; a[s] += b.y * c; } +static __device__ __forceinline__ void accum_from_mem(float* a, int s, float4 b, float c) { a[0] += b.x * c; a[s] += b.y * c; a[2*s] += b.z * c; a[3*s] += b.w * c; } +static __device__ __forceinline__ void accum_to_mem(float& a, float* b, int s) { a += b[0]; } +static __device__ __forceinline__ void accum_to_mem(float2& a, float* b, int s) { float2 v = a; v.x += b[0]; v.y += b[s]; a = v; } +static __device__ __forceinline__ void accum_to_mem(float4& a, float* b, int s) { float4 v = a; v.x += b[0]; v.y += b[s]; v.z += b[2*s]; v.w += b[3*s]; a = v; } +static __device__ __forceinline__ bool isfinite_vec3(const float3& a) { return isfinite(a.x) && isfinite(a.y) && isfinite(a.z); } +static __device__ __forceinline__ bool isfinite_vec4(const float4& a) { return isfinite(a.x) && isfinite(a.y) && isfinite(a.z) && isfinite(a.w); } +template static __device__ __forceinline__ T lerp (const T& a, const T& b, float c) { return a + c * (b - a); } +template static __device__ __forceinline__ T bilerp(const T& a, const T& b, const T& c, const T& d, const float2& e) { return lerp(lerp(a, b, e.x), lerp(c, d, e.x), e.y); } + +//------------------------------------------------------------------------ +// Cube map wrapping for smooth filtering across edges and corners. At corners, +// one of the texture coordinates will be negative. For correct interpolation, +// the missing texel must take the average color of the other three. + +static __constant__ uint32_t c_cubeWrapMask1[48] = +{ + 0x1530a440, 0x1133a550, 0x6103a110, 0x1515aa44, 0x6161aa11, 0x40154a04, 0x44115a05, 0x04611a01, + 0x2630a440, 0x2233a550, 0x5203a110, 0x2626aa44, 0x5252aa11, 0x40264a04, 0x44225a05, 0x04521a01, + 0x32608064, 0x3366a055, 0x13062091, 0x32328866, 0x13132299, 0x50320846, 0x55330a55, 0x05130219, + 0x42508064, 0x4455a055, 0x14052091, 0x42428866, 0x14142299, 0x60420846, 0x66440a55, 0x06140219, + 0x5230a044, 0x5533a055, 0x1503a011, 0x5252aa44, 0x1515aa11, 0x40520a44, 0x44550a55, 0x04150a11, + 0x6130a044, 0x6633a055, 0x2603a011, 0x6161aa44, 0x2626aa11, 0x40610a44, 0x44660a55, 0x04260a11, +}; + +static __constant__ uint8_t c_cubeWrapMask2[48] = +{ + 0x26, 0x33, 0x11, 0x05, 0x00, 0x09, 0x0c, 0x04, 0x04, 0x00, 0x00, 0x05, 0x00, 0x81, 0xc0, 0x40, + 0x02, 0x03, 0x09, 0x00, 0x0a, 0x00, 0x00, 0x02, 0x64, 0x30, 0x90, 0x55, 0xa0, 0x99, 0xcc, 0x64, + 0x24, 0x30, 0x10, 0x05, 0x00, 0x01, 0x00, 0x00, 0x06, 0x03, 0x01, 0x05, 0x00, 0x89, 0xcc, 0x44, +}; + +static __device__ __forceinline__ int4 wrapCubeMap(int face, int ix0, int ix1, int iy0, int iy1, int w) +{ + // Calculate case number. + int cx = (ix0 < 0) ? 0 : (ix1 >= w) ? 2 : 1; + int cy = (iy0 < 0) ? 0 : (iy1 >= w) ? 6 : 3; + int c = cx + cy; + if (c >= 5) + c--; + c = (face << 3) + c; + + // Compute coordinates and faces. + unsigned int m = c_cubeWrapMask1[c]; + int x0 = (m >> 0) & 3; x0 = (x0 == 0) ? 0 : (x0 == 1) ? ix0 : iy0; + int x1 = (m >> 2) & 3; x1 = (x1 == 0) ? 0 : (x1 == 1) ? ix1 : iy0; + int x2 = (m >> 4) & 3; x2 = (x2 == 0) ? 0 : (x2 == 1) ? ix0 : iy1; + int x3 = (m >> 6) & 3; x3 = (x3 == 0) ? 0 : (x3 == 1) ? ix1 : iy1; + int y0 = (m >> 8) & 3; y0 = (y0 == 0) ? 0 : (y0 == 1) ? ix0 : iy0; + int y1 = (m >> 10) & 3; y1 = (y1 == 0) ? 0 : (y1 == 1) ? ix1 : iy0; + int y2 = (m >> 12) & 3; y2 = (y2 == 0) ? 0 : (y2 == 1) ? ix0 : iy1; + int y3 = (m >> 14) & 3; y3 = (y3 == 0) ? 0 : (y3 == 1) ? ix1 : iy1; + int f0 = ((m >> 16) & 15) - 1; + int f1 = ((m >> 20) & 15) - 1; + int f2 = ((m >> 24) & 15) - 1; + int f3 = ((m >> 28) ) - 1; + + // Flips. + unsigned int f = c_cubeWrapMask2[c]; + int w1 = w - 1; + if (f & 0x01) x0 = w1 - x0; + if (f & 0x02) x1 = w1 - x1; + if (f & 0x04) x2 = w1 - x2; + if (f & 0x08) x3 = w1 - x3; + if (f & 0x10) y0 = w1 - y0; + if (f & 0x20) y1 = w1 - y1; + if (f & 0x40) y2 = w1 - y2; + if (f & 0x80) y3 = w1 - y3; + + // Done. + int4 tcOut; + tcOut.x = x0 + (y0 + f0 * w) * w; + tcOut.y = x1 + (y1 + f1 * w) * w; + tcOut.z = x2 + (y2 + f2 * w) * w; + tcOut.w = x3 + (y3 + f3 * w) * w; + return tcOut; +} + +//------------------------------------------------------------------------ +// Cube map indexing and gradient functions. + +// Map a 3D lookup vector into an (s,t) face coordinates (returned in first . +// two parameters) and face index. +static __device__ __forceinline__ int indexCubeMap(float& x, float& y, float z) +{ + float ax = fabsf(x); + float ay = fabsf(y); + float az = fabsf(z); + int idx; + float c; + if (az > fmaxf(ax, ay)) { idx = 4; c = z; } + else if (ay > ax) { idx = 2; c = y; y = z; } + else { idx = 0; c = x; x = z; } + if (c < 0.f) idx += 1; + float m = __frcp_rz(fabsf(c)) * .5; + float m0 = __uint_as_float(__float_as_uint(m) ^ ((0x21u >> idx) << 31)); + float m1 = (idx != 2) ? -m : m; + x = x * m0 + .5; + y = y * m1 + .5; + if (!isfinite(x) || !isfinite(y)) + return -1; // Invalid uv. + x = fminf(fmaxf(x, 0.f), 1.f); + y = fminf(fmaxf(y, 0.f), 1.f); + return idx; +} + +// Based on dA/d{s,t}, compute dA/d{x,y,z} at a given 3D lookup vector. +static __device__ __forceinline__ float3 indexCubeMapGrad(float3 uv, float gu, float gv) +{ + float ax = fabsf(uv.x); + float ay = fabsf(uv.y); + float az = fabsf(uv.z); + int idx; + float c; + float c0 = gu; + float c1 = gv; + if (az > fmaxf(ax, ay)) { idx = 0x10; c = uv.z; c0 *= uv.x; c1 *= uv.y; } + else if (ay > ax) { idx = 0x04; c = uv.y; c0 *= uv.x; c1 *= uv.z; } + else { idx = 0x01; c = uv.x; c0 *= uv.z; c1 *= uv.y; } + if (c < 0.f) idx += idx; + float m = __frcp_rz(fabsf(c)); + c0 = (idx & 0x34) ? -c0 : c0; + c1 = (idx & 0x2e) ? -c1 : c1; + float gl = (c0 + c1) * m; + float gx = (idx & 0x03) ? gl : (idx & 0x20) ? -gu : gu; + float gy = (idx & 0x0c) ? gl : -gv; + float gz = (idx & 0x30) ? gl : (idx & 0x03) ? gu : gv; + gz = (idx & 0x09) ? -gz : gz; + float3 res = make_float3(gx, gy, gz) * (m * .5f); + if (!isfinite_vec3(res)) + return make_float3(0.f, 0.f, 0.f); // Invalid uv. + return res; +} + +// Based on dL/d(d{s,t}/s{X,Y}), compute dL/d(d{x,y,z}/d{X,Y}). This is just two +// indexCubeMapGrad() functions rolled together. +static __device__ __forceinline__ void indexCubeMapGrad4(float3 uv, float4 dw, float3& g0, float3& g1) +{ + float ax = fabsf(uv.x); + float ay = fabsf(uv.y); + float az = fabsf(uv.z); + int idx; + float c, c0, c1; + if (az > fmaxf(ax, ay)) { idx = 0x10; c = uv.z; c0 = uv.x; c1 = uv.y; } + else if (ay > ax) { idx = 0x04; c = uv.y; c0 = uv.x; c1 = uv.z; } + else { idx = 0x01; c = uv.x; c0 = uv.z; c1 = uv.y; } + if (c < 0.f) idx += idx; + float m = __frcp_rz(fabsf(c)); + c0 = (idx & 0x34) ? -c0 : c0; + c1 = (idx & 0x2e) ? -c1 : c1; + float gl0 = (dw.x * c0 + dw.z * c1) * m; + float gl1 = (dw.y * c0 + dw.w * c1) * m; + float gx0 = (idx & 0x03) ? gl0 : (idx & 0x20) ? -dw.x : dw.x; + float gx1 = (idx & 0x03) ? gl1 : (idx & 0x20) ? -dw.y : dw.y; + float gy0 = (idx & 0x0c) ? gl0 : -dw.z; + float gy1 = (idx & 0x0c) ? gl1 : -dw.w; + float gz0 = (idx & 0x30) ? gl0 : (idx & 0x03) ? dw.x : dw.z; + float gz1 = (idx & 0x30) ? gl1 : (idx & 0x03) ? dw.y : dw.w; + if (idx & 0x09) + { + gz0 = -gz0; + gz1 = -gz1; + } + g0 = make_float3(gx0, gy0, gz0) * (m * .5f); + g1 = make_float3(gx1, gy1, gz1) * (m * .5f); + if (!isfinite_vec3(g0) || !isfinite_vec3(g1)) + { + g0 = make_float3(0.f, 0.f, 0.f); // Invalid uv. + g1 = make_float3(0.f, 0.f, 0.f); + } +} + +// Compute d{s,t}/d{X,Y} based on d{x,y,z}/d{X,Y} at a given 3D lookup vector. +// Result is (ds/dX, ds/dY, dt/dX, dt/dY). +static __device__ __forceinline__ float4 indexCubeMapGradST(float3 uv, float3 dvdX, float3 dvdY) +{ + float ax = fabsf(uv.x); + float ay = fabsf(uv.y); + float az = fabsf(uv.z); + int idx; + float c, gu, gv; + if (az > fmaxf(ax, ay)) { idx = 0x10; c = uv.z; gu = uv.x; gv = uv.y; } + else if (ay > ax) { idx = 0x04; c = uv.y; gu = uv.x; gv = uv.z; } + else { idx = 0x01; c = uv.x; gu = uv.z; gv = uv.y; } + if (c < 0.f) idx += idx; + if (idx & 0x09) + { + dvdX.z = -dvdX.z; + dvdY.z = -dvdY.z; + } + float m = __frcp_rz(fabsf(c)); + float dm = m * .5f; + float mm = m * dm; + gu *= (idx & 0x34) ? -mm : mm; + gv *= (idx & 0x2e) ? -mm : mm; + + float4 res; + if (idx & 0x03) + { + res = make_float4(gu * dvdX.x + dm * dvdX.z, + gu * dvdY.x + dm * dvdY.z, + gv * dvdX.x - dm * dvdX.y, + gv * dvdY.x - dm * dvdY.y); + } + else if (idx & 0x0c) + { + res = make_float4(gu * dvdX.y + dm * dvdX.x, + gu * dvdY.y + dm * dvdY.x, + gv * dvdX.y + dm * dvdX.z, + gv * dvdY.y + dm * dvdY.z); + } + else // (idx & 0x30) + { + res = make_float4(gu * dvdX.z + copysignf(dm, c) * dvdX.x, + gu * dvdY.z + copysignf(dm, c) * dvdY.x, + gv * dvdX.z - dm * dvdX.y, + gv * dvdY.z - dm * dvdY.y); + } + + if (!isfinite_vec4(res)) + return make_float4(0.f, 0.f, 0.f, 0.f); + + return res; +} + +// Compute d(d{s,t}/d{X,Y})/d{x,y,z}, i.e., how the pixel derivatives of 2D face +// coordinates change w.r.t. 3D texture coordinate vector, returned as follows: +// | d(ds/dX)/dx d(ds/dY)/dx d(dt/dX)/dx d(dt/dY)/dx | +// | d(ds/dX)/dy d(ds/dY)/dy d(dt/dX)/dy d(dt/dY)/dy | +// | d(ds/dX)/dz d(ds/dY)/dz d(dt/dX)/dz d(dt/dY)/dz | +static __device__ __forceinline__ void indexCubeMapGrad2(float3 uv, float3 dvdX, float3 dvdY, float4& dx, float4& dy, float4& dz) +{ + float ax = fabsf(uv.x); + float ay = fabsf(uv.y); + float az = fabsf(uv.z); + int idx; + float c, gu, gv; + if (az > fmaxf(ax, ay)) { idx = 0x10; c = uv.z; gu = uv.x; gv = uv.y; } + else if (ay > ax) { idx = 0x04; c = uv.y; gu = uv.x; gv = uv.z; } + else { idx = 0x01; c = uv.x; gu = uv.z; gv = uv.y; } + if (c < 0.f) idx += idx; + + if (idx & 0x09) + { + dvdX.z = -dvdX.z; + dvdY.z = -dvdY.z; + } + + float m = __frcp_rz(c); + float dm = -m * fabsf(m) * .5; + float mm = m * m * .5; + float mu = (idx & 0x34) ? -mm : mm; + float mv = (idx & 0x2e) ? -mm : mm; + gu *= -2.0 * m * mu; + gv *= -2.0 * m * mv; + + if (idx & 0x03) + { + dx.x = gu * dvdX.x + dm * dvdX.z; + dx.y = gu * dvdY.x + dm * dvdY.z; + dx.z = gv * dvdX.x - dm * dvdX.y; + dx.w = gv * dvdY.x - dm * dvdY.y; + dy.x = 0.f; + dy.y = 0.f; + dy.z = mv * dvdX.x; + dy.w = mv * dvdY.x; + dz.x = mu * dvdX.x; + dz.y = mu * dvdY.x; + dz.z = 0.f; + dz.w = 0.f; + } + else if (idx & 0x0c) + { + dx.x = mu * dvdX.y; + dx.y = mu * dvdY.y; + dx.z = 0.f; + dx.w = 0.f; + dy.x = gu * dvdX.y + dm * dvdX.x; + dy.y = gu * dvdY.y + dm * dvdY.x; + dy.z = gv * dvdX.y + dm * dvdX.z; + dy.w = gv * dvdY.y + dm * dvdY.z; + dz.x = 0.f; + dz.y = 0.f; + dz.z = mv * dvdX.y; + dz.w = mv * dvdY.y; + } + else // (idx & 0x30) + { + dx.x = mu * dvdX.z; + dx.y = mu * dvdY.z; + dx.z = 0.f; + dx.w = 0.f; + dy.x = 0.f; + dy.y = 0.f; + dy.z = mv * dvdX.z; + dy.w = mv * dvdY.z; + dz.x = gu * dvdX.z - fabsf(dm) * dvdX.x; + dz.y = gu * dvdY.z - fabsf(dm) * dvdY.x; + dz.z = gv * dvdX.z - dm * dvdX.y; + dz.w = gv * dvdY.z - dm * dvdY.y; + } +} + +//------------------------------------------------------------------------ +// General texture indexing. + +template +static __device__ __forceinline__ int indexTextureNearest(const TextureKernelParams& p, float3 uv, int tz) +{ + int w = p.texWidth; + int h = p.texHeight; + float u = uv.x; + float v = uv.y; + + // Cube map indexing. + if (CUBE_MODE) + { + // No wrap. Fold face index into tz right away. + int idx = indexCubeMap(u, v, uv.z); // Rewrites u, v. + if (idx < 0) + return -1; // Invalid uv. + tz = 6 * tz + idx; + } + else + { + // Handle boundary. + if (p.boundaryMode == TEX_BOUNDARY_MODE_WRAP) + { + u = u - (float)__float2int_rd(u); + v = v - (float)__float2int_rd(v); + } + } + + u = u * (float)w; + v = v * (float)h; + + int iu = __float2int_rd(u); + int iv = __float2int_rd(v); + + // In zero boundary mode, return texture address -1. + if (!CUBE_MODE && p.boundaryMode == TEX_BOUNDARY_MODE_ZERO) + { + if (iu < 0 || iu >= w || iv < 0 || iv >= h) + return -1; + } + + // Otherwise clamp and calculate the coordinate properly. + iu = min(max(iu, 0), w-1); + iv = min(max(iv, 0), h-1); + return iu + w * (iv + tz * h); +} + +template +static __device__ __forceinline__ float2 indexTextureLinear(const TextureKernelParams& p, float3 uv, int tz, int4& tcOut, int level) +{ + // Mip level size. + int2 sz = mipLevelSize(p, level); + int w = sz.x; + int h = sz.y; + + // Compute texture-space u, v. + float u = uv.x; + float v = uv.y; + bool clampU = false; + bool clampV = false; + + // Cube map indexing. + int face = 0; + if (CUBE_MODE) + { + // Neither clamp or wrap. + face = indexCubeMap(u, v, uv.z); // Rewrites u, v. + if (face < 0) + { + tcOut.x = tcOut.y = tcOut.z = tcOut.w = -1; // Invalid uv. + return make_float2(0.f, 0.f); + } + u = u * (float)w - 0.5f; + v = v * (float)h - 0.5f; + } + else + { + if (p.boundaryMode == TEX_BOUNDARY_MODE_WRAP) + { + // Wrap. + u = u - (float)__float2int_rd(u); + v = v - (float)__float2int_rd(v); + } + + // Move to texel space. + u = u * (float)w - 0.5f; + v = v * (float)h - 0.5f; + + if (p.boundaryMode == TEX_BOUNDARY_MODE_CLAMP) + { + // Clamp to center of edge texels. + u = fminf(fmaxf(u, 0.f), w - 1.f); + v = fminf(fmaxf(v, 0.f), h - 1.f); + clampU = (u == 0.f || u == w - 1.f); + clampV = (v == 0.f || v == h - 1.f); + } + } + + // Compute texel coordinates and weights. + int iu0 = __float2int_rd(u); + int iv0 = __float2int_rd(v); + int iu1 = iu0 + (clampU ? 0 : 1); // Ensure zero u/v gradients with clamped. + int iv1 = iv0 + (clampV ? 0 : 1); + u -= (float)iu0; + v -= (float)iv0; + + // Cube map wrapping. + bool cubeWrap = CUBE_MODE && (iu0 < 0 || iv0 < 0 || iu1 >= w || iv1 >= h); + if (cubeWrap) + { + tcOut = wrapCubeMap(face, iu0, iu1, iv0, iv1, w); + tcOut += 6 * tz * w * h; // Bring in tz. + return make_float2(u, v); // Done. + } + + // Fold cube map face into tz. + if (CUBE_MODE) + tz = 6 * tz + face; + + // Wrap overflowing texel indices. + if (!CUBE_MODE && p.boundaryMode == TEX_BOUNDARY_MODE_WRAP) + { + if (iu0 < 0) iu0 += w; + if (iv0 < 0) iv0 += h; + if (iu1 >= w) iu1 -= w; + if (iv1 >= h) iv1 -= h; + } + + // Coordinates with tz folded in. + int iu0z = iu0 + tz * w * h; + int iu1z = iu1 + tz * w * h; + tcOut.x = iu0z + w * iv0; + tcOut.y = iu1z + w * iv0; + tcOut.z = iu0z + w * iv1; + tcOut.w = iu1z + w * iv1; + + // Invalidate texture addresses outside unit square if we are in zero mode. + if (!CUBE_MODE && p.boundaryMode == TEX_BOUNDARY_MODE_ZERO) + { + bool iu0_out = (iu0 < 0 || iu0 >= w); + bool iu1_out = (iu1 < 0 || iu1 >= w); + bool iv0_out = (iv0 < 0 || iv0 >= h); + bool iv1_out = (iv1 < 0 || iv1 >= h); + if (iu0_out || iv0_out) tcOut.x = -1; + if (iu1_out || iv0_out) tcOut.y = -1; + if (iu0_out || iv1_out) tcOut.z = -1; + if (iu1_out || iv1_out) tcOut.w = -1; + } + + // All done. + return make_float2(u, v); +} + +//------------------------------------------------------------------------ +// Mip level calculation. + +template +static __device__ __forceinline__ void calculateMipLevel(int& level0, int& level1, float& flevel, const TextureKernelParams& p, int pidx, float3 uv, float4* pdw, float3* pdfdv) +{ + // Do nothing if mips not in use. + if (FILTER_MODE == TEX_MODE_NEAREST || FILTER_MODE == TEX_MODE_LINEAR) + return; + + // Determine mip level based on UV pixel derivatives. If no derivatives are given (mip level bias only), leave as zero. + if (!BIAS_ONLY) + { + // Get pixel derivatives of texture coordinates. + float4 uvDA; + float3 dvdX, dvdY; // Gradients use these later. + if (CUBE_MODE) + { + // Fetch. + float2 d0 = ((const float2*)p.uvDA)[3 * pidx + 0]; + float2 d1 = ((const float2*)p.uvDA)[3 * pidx + 1]; + float2 d2 = ((const float2*)p.uvDA)[3 * pidx + 2]; + + // Map d{x,y,z}/d{X,Y} into d{s,t}/d{X,Y}. + dvdX = make_float3(d0.x, d1.x, d2.x); // d{x,y,z}/dX + dvdY = make_float3(d0.y, d1.y, d2.y); // d{x,y,z}/dY + uvDA = indexCubeMapGradST(uv, dvdX, dvdY); // d{s,t}/d{X,Y} + } + else + { + // Fetch. + uvDA = ((const float4*)p.uvDA)[pidx]; + } + + // Scaling factors. + float uscl = p.texWidth; + float vscl = p.texHeight; + + // d[s,t]/d[X,Y]. + float dsdx = uvDA.x * uscl; + float dsdy = uvDA.y * uscl; + float dtdx = uvDA.z * vscl; + float dtdy = uvDA.w * vscl; + + // Calculate footprint axis lengths. + float A = dsdx*dsdx + dtdx*dtdx; + float B = dsdy*dsdy + dtdy*dtdy; + float C = dsdx*dsdy + dtdx*dtdy; + float l2b = 0.5 * (A + B); + float l2n = 0.25 * (A-B)*(A-B) + C*C; + float l2a = sqrt(l2n); + float lenMinorSqr = fmaxf(0.0, l2b - l2a); + float lenMajorSqr = l2b + l2a; + + // Footprint vs. mip level gradient. + if (pdw && FILTER_MODE == TEX_MODE_LINEAR_MIPMAP_LINEAR) + { + float dw = 0.72134752f / (l2n + l2a * l2b); // Constant is 0.5/ln(2). + float AB = dw * .5f * (A - B); + float Cw = dw * C; + float l2aw = dw * l2a; + float d_f_ddsdX = uscl * (dsdx * (l2aw + AB) + dsdy * Cw); + float d_f_ddsdY = uscl * (dsdy * (l2aw - AB) + dsdx * Cw); + float d_f_ddtdX = vscl * (dtdx * (l2aw + AB) + dtdy * Cw); + float d_f_ddtdY = vscl * (dtdy * (l2aw - AB) + dtdx * Cw); + + float4 d_f_dw = make_float4(d_f_ddsdX, d_f_ddsdY, d_f_ddtdX, d_f_ddtdY); + if (!CUBE_MODE) + *pdw = isfinite_vec4(d_f_dw) ? d_f_dw : make_float4(0.f, 0.f, 0.f, 0.f); + + // In cube maps, there is also a texture coordinate vs. mip level gradient. + // Only output nonzero vectors if both are free of inf/Nan garbage. + if (CUBE_MODE) + { + float4 dx, dy, dz; + indexCubeMapGrad2(uv, dvdX, dvdY, dx, dy, dz); + float3 d_dsdX_dv = make_float3(dx.x, dy.x, dz.x); + float3 d_dsdY_dv = make_float3(dx.y, dy.y, dz.y); + float3 d_dtdX_dv = make_float3(dx.z, dy.z, dz.z); + float3 d_dtdY_dv = make_float3(dx.w, dy.w, dz.w); + + float3 d_f_dv = make_float3(0.f, 0.f, 0.f); + d_f_dv += d_dsdX_dv * d_f_ddsdX; + d_f_dv += d_dsdY_dv * d_f_ddsdY; + d_f_dv += d_dtdX_dv * d_f_ddtdX; + d_f_dv += d_dtdY_dv * d_f_ddtdY; + + bool finite = isfinite_vec4(d_f_dw) && isfinite_vec3(d_f_dv); + *pdw = finite ? d_f_dw : make_float4(0.f, 0.f, 0.f, 0.f); + *pdfdv = finite ? d_f_dv : make_float3(0.f, 0.f, 0.f); + } + } + + // Finally, calculate mip level. + flevel = .5f * __log2f(lenMajorSqr); // May be inf/NaN, but clamp fixes it. + } + + // Bias the mip level and clamp. + if (p.mipLevelBias) + flevel += p.mipLevelBias[pidx]; + flevel = fminf(fmaxf(flevel, 0.f), (float)p.mipLevelMax); + + // Calculate levels depending on filter mode. + level0 = __float2int_rd(flevel); + + // Leave everything else at zero if flevel == 0 (magnification) or when in linear-mipmap-nearest mode. + if (FILTER_MODE == TEX_MODE_LINEAR_MIPMAP_LINEAR && flevel > 0.f) + { + level1 = min(level0 + 1, p.mipLevelMax); + flevel -= level0; // Fractional part. Zero if clamped on last level. + } +} + +//------------------------------------------------------------------------ +// Texel fetch and accumulator helpers that understand cube map corners. + +template +static __device__ __forceinline__ void fetchQuad(T& a00, T& a10, T& a01, T& a11, const float* pIn, int4 tc, bool corner) +{ + // For invalid cube map uv, tc will be all negative, and all texel values will be zero. + if (corner) + { + T avg = zero_value(); + if (tc.x >= 0) avg += (a00 = *((const T*)&pIn[tc.x])); + if (tc.y >= 0) avg += (a10 = *((const T*)&pIn[tc.y])); + if (tc.z >= 0) avg += (a01 = *((const T*)&pIn[tc.z])); + if (tc.w >= 0) avg += (a11 = *((const T*)&pIn[tc.w])); + avg *= 0.33333333f; + if (tc.x < 0) a00 = avg; + if (tc.y < 0) a10 = avg; + if (tc.z < 0) a01 = avg; + if (tc.w < 0) a11 = avg; + } + else + { + a00 = (tc.x >= 0) ? *((const T*)&pIn[tc.x]) : zero_value(); + a10 = (tc.y >= 0) ? *((const T*)&pIn[tc.y]) : zero_value(); + a01 = (tc.z >= 0) ? *((const T*)&pIn[tc.z]) : zero_value(); + a11 = (tc.w >= 0) ? *((const T*)&pIn[tc.w]) : zero_value(); + } +} + +static __device__ __forceinline__ void accumQuad(float4 c, float* pOut, int level, int4 tc, bool corner, CA_TEMP_PARAM) +{ + // For invalid cube map uv, tc will be all negative, and no accumulation will take place. + if (corner) + { + float cb; + if (tc.x < 0) cb = c.x; + if (tc.y < 0) cb = c.y; + if (tc.z < 0) cb = c.z; + if (tc.w < 0) cb = c.w; + cb *= 0.33333333f; + if (tc.x >= 0) caAtomicAddTexture(pOut, level, tc.x, c.x + cb); + if (tc.y >= 0) caAtomicAddTexture(pOut, level, tc.y, c.y + cb); + if (tc.z >= 0) caAtomicAddTexture(pOut, level, tc.z, c.z + cb); + if (tc.w >= 0) caAtomicAddTexture(pOut, level, tc.w, c.w + cb); + } + else + { + if (tc.x >= 0) caAtomicAddTexture(pOut, level, tc.x, c.x); + if (tc.y >= 0) caAtomicAddTexture(pOut, level, tc.y, c.y); + if (tc.z >= 0) caAtomicAddTexture(pOut, level, tc.z, c.z); + if (tc.w >= 0) caAtomicAddTexture(pOut, level, tc.w, c.w); + } +} + +//------------------------------------------------------------------------ +// Mip builder kernel. + +template +static __forceinline__ __device__ void MipBuildKernelTemplate(const TextureKernelParams p) +{ + // Sizes. + int2 sz_in = mipLevelSize(p, p.mipLevelOut - 1); + int2 sz_out = mipLevelSize(p, p.mipLevelOut); + + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= sz_out.x || py >= sz_out.y) + return; + + // Pixel indices. + int pidx_in0 = p.channels * (((px + sz_in.x * py) << 1) + (pz * sz_in.x * sz_in.y)); + int pidx_in1 = pidx_in0 + p.channels * sz_in.x; // Next pixel down. + int pidx_out = p.channels * (px + sz_out.x * (py + sz_out.y * pz)); + + // Input and output pointers. + const float* pin = p.tex[p.mipLevelOut - 1]; + float* pout = (float*)p.tex[p.mipLevelOut]; + + // Special case: Input texture height or width is 1. + if (sz_in.x == 1 || sz_in.y == 1) + { + if (sz_in.y == 1) + pidx_in1 = pidx_in0 + p.channels; // Next pixel on the right. + + for (int i=0; i < p.channels; i += C) + { + T v0 = *((const T*)&pin[pidx_in0 + i]); + T v1 = *((const T*)&pin[pidx_in1 + i]); + T avg = .5f * (v0 + v1); +#if TEX_DEBUG_MIP_RETAIN_VARIANCE + avg = (avg - .5f) * 1.41421356f + .5f; +#endif + *((T*)&pout[pidx_out + i]) = avg; + } + + return; + } + + for (int i=0; i < p.channels; i += C) + { + T v0 = *((const T*)&pin[pidx_in0 + i]); + T v1 = *((const T*)&pin[pidx_in0 + i + p.channels]); + T v2 = *((const T*)&pin[pidx_in1 + i]); + T v3 = *((const T*)&pin[pidx_in1 + i + p.channels]); + T avg = .25f * (v0 + v1 + v2 + v3); +#if TEX_DEBUG_MIP_RETAIN_VARIANCE + avg = (avg - .5f) * 2.f + .5f; +#endif + *((T*)&pout[pidx_out + i]) = avg; + } +} + +// Template specializations. +__global__ void MipBuildKernel1(const TextureKernelParams p) { MipBuildKernelTemplate(p); } +__global__ void MipBuildKernel2(const TextureKernelParams p) { MipBuildKernelTemplate(p); } +__global__ void MipBuildKernel4(const TextureKernelParams p) { MipBuildKernelTemplate(p); } + +//------------------------------------------------------------------------ +// Forward kernel. + +template +static __forceinline__ __device__ void TextureFwdKernelTemplate(const TextureKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + int tz = (p.texDepth == 1) ? 0 : pz; + if (px >= p.imgWidth || py >= p.imgHeight || pz >= p.n) + return; + + // Pixel index. + int pidx = px + p.imgWidth * (py + p.imgHeight * pz); + + // Output ptr. + float* pOut = p.out + pidx * p.channels; + + // Get UV. + float3 uv; + if (CUBE_MODE) + uv = ((const float3*)p.uv)[pidx]; + else + uv = make_float3(((const float2*)p.uv)[pidx], 0.f); + + // Nearest mode. + if (FILTER_MODE == TEX_MODE_NEAREST) + { + int tc = indexTextureNearest(p, uv, tz); + tc *= p.channels; + const float* pIn = p.tex[0]; + + // Copy if valid tc, otherwise output zero. + for (int i=0; i < p.channels; i += C) + *((T*)&pOut[i]) = (tc >= 0) ? *((const T*)&pIn[tc + i]) : zero_value(); + + return; // Exit. + } + + // Calculate mip level. In 'linear' mode these will all stay zero. + float flevel = 0.f; // Fractional level. + int level0 = 0; // Discrete level 0. + int level1 = 0; // Discrete level 1. + calculateMipLevel(level0, level1, flevel, p, pidx, uv, 0, 0); + + // Get texel indices and pointer for level 0. + int4 tc0 = make_int4(0, 0, 0, 0); + float2 uv0 = indexTextureLinear(p, uv, tz, tc0, level0); + const float* pIn0 = p.tex[level0]; + bool corner0 = CUBE_MODE && ((tc0.x | tc0.y | tc0.z | tc0.w) < 0); + tc0 *= p.channels; + + // Bilinear fetch. + if (FILTER_MODE == TEX_MODE_LINEAR || FILTER_MODE == TEX_MODE_LINEAR_MIPMAP_NEAREST) + { + // Interpolate. + for (int i=0; i < p.channels; i += C, tc0 += C) + { + T a00, a10, a01, a11; + fetchQuad(a00, a10, a01, a11, pIn0, tc0, corner0); + *((T*)&pOut[i]) = bilerp(a00, a10, a01, a11, uv0); + } + return; // Exit. + } + + // Get texel indices and pointer for level 1. + int4 tc1 = make_int4(0, 0, 0, 0); + float2 uv1 = indexTextureLinear(p, uv, tz, tc1, level1); + const float* pIn1 = p.tex[level1]; + bool corner1 = CUBE_MODE && ((tc1.x | tc1.y | tc1.z | tc1.w) < 0); + tc1 *= p.channels; + + // Trilinear fetch. + for (int i=0; i < p.channels; i += C, tc0 += C, tc1 += C) + { + // First level. + T a00, a10, a01, a11; + fetchQuad(a00, a10, a01, a11, pIn0, tc0, corner0); + T a = bilerp(a00, a10, a01, a11, uv0); + + // Second level unless in magnification mode. + if (flevel > 0.f) + { + T b00, b10, b01, b11; + fetchQuad(b00, b10, b01, b11, pIn1, tc1, corner1); + T b = bilerp(b00, b10, b01, b11, uv1); + a = lerp(a, b, flevel); // Interpolate between levels. + } + + // Write. + *((T*)&pOut[i]) = a; + } +} + +// Template specializations. +__global__ void TextureFwdKernelNearest1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelNearest2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelNearest4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinear1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinear2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinear4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapNearest1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapNearest2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapNearest4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapLinear1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapLinear2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapLinear4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeNearest1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeNearest2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeNearest4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinear1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinear2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinear4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapNearest1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapNearest2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapNearest4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapLinear1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapLinear2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapLinear4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapNearestBO1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapNearestBO2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapNearestBO4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapLinearBO1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapLinearBO2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelLinearMipmapLinearBO4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapNearestBO1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapNearestBO2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapNearestBO4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapLinearBO1 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapLinearBO2 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } +__global__ void TextureFwdKernelCubeLinearMipmapLinearBO4 (const TextureKernelParams p) { TextureFwdKernelTemplate(p); } + +//------------------------------------------------------------------------ +// Gradient mip puller kernel. + +template +static __forceinline__ __device__ void MipGradKernelTemplate(const TextureKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.texWidth || py >= p.texHeight) + return; + + // Number of wide elements. + int c = p.channels; + if (C == 2) c >>= 1; + if (C == 4) c >>= 2; + + // Dynamically allocated shared memory for holding a texel. + extern __shared__ float s_texelAccum[]; + int sharedOfs = threadIdx.x + threadIdx.y * blockDim.x; + int sharedStride = blockDim.x * blockDim.y; +# define TEXEL_ACCUM(_i) (s_texelAccum + (sharedOfs + (_i) * sharedStride)) + + // Clear the texel. + for (int i=0; i < p.channels; i++) + *TEXEL_ACCUM(i) = 0.f; + + // Track texel position and accumulation weight over the mip stack. + int x = px; + int y = py; + float w = 1.f; + + // Pull gradients from all levels. + int2 sz = mipLevelSize(p, 0); // Previous level size. + for (int level=1; level <= p.mipLevelMax; level++) + { + // Weight decay depends on previous level size. + if (sz.x > 1) w *= .5f; + if (sz.y > 1) w *= .5f; + + // Current level size and coordinates. + sz = mipLevelSize(p, level); + x >>= 1; + y >>= 1; + + T* pIn = (T*)(p.gradTex[level] + (x + sz.x * (y + sz.y * pz)) * p.channels); + for (int i=0; i < c; i++) + accum_from_mem(TEXEL_ACCUM(i * C), sharedStride, pIn[i], w); + } + + // Add to main texture gradients. + T* pOut = (T*)(p.gradTex[0] + (px + p.texWidth * (py + p.texHeight * pz)) * p.channels); + for (int i=0; i < c; i++) + accum_to_mem(pOut[i], TEXEL_ACCUM(i * C), sharedStride); +} + +// Template specializations. +__global__ void MipGradKernel1(const TextureKernelParams p) { MipGradKernelTemplate(p); } +__global__ void MipGradKernel2(const TextureKernelParams p) { MipGradKernelTemplate(p); } +__global__ void MipGradKernel4(const TextureKernelParams p) { MipGradKernelTemplate(p); } + +//------------------------------------------------------------------------ +// Gradient kernel. + +template +static __forceinline__ __device__ void TextureGradKernelTemplate(const TextureKernelParams p) +{ + // Temporary space for coalesced atomics. + CA_DECLARE_TEMP(TEX_GRAD_MAX_KERNEL_BLOCK_WIDTH * TEX_GRAD_MAX_KERNEL_BLOCK_HEIGHT); + + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + int tz = (p.texDepth == 1) ? 0 : pz; + if (px >= p.imgWidth || py >= p.imgHeight || pz >= p.n) + return; + + // Pixel index. + int pidx = px + p.imgWidth * (py + p.imgHeight * pz); + + // Early exit if output gradients are zero. + const float* pDy = p.dy + pidx * p.channels; + unsigned int dmax = 0u; + if ((p.channels & 3) == 0) + { + for (int i=0; i < p.channels; i += 4) + { + uint4 dy = *((const uint4*)&pDy[i]); + dmax |= (dy.x | dy.y | dy.z | dy.w); + } + } + else + { + for (int i=0; i < p.channels; i++) + dmax |= __float_as_uint(pDy[i]); + } + + // Store zeros and exit. + if (__uint_as_float(dmax) == 0.f) + { + if (CUBE_MODE) + { + if (FILTER_MODE != TEX_MODE_NEAREST) + ((float3*)p.gradUV)[pidx] = make_float3(0.f, 0.f, 0.f); + if (FILTER_MODE == TEX_MODE_LINEAR_MIPMAP_LINEAR) + { + if (p.gradUVDA) + { + ((float2*)p.gradUVDA)[3 * pidx + 0] = make_float2(0.f, 0.f); + ((float2*)p.gradUVDA)[3 * pidx + 1] = make_float2(0.f, 0.f); + ((float2*)p.gradUVDA)[3 * pidx + 2] = make_float2(0.f, 0.f); + } + if (p.gradMipLevelBias) + p.gradMipLevelBias[pidx] = 0.f; + } + } + else + { + if (FILTER_MODE != TEX_MODE_NEAREST) + ((float2*)p.gradUV)[pidx] = make_float2(0.f, 0.f); + if (FILTER_MODE == TEX_MODE_LINEAR_MIPMAP_LINEAR) + { + if (p.gradUVDA) + ((float4*)p.gradUVDA)[pidx] = make_float4(0.f, 0.f, 0.f, 0.f); + if (p.gradMipLevelBias) + p.gradMipLevelBias[pidx] = 0.f; + } + } + return; + } + + // Get UV. + float3 uv; + if (CUBE_MODE) + uv = ((const float3*)p.uv)[pidx]; + else + uv = make_float3(((const float2*)p.uv)[pidx], 0.f); + + // Nearest mode - texture gradients only. + if (FILTER_MODE == TEX_MODE_NEAREST) + { + int tc = indexTextureNearest(p, uv, tz); + if (tc < 0) + return; // Outside texture. + + tc *= p.channels; + float* pOut = p.gradTex[0]; + + // Accumulate texture gradients. + for (int i=0; i < p.channels; i++) + caAtomicAddTexture(pOut, 0, tc + i, pDy[i]); + + return; // Exit. + } + + // Calculate mip level. In 'linear' mode these will all stay zero. + float4 dw = make_float4(0.f, 0.f, 0.f, 0.f); + float3 dfdv = make_float3(0.f, 0.f, 0.f); + float flevel = 0.f; // Fractional level. + int level0 = 0; // Discrete level 0. + int level1 = 0; // Discrete level 1. + calculateMipLevel(level0, level1, flevel, p, pidx, uv, &dw, &dfdv); + + // UV gradient accumulators. + float gu = 0.f; + float gv = 0.f; + + // Get texel indices and pointers for level 0. + int4 tc0 = make_int4(0, 0, 0, 0); + float2 uv0 = indexTextureLinear(p, uv, tz, tc0, level0); + const float* pIn0 = p.tex[level0]; + float* pOut0 = p.gradTex[level0]; + bool corner0 = CUBE_MODE && ((tc0.x | tc0.y | tc0.z | tc0.w) < 0); + tc0 *= p.channels; + + // Texel weights. + float uv011 = uv0.x * uv0.y; + float uv010 = uv0.x - uv011; + float uv001 = uv0.y - uv011; + float uv000 = 1.f - uv0.x - uv001; + float4 tw0 = make_float4(uv000, uv010, uv001, uv011); + + // Attribute weights. + int2 sz0 = mipLevelSize(p, level0); + float sclu0 = (float)sz0.x; + float sclv0 = (float)sz0.y; + + // Bilinear mode - texture and uv gradients. + if (FILTER_MODE == TEX_MODE_LINEAR || FILTER_MODE == TEX_MODE_LINEAR_MIPMAP_NEAREST) + { + for (int i=0; i < p.channels; i++, tc0 += 1) + { + float dy = pDy[i]; + accumQuad(tw0 * dy, pOut0, level0, tc0, corner0, CA_TEMP); + + float a00, a10, a01, a11; + fetchQuad(a00, a10, a01, a11, pIn0, tc0, corner0); + float ad = (a11 + a00 - a10 - a01); + gu += dy * ((a10 - a00) + uv0.y * ad) * sclu0; + gv += dy * ((a01 - a00) + uv0.x * ad) * sclv0; + } + + // Store UV gradients and exit. + if (CUBE_MODE) + ((float3*)p.gradUV)[pidx] = indexCubeMapGrad(uv, gu, gv); + else + ((float2*)p.gradUV)[pidx] = make_float2(gu, gv); + + return; + } + + // Accumulate fractional mip level gradient. + float df = 0; // dL/df. + + // Get texel indices and pointers for level 1. + int4 tc1 = make_int4(0, 0, 0, 0); + float2 uv1 = indexTextureLinear(p, uv, tz, tc1, level1); + const float* pIn1 = p.tex[level1]; + float* pOut1 = p.gradTex[level1]; + bool corner1 = CUBE_MODE && ((tc1.x | tc1.y | tc1.z | tc1.w) < 0); + tc1 *= p.channels; + + // Texel weights. + float uv111 = uv1.x * uv1.y; + float uv110 = uv1.x - uv111; + float uv101 = uv1.y - uv111; + float uv100 = 1.f - uv1.x - uv101; + float4 tw1 = make_float4(uv100, uv110, uv101, uv111); + + // Attribute weights. + int2 sz1 = mipLevelSize(p, level1); + float sclu1 = (float)sz1.x; + float sclv1 = (float)sz1.y; + + // Trilinear mode. + for (int i=0; i < p.channels; i++, tc0 += 1, tc1 += 1) + { + float dy = pDy[i]; + float dy0 = (1.f - flevel) * dy; + accumQuad(tw0 * dy0, pOut0, level0, tc0, corner0, CA_TEMP); + + // UV gradients for first level. + float a00, a10, a01, a11; + fetchQuad(a00, a10, a01, a11, pIn0, tc0, corner0); + float ad = (a11 + a00 - a10 - a01); + gu += dy0 * ((a10 - a00) + uv0.y * ad) * sclu0; + gv += dy0 * ((a01 - a00) + uv0.x * ad) * sclv0; + + // Second level unless in magnification mode. + if (flevel > 0.f) + { + // Texture gradients for second level. + float dy1 = flevel * dy; + accumQuad(tw1 * dy1, pOut1, level1, tc1, corner1, CA_TEMP); + + // UV gradients for second level. + float b00, b10, b01, b11; + fetchQuad(b00, b10, b01, b11, pIn1, tc1, corner1); + float bd = (b11 + b00 - b10 - b01); + gu += dy1 * ((b10 - b00) + uv1.y * bd) * sclu1; + gv += dy1 * ((b01 - b00) + uv1.x * bd) * sclv1; + + // Mip level gradient. + float a = bilerp(a00, a10, a01, a11, uv0); + float b = bilerp(b00, b10, b01, b11, uv1); + df += (b-a) * dy; + } + } + + // Store UV gradients. + if (CUBE_MODE) + ((float3*)p.gradUV)[pidx] = indexCubeMapGrad(uv, gu, gv) + (dfdv * df); + else + ((float2*)p.gradUV)[pidx] = make_float2(gu, gv); + + // Store mip level bias gradient. + if (p.gradMipLevelBias) + p.gradMipLevelBias[pidx] = df; + + // Store UV pixel differential gradients. + if (!BIAS_ONLY) + { + // Final gradients. + dw *= df; // dL/(d{s,y}/d{X,Y}) = df/(d{s,y}/d{X,Y}) * dL/df. + + // Store them. + if (CUBE_MODE) + { + // Remap from dL/(d{s,t}/s{X,Y}) to dL/(d{x,y,z}/d{X,Y}). + float3 g0, g1; + indexCubeMapGrad4(uv, dw, g0, g1); + ((float2*)p.gradUVDA)[3 * pidx + 0] = make_float2(g0.x, g1.x); + ((float2*)p.gradUVDA)[3 * pidx + 1] = make_float2(g0.y, g1.y); + ((float2*)p.gradUVDA)[3 * pidx + 2] = make_float2(g0.z, g1.z); + } + else + ((float4*)p.gradUVDA)[pidx] = dw; + } +} + +// Template specializations. +__global__ void TextureGradKernelNearest (const TextureKernelParams p) { TextureGradKernelTemplate(p); } +__global__ void TextureGradKernelLinear (const TextureKernelParams p) { TextureGradKernelTemplate(p); } +__global__ void TextureGradKernelLinearMipmapNearest (const TextureKernelParams p) { TextureGradKernelTemplate(p); } +__global__ void TextureGradKernelLinearMipmapLinear (const TextureKernelParams p) { TextureGradKernelTemplate(p); } +__global__ void TextureGradKernelCubeNearest (const TextureKernelParams p) { TextureGradKernelTemplate(p); } +__global__ void TextureGradKernelCubeLinear (const TextureKernelParams p) { TextureGradKernelTemplate(p); } +__global__ void TextureGradKernelCubeLinearMipmapNearest (const TextureKernelParams p) { TextureGradKernelTemplate(p); } +__global__ void TextureGradKernelCubeLinearMipmapLinear (const TextureKernelParams p) { TextureGradKernelTemplate(p); } +__global__ void TextureGradKernelLinearMipmapNearestBO (const TextureKernelParams p) { TextureGradKernelTemplate(p); } +__global__ void TextureGradKernelLinearMipmapLinearBO (const TextureKernelParams p) { TextureGradKernelTemplate(p); } +__global__ void TextureGradKernelCubeLinearMipmapNearestBO (const TextureKernelParams p) { TextureGradKernelTemplate(p); } +__global__ void TextureGradKernelCubeLinearMipmapLinearBO (const TextureKernelParams p) { TextureGradKernelTemplate(p); } + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/lib/setgpu.lib b/extensions/nvdiffrast/nvdiffrast/lib/setgpu.lib new file mode 100644 index 0000000000000000000000000000000000000000..add9a0c4f631cb56dbee31a05ed97339930301e2 Binary files /dev/null and b/extensions/nvdiffrast/nvdiffrast/lib/setgpu.lib differ diff --git a/extensions/nvdiffrast/nvdiffrast/tensorflow/__init__.py b/extensions/nvdiffrast/nvdiffrast/tensorflow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c2ac67a712d7df7ecc40cc2d1dfd67e5f738be9f --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/tensorflow/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +from .ops import rasterize, interpolate, texture, antialias +from .plugin_loader import set_cache_dir + +__all__ = ["rasterize", "interpolate", "texture", "antialias", "set_cache_dir"] diff --git a/extensions/nvdiffrast/nvdiffrast/tensorflow/ops.py b/extensions/nvdiffrast/nvdiffrast/tensorflow/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..130e2e195e2dce3e007aaebc55cc03837cf77c02 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/tensorflow/ops.py @@ -0,0 +1,303 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import tensorflow as tf +import numpy as np +import os +from . import plugin_loader + +#---------------------------------------------------------------------------- +# Helpers. +#---------------------------------------------------------------------------- + +# OpenGL-related linker options depending on platform. +def _get_gl_opts(): + libs = { + 'posix': ['GL', 'EGL'], + 'nt': ['gdi32', 'opengl32', 'user32', 'setgpu'], + } + return ['-l' + x for x in libs[os.name]] + +# Load the cpp plugin. +def _get_plugin(): + fn = os.path.join(os.path.dirname(__file__), 'tf_all.cu') + return plugin_loader.get_plugin(fn, extra_nvcc_options=_get_gl_opts() + ['-DNVDR_TENSORFLOW']) + +# Convert parameter to a numpy array if possible. +def _get_constant(x, dtype): + try: + return np.asarray(x, dtype=dtype) + except (TypeError, ValueError): + return None + +# Tests for a construction-time constantness instead of tf.constant node because +# the latter can be overridden in Session.run() feed_dict at evaluation time. +def _is_constant(x, dtype): + if isinstance(x, np.ndarray): + return np.can_cast(x.dtype, dtype, 'unsafe') + else: + return _get_constant(x, dtype) is not None + +#---------------------------------------------------------------------------- +# Rasterize. +#---------------------------------------------------------------------------- + +def rasterize(pos, tri, resolution, ranges=None, tri_const=False, output_db=True, grad_db=True): + assert tri_const is True or tri_const is False + assert output_db is True or output_db is False + + # Known constant resolution? + resolution_c = _get_constant(resolution, np.int32) + + # Known constant triangles? + tri_const = tri_const or _is_constant(tri, np.int32) + + # Convert all inputs to tensors / base types. + tri_const = 1 if tri_const else 0 + tri = tf.convert_to_tensor(tri, dtype=tf.int32) + pos = tf.convert_to_tensor(pos, dtype=tf.float32) + resolution = tf.convert_to_tensor(resolution, dtype=tf.int32) + if ranges is None: + ranges = tf.convert_to_tensor(np.zeros(shape=[0, 2], dtype=np.int32)) # Empty tensor. + else: + ranges = tf.convert_to_tensor(ranges, dtype=tf.int32) # Convert input to tensor. + + # Infer as much about the output shape as possible. + out_shape = [None, None, None, 4] + if pos.shape.rank == 3: # Instanced mode. + out_shape[0] = pos.shape[0].value + elif pos.shape.rank == 2: # Range mode. + if ranges.shape.rank not in [None, 0]: + out_shape[0] = ranges.shape[0].value + if resolution_c is not None: + assert resolution_c.shape == (2,) + out_shape[1], out_shape[2] = resolution_c + + # Output pixel differentials. + @tf.custom_gradient + def func_db(pos): + out, out_db = _get_plugin().rasterize_fwd(pos, tri, resolution, ranges, 1, tri_const) + out.set_shape(out_shape) + out_db.set_shape(out_shape) + def grad(dy, ddb): + if grad_db: + return _get_plugin().rasterize_grad_db(pos, tri, out, dy, ddb) + else: + return _get_plugin().rasterize_grad(pos, tri, out, dy) + return (out, out_db), grad + + # Do not output pixel differentials. + @tf.custom_gradient + def func(pos): + out, out_db = _get_plugin().rasterize_fwd(pos, tri, resolution, ranges, 0, tri_const) + out.set_shape(out_shape) + out_db.set_shape(out_shape[:-1] + [0]) # Zero channels in out_db. + def grad(dy, _): + return _get_plugin().rasterize_grad(pos, tri, out, dy) + return (out, out_db), grad + + # Choose stub. + if output_db: + return func_db(pos) + else: + return func(pos) + +#---------------------------------------------------------------------------- +# Interpolate. +#---------------------------------------------------------------------------- + +def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None): + # Sanitize the list of pixel differential attributes. + if diff_attrs is None: + diff_attrs = [] + elif diff_attrs != 'all': + diff_attrs = _get_constant(diff_attrs, np.int32) + assert (diff_attrs is not None) and len(diff_attrs.shape) == 1 + diff_attrs = diff_attrs.tolist() + + # Convert all inputs to tensors. + attr = tf.convert_to_tensor(attr, dtype=tf.float32) + rast = tf.convert_to_tensor(rast, dtype=tf.float32) + tri = tf.convert_to_tensor(tri, dtype=tf.int32) + if diff_attrs: + rast_db = tf.convert_to_tensor(rast_db, dtype=tf.float32) + + # Infer output shape. + out_shape = [None, None, None, None] + if rast.shape.rank is not None: + out_shape = [rast.shape[0].value, rast.shape[1].value, rast.shape[2].value, None] + if attr.shape.rank in [2, 3]: + out_shape[3] = attr.shape[-1].value + + # Output pixel differentials for at least some attributes. + @tf.custom_gradient + def func_da(attr, rast, rast_db): + diff_attrs_all = int(diff_attrs == 'all') + diff_attrs_list = [] if diff_attrs_all else diff_attrs + out, out_da = _get_plugin().interpolate_fwd_da(attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list) + + # Infer number of channels in out_da. + if not diff_attrs_all: + da_channels = 2 * len(diff_attrs) + if (attr.shape.rank in [2, 3]) and (attr.shape[-1].value is not None): + da_channels = 2 * attr.shape[-1].value + else: + da_channels = None + + # Set output shapes. + out.set_shape(out_shape) + out_da.set_shape([out_shape[0], out_shape[1], out_shape[2], da_channels]) + + def grad(dy, dda): + return _get_plugin().interpolate_grad_da(attr, rast, tri, dy, rast_db, dda, diff_attrs_all, diff_attrs_list) + return (out, out_da), grad + + # No pixel differentials for any attribute. + @tf.custom_gradient + def func(attr, rast): + out, out_da = _get_plugin().interpolate_fwd(attr, rast, tri) + out.set_shape(out_shape) + out_da.set_shape(out_shape[:-1] + [0]) # Zero channels in out_da. + def grad(dy, _): + return _get_plugin().interpolate_grad(attr, rast, tri, dy) + return (out, out_da), grad + + # Choose stub. + if diff_attrs: + return func_da(attr, rast, rast_db) + else: + return func(attr, rast) + +#---------------------------------------------------------------------------- +# Texture. +#---------------------------------------------------------------------------- + +def texture(tex, uv, uv_da=None, filter_mode='auto', boundary_mode='wrap', tex_const=False, max_mip_level=None): + assert tex_const is True or tex_const is False + + # Default filter mode. + if filter_mode == 'auto': + filter_mode = 'linear-mipmap-linear' if (uv_da is not None) else 'linear' + + # Known constant texture? + tex_const = tex_const or _is_constant(tex, np.float32) + + # Sanitize inputs. + tex_const = 1 if tex_const else 0 + if max_mip_level is None: + max_mip_level = -1 + else: + max_mip_level = int(max_mip_level) + assert max_mip_level >= 0 + + # Convert inputs to tensors. + tex = tf.convert_to_tensor(tex, dtype=tf.float32) + uv = tf.convert_to_tensor(uv, dtype=tf.float32) + if 'mipmap' in filter_mode: + uv_da = tf.convert_to_tensor(uv_da, dtype=tf.float32) + + # Infer output shape. + out_shape = [None, None, None, None] + if uv.shape.rank is not None: + assert uv.shape.rank == 4 + out_shape = [uv.shape[0].value, uv.shape[1].value, uv.shape[2].value, None] + if tex.shape.rank is not None: + assert tex.shape.rank == (5 if boundary_mode == 'cube' else 4) + out_shape[-1] = tex.shape[-1].value + + # If mipping disabled via max level=0, we may as well use simpler filtering internally. + if max_mip_level == 0 and filter_mode in ['linear-mipmap-nearest', 'linear-mipmap-linear']: + filter_mode = 'linear' + + # Convert filter mode to internal enumeration. + filter_mode_dict = {'nearest': 0, 'linear': 1, 'linear-mipmap-nearest': 2, 'linear-mipmap-linear': 3} + filter_mode_enum = filter_mode_dict[filter_mode] + + # Convert boundary mode to internal enumeration. + boundary_mode_dict = {'cube': 0, 'wrap': 1, 'clamp': 2, 'zero': 3} + boundary_mode_enum = boundary_mode_dict[boundary_mode] + + # Linear-mipmap-linear: Mipmaps enabled, all gradients active. + @tf.custom_gradient + def func_linear_mipmap_linear(tex, uv, uv_da): + out, mip = _get_plugin().texture_fwd_mip(tex, uv, uv_da, filter_mode_enum, boundary_mode_enum, tex_const, max_mip_level) + out.set_shape(out_shape) + def grad(dy): + return _get_plugin().texture_grad_linear_mipmap_linear(tex, uv, dy, uv_da, mip, filter_mode_enum, boundary_mode_enum, max_mip_level) + return out, grad + + # Linear-mipmap-nearest: Mipmaps enabled, no gradients to uv_da. + @tf.custom_gradient + def func_linear_mipmap_nearest(tex, uv): + out, mip = _get_plugin().texture_fwd_mip(tex, uv, uv_da, filter_mode_enum, boundary_mode_enum, tex_const, max_mip_level) + out.set_shape(out_shape) + def grad(dy): + return _get_plugin().texture_grad_linear_mipmap_nearest(tex, uv, dy, uv_da, mip, filter_mode_enum, boundary_mode_enum, max_mip_level) + return out, grad + + # Linear: Mipmaps disabled, no uv_da, no gradients to uv_da. + @tf.custom_gradient + def func_linear(tex, uv): + out = _get_plugin().texture_fwd(tex, uv, filter_mode_enum, boundary_mode_enum) + out.set_shape(out_shape) + def grad(dy): + return _get_plugin().texture_grad_linear(tex, uv, dy, filter_mode_enum, boundary_mode_enum) + return out, grad + + # Nearest: Mipmaps disabled, no uv_da, no gradients to uv_da or uv. + @tf.custom_gradient + def func_nearest(tex): + out = _get_plugin().texture_fwd(tex, uv, filter_mode_enum, boundary_mode_enum) + out.set_shape(out_shape) + def grad(dy): + return _get_plugin().texture_grad_nearest(tex, uv, dy, filter_mode_enum, boundary_mode_enum) + return out, grad + + # Choose stub. + if filter_mode == 'linear-mipmap-linear': + return func_linear_mipmap_linear(tex, uv, uv_da) + elif filter_mode == 'linear-mipmap-nearest': + return func_linear_mipmap_nearest(tex, uv) + elif filter_mode == 'linear': + return func_linear(tex, uv) + elif filter_mode == 'nearest': + return func_nearest(tex) + +#---------------------------------------------------------------------------- +# Antialias. +#---------------------------------------------------------------------------- + +def antialias(color, rast, pos, tri, tri_const=False, pos_gradient_boost=1.0): + assert tri_const is True or tri_const is False + + # Known constant triangles? + tri_const = tri_const or _is_constant(tri, np.int32) + + # Convert inputs to tensors. + color = tf.convert_to_tensor(color, dtype=tf.float32) + rast = tf.convert_to_tensor(rast, dtype=tf.float32) + pos = tf.convert_to_tensor(pos, dtype=tf.float32) + tri = tf.convert_to_tensor(tri, dtype=tf.int32) + + # Sanitize inputs. + tri_const = 1 if tri_const else 0 + + @tf.custom_gradient + def func(color, pos): + color_out, work_buffer = _get_plugin().antialias_fwd(color, rast, pos, tri, tri_const) + color_out.set_shape(color.shape) + def grad(dy): + grad_color, grad_pos = _get_plugin().antialias_grad(color, rast, pos, tri, dy, work_buffer) + if pos_gradient_boost != 1.0: + grad_pos = grad_pos * pos_gradient_boost + return grad_color, grad_pos + return color_out, grad + + return func(color, pos) + +#---------------------------------------------------------------------------- diff --git a/extensions/nvdiffrast/nvdiffrast/tensorflow/plugin_loader.py b/extensions/nvdiffrast/nvdiffrast/tensorflow/plugin_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..1347df8a5a276298dbeff81e7b06f62e6765a3ef --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/tensorflow/plugin_loader.py @@ -0,0 +1,219 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import glob +import os +import re +import uuid +import hashlib +import tempfile +import shutil +import tensorflow as tf +from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module + +#---------------------------------------------------------------------------- +# Global options. + +_nvdiffrast_cache_dir = None + +def set_cache_dir(path: str) -> None: + '''Set CUDA kernel compilation temp dir. + + If `set_cache_dir` is not called, the cache directory will default to + one of the below: + + - Value of NVDIFFRAST_CACHE_DIR env var, if set + - $HOME/.cache/nvdiffrast if HOME env var is set + - $USERPROFILE/.cache/nvdiffrast if USERPROFILE is set. + + Args: + path: Where to save CUDA kernel build temporaries + ''' + global _nvdiffrast_cache_dir + _nvdiffrast_cache_dir = path + +def make_cache_dir_path(*paths: str) -> str: + if _nvdiffrast_cache_dir is not None: + return os.path.join(_nvdiffrast_cache_dir, *paths) + if 'NVDIFFRAST_CACHE_DIR' in os.environ: + return os.path.join(os.environ['NVDIFFRAST_CACHE_DIR'], *paths) + if 'HOME' in os.environ: + return os.path.join(os.environ['HOME'], '.cache', 'nvdiffrast', *paths) + if 'USERPROFILE' in os.environ: + return os.path.join(os.environ['USERPROFILE'], '.cache', 'nvdiffrast', *paths) + return os.path.join(tempfile.gettempdir(), '.cache', 'nvdiffrast', *paths) + +cuda_cache_version_tag = 'v1' +do_not_hash_included_headers = False # Speed up compilation by assuming that headers included by the CUDA code never change. Unsafe! +verbose = True # Print status messages to stdout. + +#---------------------------------------------------------------------------- +# Internal helper funcs. + +def _find_compiler_bindir(): + hostx64_paths = sorted(glob.glob('C:/Program Files/Microsoft Visual Studio/*/Enterprise/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + if hostx64_paths != []: + return hostx64_paths[0] + hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Enterprise/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + if hostx64_paths != []: + return hostx64_paths[0] + hostx64_paths = sorted(glob.glob('C:/Program Files/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + if hostx64_paths != []: + return hostx64_paths[0] + hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + if hostx64_paths != []: + return hostx64_paths[0] + hostx64_paths = sorted(glob.glob('C:/Program Files/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + if hostx64_paths != []: + return hostx64_paths[0] + hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + if hostx64_paths != []: + return hostx64_paths[0] + hostx64_paths = sorted(glob.glob('C:/Program Files/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + if hostx64_paths != []: + return hostx64_paths[0] + hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + if hostx64_paths != []: + return hostx64_paths[0] + vc_bin_dir = 'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin' + if os.path.isdir(vc_bin_dir): + return vc_bin_dir + return None + +def _get_compute_cap(device): + caps_str = device.physical_device_desc + m = re.search('compute capability: (\\d+).(\\d+)', caps_str) + major = m.group(1) + minor = m.group(2) + return (major, minor) + +def _get_cuda_gpu_arch_string(): + gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU'] + if len(gpus) == 0: + raise RuntimeError('No GPU devices found') + (major, minor) = _get_compute_cap(gpus[0]) + return 'sm_%s%s' % (major, minor) + +def _run_cmd(cmd): + with os.popen(cmd) as pipe: + output = pipe.read() + status = pipe.close() + if status is not None: + raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output)) + +def _prepare_nvcc_cli(opts): + cmd = 'nvcc ' + opts.strip() + cmd += ' --disable-warnings' + cmd += ' --include-path "%s"' % tf.sysconfig.get_include() + cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src') + cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'com_google_absl') + cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'eigen_archive') + + compiler_bindir = _find_compiler_bindir() + if compiler_bindir is None: + # Require that _find_compiler_bindir succeeds on Windows. Allow + # nvcc to use whatever is the default on Linux. + if os.name == 'nt': + raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__) + else: + cmd += ' --compiler-bindir "%s"' % compiler_bindir + cmd += ' 2>&1' + return cmd + +#---------------------------------------------------------------------------- +# Main entry point. + +_plugin_cache = dict() + +def get_plugin(cuda_file, extra_nvcc_options=[]): + cuda_file_base = os.path.basename(cuda_file) + cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base) + + # Already in cache? + if cuda_file in _plugin_cache: + return _plugin_cache[cuda_file] + + # Setup plugin. + if verbose: + print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True) + try: + # Hash CUDA source. + md5 = hashlib.md5() + with open(cuda_file, 'rb') as f: + md5.update(f.read()) + md5.update(b'\n') + + # Hash headers included by the CUDA code by running it through the preprocessor. + if not do_not_hash_included_headers: + if verbose: + print('Preprocessing... ', end='', flush=True) + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext) + _run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))) + with open(tmp_file, 'rb') as f: + bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros + good_file_str = ('"' + cuda_file_base + '"').encode('utf-8') + for ln in f: + if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas + ln = ln.replace(bad_file_str, good_file_str) + md5.update(ln) + md5.update(b'\n') + + # Select compiler options. + compile_opts = '' + if os.name == 'nt': + compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.lib') + compile_opts += ' --library-path="%s"' % (os.path.dirname(__file__) + r"\..\lib") # Find libraries during compilation. + elif os.name == 'posix': + compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.so') + compile_opts += ' --compiler-options \'-fPIC -D_GLIBCXX_USE_CXX11_ABI=0\'' + else: + assert False # not Windows or Linux, w00t? + compile_opts += ' --gpu-architecture=%s' % _get_cuda_gpu_arch_string() + compile_opts += ' --use_fast_math' + for opt in extra_nvcc_options: + compile_opts += ' ' + opt + nvcc_cmd = _prepare_nvcc_cli(compile_opts) + + # Hash build configuration. + md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n') + md5.update(('tf.VERSION: ' + tf.VERSION).encode('utf-8') + b'\n') + md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n') + + # Compile if not already compiled. + bin_file_ext = '.dll' if os.name == 'nt' else '.so' + cuda_cache_path = make_cache_dir_path() + bin_file = os.path.join(make_cache_dir_path(), cuda_file_name + '_' + md5.hexdigest() + bin_file_ext) + if not os.path.isfile(bin_file): + if verbose: + print('Compiling... ', end='', flush=True) + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext) + _run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)) + os.makedirs(cuda_cache_path, exist_ok=True) + intermediate_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext) + shutil.copyfile(tmp_file, intermediate_file) + os.rename(intermediate_file, bin_file) # atomic + + # Load. + if verbose: + print('Loading... ', end='', flush=True) + plugin = tf.load_op_library(bin_file) + + # Add to cache. + _plugin_cache[cuda_file] = plugin + if verbose: + print('Done.', flush=True) + return plugin + + except: + if verbose: + print('Failed!', flush=True) + raise + +#---------------------------------------------------------------------------- diff --git a/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_all.cu b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_all.cu new file mode 100644 index 0000000000000000000000000000000000000000..21b32399c8a2429faa89e3356a34837d0d7e3138 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_all.cu @@ -0,0 +1,36 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +// TF-specific helpers. + +#define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal("Cuda error: ", cudaGetErrorName(err), "[", #CUDA_CALL, ";]")); } while (0) +#define OP_CHECK_GL_ERROR(CTX, GL_CALL) do { GL_CALL; GLenum err = glGetError(); OP_REQUIRES(CTX, err == GL_NO_ERROR, errors::Internal("OpenGL error: ", getGLErrorString(err), "[", #GL_CALL, ";]")); } while (0) + +// Cuda kernels and CPP all together. What an absolute compilation unit. + +#define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__ +#include "../common/framework.h" +#include "../common/glutil.cpp" + +#include "../common/common.h" +#include "../common/common.cpp" + +#include "../common/rasterize.h" +#include "../common/rasterize_gl.cpp" +#include "../common/rasterize.cu" +#include "tf_rasterize.cu" + +#include "../common/interpolate.cu" +#include "tf_interpolate.cu" + +#include "../common/texture.cpp" +#include "../common/texture.cu" +#include "tf_texture.cu" + +#include "../common/antialias.cu" +#include "tf_antialias.cu" diff --git a/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_antialias.cu b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_antialias.cu new file mode 100644 index 0000000000000000000000000000000000000000..a8f1b5930a1c4145bf63f925168e01728c36a28c --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_antialias.cu @@ -0,0 +1,278 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// Forward TensorFlow op. + +struct AntialiasFwdOp : public OpKernel +{ + AntialiasKernelParams m_attribs; + + AntialiasFwdOp(OpKernelConstruction* ctx): OpKernel(ctx) + { + memset(&m_attribs, 0, sizeof(m_attribs)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("tri_const", &m_attribs.tri_const)); + } + + void Compute(OpKernelContext* ctx) + { + AntialiasKernelParams& p = m_attribs; + cudaStream_t stream = ctx->eigen_device().stream(); + + // Get input. + const Tensor& color = ctx->input(0); + const Tensor& rasterOut = ctx->input(1); + const Tensor& pos = ctx->input(2); + const Tensor& tri = ctx->input(3); + + // Instance rendering mode? + p.instance_mode = pos.dims() > 2; + + // Extract input dimensions. + if (p.instance_mode) + p.numVertices = (pos.dims() > 1) ? pos.dim_size(1) : 0; + else + p.numVertices = (pos.dims() > 0) ? pos.dim_size(0) : 0; + p.numTriangles = (tri.dims() > 0) ? tri.dim_size(0) : 0; + p.n = (color.dims() > 0) ? color.dim_size(0) : 0; + p.height = (color.dims() > 1) ? color.dim_size(1) : 0; + p.width = (color.dims() > 2) ? color.dim_size(2) : 0; + p.channels = (color.dims() > 3) ? color.dim_size(3) : 0; + + // Sanity checks. + OP_REQUIRES(ctx, color.dims() == 4 && color.dim_size(0) > 0 && color.dim_size(1) > 0 && color.dim_size(2) > 0 && color.dim_size(3) > 0, errors::InvalidArgument("color must have shape[>0, >0, >0, >0]")); + OP_REQUIRES(ctx, rasterOut.dims() == 4 && rasterOut.dim_size(0) > 0 && rasterOut.dim_size(1) > 0 && rasterOut.dim_size(2) > 0 && rasterOut.dim_size(3) == 4, errors::InvalidArgument("raster_out must have shape[>0, >0, >0, 4]")); + OP_REQUIRES(ctx, tri.dims() == 2 && tri.dim_size(0) > 0 && tri.dim_size(1) == 3, errors::InvalidArgument("tri must have shape [>0, 3]")); + OP_REQUIRES(ctx, color.dim_size(1) == rasterOut.dim_size(1) && color.dim_size(2) == rasterOut.dim_size(2), errors::InvalidArgument("color and raster_out inputs must have same spatial dimensions")); + if (p.instance_mode) + { + OP_REQUIRES(ctx, pos.dims() == 3 && pos.dim_size(0) > 0 && pos.dim_size(1) > 0 && pos.dim_size(2) == 4, errors::InvalidArgument("pos must have shape [>0, >0, 4] or [>0, 4]")); + OP_REQUIRES(ctx, rasterOut.dim_size(0) == p.n && pos.dim_size(0) == p.n, errors::InvalidArgument("minibatch size mismatch between inputs color, raster_out, pos")); + } + else + { + OP_REQUIRES(ctx, pos.dims() == 2 && pos.dim_size(0) > 0 && pos.dim_size(1) == 4, errors::InvalidArgument("pos must have shape [>0, >0, 4] or [>0, 4]")); + OP_REQUIRES(ctx, rasterOut.dim_size(0) == p.n, errors::InvalidArgument("minibatch size mismatch between inputs color, raster_out")); + } + + // Get input pointers. + p.color = color.flat().data(); + p.rasterOut = rasterOut.flat().data(); + p.tri = tri.flat().data(); + p.pos = pos.flat().data(); + + // Misc parameters. + p.xh = .5f * (float)p.width; + p.yh = .5f * (float)p.height; + + // Allocate output tensor. + Tensor* outputTensor = NULL; + TensorShape outputShape; + outputShape.AddDim(p.n); + outputShape.AddDim(p.height); + outputShape.AddDim(p.width); + outputShape.AddDim(p.channels); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, outputShape, &outputTensor)); + p.output = outputTensor->flat().data(); + + // Allocate work buffer. One extra int4 for storing counters. + Tensor* workTensor = NULL; + TensorShape workShape; + workShape.AddDim(p.n * p.width * p.height * 8 + 4); // 8 int for a maximum of two work items per pixel. + OP_REQUIRES_OK(ctx, ctx->allocate_output(1, workShape, &workTensor)); + p.workBuffer = (int4*)(workTensor->flat().data()); + + // Clear the work counters. + OP_CHECK_CUDA_ERROR(ctx, cudaMemsetAsync(p.workBuffer, 0, sizeof(int4), stream)); + + // Verify that buffers are aligned to allow float2/float4 operations. + OP_REQUIRES(ctx, !((uintptr_t)p.pos & 15), errors::Internal("pos input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)p.rasterOut & 7), errors::Internal("raster_out input tensor not aligned to float2")); + OP_REQUIRES(ctx, !((uintptr_t)p.workBuffer & 15), errors::Internal("work_buffer internal tensor not aligned to int4")); + + // Kernel parameters. + void* args[] = {&p}; + + // (Re-)calculate opposite vertex hash. + if (!p.evHash || !p.tri_const) + { + if (p.allocTriangles < p.numTriangles) + { + p.allocTriangles = max(p.allocTriangles, 64); + while (p.allocTriangles < p.numTriangles) + p.allocTriangles <<= 1; // Must be power of two. + + // (Re-)allocate memory for the hash. + OP_CHECK_CUDA_ERROR(ctx, cudaFree(p.evHash)); + OP_CHECK_CUDA_ERROR(ctx, cudaMalloc(&p.evHash, p.allocTriangles * AA_HASH_ELEMENTS_PER_TRIANGLE(p.allocTriangles) * sizeof(uint4))); + LOG(INFO) << "Increasing topology hash size to accommodate " << p.allocTriangles << " triangles"; + } + + // Clear the hash and launch the mesh kernel to populate it. + OP_CHECK_CUDA_ERROR(ctx, cudaMemsetAsync(p.evHash, 0, p.allocTriangles * AA_HASH_ELEMENTS_PER_TRIANGLE(p.allocTriangles) * sizeof(uint4), stream)); + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)AntialiasFwdMeshKernel, (p.numTriangles - 1) / AA_MESH_KERNEL_THREADS_PER_BLOCK + 1, AA_MESH_KERNEL_THREADS_PER_BLOCK, args, 0, stream)); + } + + // Copy input to output as a baseline. + OP_CHECK_CUDA_ERROR(ctx, cudaMemcpyAsync(p.output, p.color, p.n * p.height * p.width * p.channels * sizeof(float), cudaMemcpyDeviceToDevice, stream)); + + // Choose launch parameters for the discontinuity finder kernel and launch. + dim3 blockSize(AA_DISCONTINUITY_KERNEL_BLOCK_WIDTH, AA_DISCONTINUITY_KERNEL_BLOCK_HEIGHT, 1); + dim3 gridSize = getLaunchGridSize(blockSize, p.width, p.height, p.n); + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)AntialiasFwdDiscontinuityKernel, gridSize, blockSize, args, 0, stream)); + + // Determine optimum block size for the persistent analysis kernel. + int device = 0; + int numCTA = 0; + int numSM = 0; + OP_CHECK_CUDA_ERROR(ctx, cudaGetDevice(&device)); + OP_CHECK_CUDA_ERROR(ctx, cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numCTA, (void*)AntialiasFwdAnalysisKernel, AA_ANALYSIS_KERNEL_THREADS_PER_BLOCK, 0)); + OP_CHECK_CUDA_ERROR(ctx, cudaDeviceGetAttribute(&numSM, cudaDevAttrMultiProcessorCount, device)); + + // Launch analysis kernel. + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)AntialiasFwdAnalysisKernel, numCTA * numSM, AA_ANALYSIS_KERNEL_THREADS_PER_BLOCK, args, 0, stream)); + } +}; + +REGISTER_OP("AntialiasFwd") + .Input ("color: float") + .Input ("raster_out: float") + .Input ("pos: float") + .Input ("tri: int32") + .Output ("output: float") + .Output ("work_buffer: int32") + .Attr ("tri_const: int"); + +REGISTER_KERNEL_BUILDER(Name("AntialiasFwd").Device(DEVICE_GPU), AntialiasFwdOp); + +//------------------------------------------------------------------------ +// Gradient TensorFlow op. + +struct AntialiasGradOp : public OpKernel +{ + AntialiasKernelParams m_attribs; + + AntialiasGradOp(OpKernelConstruction* ctx): OpKernel(ctx) + { + memset(&m_attribs, 0, sizeof(m_attribs)); + } + + void Compute(OpKernelContext* ctx) + { + AntialiasKernelParams& p = m_attribs; + cudaStream_t stream = ctx->eigen_device().stream(); + + // Get input. + const Tensor& color = ctx->input(0); + const Tensor& rasterOut = ctx->input(1); + const Tensor& pos = ctx->input(2); + const Tensor& tri = ctx->input(3); + const Tensor& dy = ctx->input(4); + const Tensor& workBuffer = ctx->input(5); + + // Instance rendering mode? + p.instance_mode = pos.dims() > 2; + + // Extract input dimensions. + if (p.instance_mode) + p.numVertices = (pos.dims() > 1) ? pos.dim_size(1) : 0; + else + p.numVertices = (pos.dims() > 0) ? pos.dim_size(0) : 0; + p.numTriangles = (tri.dims() > 0) ? tri.dim_size(0) : 0; + p.n = (color.dims() > 0) ? color.dim_size(0) : 0; + p.height = (color.dims() > 1) ? color.dim_size(1) : 0; + p.width = (color.dims() > 2) ? color.dim_size(2) : 0; + p.channels = (color.dims() > 3) ? color.dim_size(3) : 0; + + // Sanity checks. + OP_REQUIRES(ctx, dy.dims() == 4 && dy.dim_size(0) > 0 && dy.dim_size(1) > 0 && dy.dim_size(2) > 0 && dy.dim_size(3) > 0, errors::InvalidArgument("dy must have shape[>0, >0, >0, >0]")); + OP_REQUIRES(ctx, color.dims() == 4 && color.dim_size(0) > 0 && color.dim_size(1) > 0 && color.dim_size(2) > 0 && color.dim_size(3) > 0, errors::InvalidArgument("color must have shape[>0, >0, >0, >0]")); + OP_REQUIRES(ctx, rasterOut.dims() == 4 && rasterOut.dim_size(0) > 0 && rasterOut.dim_size(1) > 0 && rasterOut.dim_size(2) > 0 && rasterOut.dim_size(3) == 4, errors::InvalidArgument("raster_out must have shape[>0, >0, >0, 4]")); + OP_REQUIRES(ctx, tri.dims() == 2 && tri.dim_size(0) > 0 && tri.dim_size(1) == 3, errors::InvalidArgument("tri must have shape [>0, 3]")); + OP_REQUIRES(ctx, color.dim_size(1) == rasterOut.dim_size(1) && color.dim_size(2) == rasterOut.dim_size(2), errors::InvalidArgument("color and raster_out inputs must have same spatial dimensions")); + OP_REQUIRES(ctx, color.dim_size(1) == dy.dim_size(1) && color.dim_size(2) == dy.dim_size(2) && color.dim_size(3) == dy.dim_size(3), errors::InvalidArgument("color and dy inputs must have same dimensions")); + if (p.instance_mode) + { + OP_REQUIRES(ctx, pos.dims() == 3 && pos.dim_size(0) > 0 && pos.dim_size(1) > 0 && pos.dim_size(2) == 4, errors::InvalidArgument("pos must have shape [>0, >0, 4] or [>0, 4]")); + OP_REQUIRES(ctx, rasterOut.dim_size(0) == p.n && pos.dim_size(0) == p.n, errors::InvalidArgument("minibatch size mismatch between inputs color, raster_out, pos")); + OP_REQUIRES(ctx, dy.dim_size(0) == p.n && rasterOut.dim_size(0) == p.n && pos.dim_size(0) == p.n, errors::InvalidArgument("minibatch size mismatch between inputs dy, color, raster_out, pos")); + } + else + { + OP_REQUIRES(ctx, pos.dims() == 2 && pos.dim_size(0) > 0 && pos.dim_size(1) == 4, errors::InvalidArgument("pos must have shape [>0, >0, 4] or [>0, 4]")); + OP_REQUIRES(ctx, rasterOut.dim_size(0) == p.n, errors::InvalidArgument("minibatch size mismatch between inputs color, raster_out")); + OP_REQUIRES(ctx, dy.dim_size(0) == p.n && rasterOut.dim_size(0) == p.n, errors::InvalidArgument("minibatch size mismatch between inputs dy, color, raster_out")); + } + + // Get input pointers. + p.dy = dy.flat().data(); + p.color = color.flat().data(); + p.rasterOut = rasterOut.flat().data(); + p.tri = tri.flat().data(); + p.pos = pos.flat().data(); + p.workBuffer = (int4*)(workBuffer.flat().data()); + + // Misc parameters. + p.xh = .5f * (float)p.width; + p.yh = .5f * (float)p.height; + + // Allocate color gradient output tensor. + Tensor* gradColor = NULL; + TensorShape gradColorShape; + gradColorShape.AddDim(p.n); + gradColorShape.AddDim(p.height); + gradColorShape.AddDim(p.width); + gradColorShape.AddDim(p.channels); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, gradColorShape, &gradColor)); + p.gradColor = gradColor->flat().data(); + + // Allocate position gradient output tensor. + Tensor* gradPos = NULL; + TensorShape gradPosShape; + if (p.instance_mode) + gradPosShape.AddDim(p.n); + gradPosShape.AddDim(p.numVertices); + gradPosShape.AddDim(4); + OP_REQUIRES_OK(ctx, ctx->allocate_output(1, gradPosShape, &gradPos)); + p.gradPos = gradPos->flat().data(); + + // Initialize all the stuff. + OP_CHECK_CUDA_ERROR(ctx, cudaMemsetAsync(&p.workBuffer[0].y, 0, sizeof(int), stream)); // Gradient kernel work counter. + OP_CHECK_CUDA_ERROR(ctx, cudaMemcpyAsync(p.gradColor, p.dy, p.n * p.height * p.width * p.channels * sizeof(float), cudaMemcpyDeviceToDevice, stream)); + OP_CHECK_CUDA_ERROR(ctx, cudaMemsetAsync(p.gradPos, 0, (p.instance_mode ? p.n : 1) * p.numVertices * 4 * sizeof(float), stream)); + + // Verify that buffers are aligned to allow float2/float4 operations. + OP_REQUIRES(ctx, !((uintptr_t)p.pos & 15), errors::Internal("pos input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)p.workBuffer & 15), errors::Internal("work_buffer internal tensor not aligned to int4")); + + // Launch the gradient kernel. + void* args[] = {&p}; + + int device = 0; + int numCTA = 0; + int numSM = 0; + OP_CHECK_CUDA_ERROR(ctx, cudaGetDevice(&device)); + OP_CHECK_CUDA_ERROR(ctx, cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numCTA, (void*)AntialiasGradKernel, AA_GRAD_KERNEL_THREADS_PER_BLOCK, 0)); + OP_CHECK_CUDA_ERROR(ctx, cudaDeviceGetAttribute(&numSM, cudaDevAttrMultiProcessorCount, device)); + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)AntialiasGradKernel, numCTA * numSM, AA_GRAD_KERNEL_THREADS_PER_BLOCK, args, 0, stream)); + } +}; + +REGISTER_OP("AntialiasGrad") + .Input ("color: float") + .Input ("raster_out: float") + .Input ("pos: float") + .Input ("tri: int32") + .Input ("dy: float") + .Input ("work_buffer: int32") + .Output ("grad_color: float") + .Output ("grad_pos: float"); + +REGISTER_KERNEL_BUILDER(Name("AntialiasGrad").Device(DEVICE_GPU), AntialiasGradOp); + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_interpolate.cu b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_interpolate.cu new file mode 100644 index 0000000000000000000000000000000000000000..6b44165cab026937014362893751e5c26be3fac3 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_interpolate.cu @@ -0,0 +1,301 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// Common op attribute parser. + +static __host__ void interpolateParseOpAttributes(OpKernelConstruction* ctx, InterpolateKernelParams& p, bool enableDA) +{ + if (enableDA) + { + OP_REQUIRES_OK(ctx, ctx->GetAttr("diff_attrs_all", &p.diff_attrs_all)); + if (!p.diff_attrs_all) + { + std::vector diff_attrs_vec; + OP_REQUIRES_OK(ctx, ctx->GetAttr("diff_attrs", &diff_attrs_vec)); + OP_REQUIRES(ctx, diff_attrs_vec.size() > 0, errors::InvalidArgument("differentiation enabled with empty diff_attrs list")); + OP_REQUIRES(ctx, diff_attrs_vec.size() <= IP_MAX_DIFF_ATTRS, errors::InvalidArgument("too many entries in diff_attrs list (increase IP_MAX_DIFF_ATTRS)")); + p.numDiffAttr = diff_attrs_vec.size(); + memcpy(p.diffAttrs, &diff_attrs_vec[0], diff_attrs_vec.size()*sizeof(int)); + } + } +} + +//------------------------------------------------------------------------ +// Forward TensorFlow op. + +template +struct InterpolateFwdOp : public OpKernel +{ + InterpolateKernelParams m_attribs; + + InterpolateFwdOp(OpKernelConstruction* ctx): OpKernel(ctx) + { + memset(&m_attribs, 0, sizeof(m_attribs)); + interpolateParseOpAttributes(ctx, m_attribs, ENABLE_DA); + } + + void Compute(OpKernelContext* ctx) + { + InterpolateKernelParams& p = m_attribs; + cudaStream_t stream = ctx->eigen_device().stream(); + + // Get input. + const Tensor& attr = ctx->input(0); + const Tensor& rast = ctx->input(1); + const Tensor& tri = ctx->input(2); + const Tensor& rast_db = ctx->input(ENABLE_DA ? 3 : 2); + + // Instance rendering mode? + p.instance_mode = attr.dims() > 2; + + // Extract input dimensions. + if (p.instance_mode) + { + p.numVertices = (attr.dims() > 1) ? attr.dim_size(1) : 0; + p.numAttr = (attr.dims() > 2) ? attr.dim_size(2) : 0; + } + else + { + p.numVertices = (attr.dims() > 0) ? attr.dim_size(0) : 0; + p.numAttr = (attr.dims() > 1) ? attr.dim_size(1) : 0; + } + p.numTriangles = (tri.dims() > 0) ? tri.dim_size(0) : 0; + p.height = (rast.dims() > 1) ? rast.dim_size(1) : 0; + p.width = (rast.dims() > 2) ? rast.dim_size(2) : 0; + p.depth = (rast.dims() > 0) ? rast.dim_size(0) : 0; + + // Sanity checks. + OP_REQUIRES(ctx, rast.dims() == 4 && rast.dim_size(0) > 0 && rast.dim_size(1) > 0 && rast.dim_size(2) > 0 && rast.dim_size(3) == 4, errors::InvalidArgument("rast must have shape[>0, >0, >0, 4]")); + OP_REQUIRES(ctx, tri.dims() == 2 && tri.dim_size(0) > 0 && tri.dim_size(1) == 3, errors::InvalidArgument("tri must have shape [>0, 3]")); + OP_REQUIRES(ctx, (attr.dims() == 2 || attr.dims() == 3) && attr.dim_size(0) > 0 && attr.dim_size(1) > 0 && (attr.dims() == 2 || attr.dim_size(2) > 0), errors::InvalidArgument("attr must have shape [>0, >0, >0] or [>0, >0]")); + if (p.instance_mode) + OP_REQUIRES(ctx, attr.dim_size(0) == p.depth || attr.dim_size(0) == 1, errors::InvalidArgument("minibatch size mismatch between inputs rast, attr")); + if (ENABLE_DA) + { + OP_REQUIRES(ctx, rast_db.dims() == 4 && rast_db.dim_size(0) > 0 && rast_db.dim_size(1) > 0 && rast_db.dim_size(2) > 0 && rast_db.dim_size(3) == 4, errors::InvalidArgument("rast_db must have shape[>0, >0, >0, 4]")); + OP_REQUIRES(ctx, rast_db.dim_size(1) == rast.dim_size(1) && rast_db.dim_size(2) == rast.dim_size(2), errors::InvalidArgument("spatial size mismatch between inputs rast and rast_db")); + OP_REQUIRES(ctx, rast_db.dim_size(0) == p.depth, errors::InvalidArgument("minibatch size mismatch between inputs rast, rast_db")); + } + + // All diff attrs mode. + if (p.diff_attrs_all) + p.numDiffAttr = p.numAttr; + + // Get input pointers. + p.attr = attr.flat().data(); + p.rast = rast.flat().data(); + p.tri = tri.flat().data(); + p.attrBC = (p.instance_mode && attr.dim_size(0) == 1) ? 1 : 0; + p.rastDB = ENABLE_DA ? rast_db.flat().data() : 0; + + // Allocate main output tensor. + Tensor* out_tensor = NULL; + TensorShape out_shape; + out_shape.AddDim(p.depth); + out_shape.AddDim(p.height); + out_shape.AddDim(p.width); + out_shape.AddDim(p.numAttr); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out_tensor)); + p.out = out_tensor->flat().data(); + + // Allocate pixel differential output tensor. + Tensor* out_da_tensor = NULL; + out_shape.set_dim(3, p.numDiffAttr * 2); + OP_REQUIRES_OK(ctx, ctx->allocate_output(1, out_shape, &out_da_tensor)); + p.outDA = ENABLE_DA ? out_da_tensor->flat().data() : 0; + + // Verify that buffers are aligned to allow float2/float4 operations. + OP_REQUIRES(ctx, !((uintptr_t)p.rast & 15), errors::Internal("rast input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)p.rastDB & 15), errors::Internal("rast_db input tensor not aligned to float4")); + if (ENABLE_DA) + OP_REQUIRES(ctx, !((uintptr_t)p.outDA & 7), errors::Internal("out_da output tensor not aligned to float2")); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(IP_FWD_MAX_KERNEL_BLOCK_WIDTH, IP_FWD_MAX_KERNEL_BLOCK_HEIGHT, p.width, p.height); + dim3 gridSize = getLaunchGridSize(blockSize, p.width, p.height, p.depth); + + // Launch CUDA kernel. + void* args[] = {&p}; + void* func = ENABLE_DA ? (void*)InterpolateFwdKernelDa : (void*)InterpolateFwdKernel; + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(func, gridSize, blockSize, args, 0, stream)); + } +}; + +REGISTER_OP("InterpolateFwd") + .Input ("attr: float") + .Input ("rast: float") + .Input ("tri: int32") + .Output ("out: float") + .Output ("out_da: float"); + +REGISTER_OP("InterpolateFwdDa") + .Input ("attr: float") + .Input ("rast: float") + .Input ("tri: int32") + .Input ("rast_db: float") + .Output ("out: float") + .Output ("out_da: float") + .Attr ("diff_attrs_all: int") + .Attr ("diff_attrs: list(int)"); + +REGISTER_KERNEL_BUILDER(Name("InterpolateFwd") .Device(DEVICE_GPU), InterpolateFwdOp); +REGISTER_KERNEL_BUILDER(Name("InterpolateFwdDa").Device(DEVICE_GPU), InterpolateFwdOp); + +//------------------------------------------------------------------------ +// Gradient TensorFlow op. + +template +struct InterpolateGradOp : public OpKernel +{ + InterpolateKernelParams m_attribs; + + InterpolateGradOp(OpKernelConstruction* ctx): OpKernel(ctx) + { + memset(&m_attribs, 0, sizeof(m_attribs)); + interpolateParseOpAttributes(ctx, m_attribs, ENABLE_DA); + } + + void Compute(OpKernelContext* ctx) + { + InterpolateKernelParams& p = m_attribs; + cudaStream_t stream = ctx->eigen_device().stream(); + + // Get input. + const Tensor& attr = ctx->input(0); + const Tensor& rast = ctx->input(1); + const Tensor& tri = ctx->input(2); + const Tensor& dy = ctx->input(3); + const Tensor& rast_db = ctx->input(ENABLE_DA ? 4 : 3); + const Tensor& dda = ctx->input(ENABLE_DA ? 5 : 3); + + // Instance rendering mode? + p.instance_mode = attr.dims() > 2; + + // Extract input dimensions. + if (p.instance_mode) + { + p.numVertices = (attr.dims() > 1) ? attr.dim_size(1) : 0; + p.numAttr = (attr.dims() > 2) ? attr.dim_size(2) : 0; + } + else + { + p.numVertices = (attr.dims() > 0) ? attr.dim_size(0) : 0; + p.numAttr = (attr.dims() > 1) ? attr.dim_size(1) : 0; + } + p.numTriangles = (tri.dims() > 0) ? tri.dim_size(0) : 0; + p.depth = (rast.dims() > 0) ? rast.dim_size(0) : 0; + p.height = (rast.dims() > 1) ? rast.dim_size(1) : 0; + p.width = (rast.dims() > 2) ? rast.dim_size(2) : 0; + int attr_depth = p.instance_mode ? (attr.dims() > 1 ? attr.dim_size(0) : 0) : 1; + + // Sanity checks. + OP_REQUIRES(ctx, rast.dims() == 4 && rast.dim_size(0) > 0 && rast.dim_size(1) > 0 && rast.dim_size(2) > 0 && rast.dim_size(3) == 4, errors::InvalidArgument("rast must have shape[>0, >0, >0, 4]")); + OP_REQUIRES(ctx, tri.dims() == 2 && tri.dim_size(0) > 0 && tri.dim_size(1) == 3, errors::InvalidArgument("tri must have shape [>0, 3]")); + OP_REQUIRES(ctx, (attr.dims() == 2 || attr.dims() == 3) && attr.dim_size(0) > 0 && attr.dim_size(1) > 0 && (attr.dims() == 2 || attr.dim_size(2) > 0), errors::InvalidArgument("attr must have shape [>0, >0, >0] or [>0, >0]")); + OP_REQUIRES(ctx, dy.dims() == 4 && dy.dim_size(0) > 0 && dy.dim_size(1) == p.height && dy.dim_size(2) == p.width && dy.dim_size(3) > 0, errors::InvalidArgument("dy must have shape [>0, height, width, >0]")); + OP_REQUIRES(ctx, dy.dim_size(3) == p.numAttr, errors::InvalidArgument("argument count mismatch between inputs dy, attr")); + OP_REQUIRES(ctx, (attr_depth == p.depth || attr_depth == 1) && dy.dim_size(0) == p.depth, errors::InvalidArgument("minibatch size mismatch between inputs rast, dy, attr")); + if (ENABLE_DA) + { + OP_REQUIRES(ctx, dda.dims() == 4 && dda.dim_size(0) > 0 && dda.dim_size(1) == p.height && dda.dim_size(2) == p.width, errors::InvalidArgument("dda must have shape [>0, height, width, ?]")); + OP_REQUIRES(ctx, dda.dim_size(0) == p.depth, errors::InvalidArgument("minibatch size mismatch between rast, dda")); + } + + // All diff attrs mode. + if (p.diff_attrs_all) + p.numDiffAttr = p.numAttr; + + // Get input pointers. + p.attr = attr.flat().data(); + p.rast = rast.flat().data(); + p.tri = tri.flat().data(); + p.dy = dy.flat().data(); + p.rastDB = ENABLE_DA ? rast_db.flat().data() : 0; + p.dda = ENABLE_DA ? dda.flat().data() : 0; + p.attrBC = (p.instance_mode && attr_depth < p.depth) ? 1 : 0; + + // Allocate attribute gradient output tensor. + Tensor* grad_attr_tensor = NULL; + TensorShape grad_attr_shape; + if (p.instance_mode) + grad_attr_shape.AddDim(attr_depth); + grad_attr_shape.AddDim(p.numVertices); + grad_attr_shape.AddDim(p.numAttr); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, grad_attr_shape, &grad_attr_tensor)); + p.gradAttr = grad_attr_tensor->flat().data(); + + // Allocate bary gradient output tensor. + Tensor* grad_rast_tensor = NULL; + TensorShape grad_rast_shape; + grad_rast_shape.AddDim(p.depth); + grad_rast_shape.AddDim(p.height); + grad_rast_shape.AddDim(p.width); + grad_rast_shape.AddDim(4); + OP_REQUIRES_OK(ctx, ctx->allocate_output(1, grad_rast_shape, &grad_rast_tensor)); + p.gradRaster = grad_rast_tensor->flat().data(); + + // Allocate bary pixel diff gradient output tensor. + if (ENABLE_DA) + { + Tensor* grad_rast_db_tensor = NULL; + OP_REQUIRES_OK(ctx, ctx->allocate_output(2, grad_rast_shape, &grad_rast_db_tensor)); + p.gradRasterDB = grad_rast_db_tensor->flat().data(); + } + + // Clear attribute gradients. + cudaMemsetAsync(p.gradAttr, 0, attr_depth * p.numVertices * p.numAttr * sizeof(float), stream); + + // Verify that buffers are aligned to allow float2/float4 operations. + OP_REQUIRES(ctx, !((uintptr_t)p.rast & 15), errors::Internal("rast input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)p.gradRaster & 15), errors::Internal("grad_rast output tensor not aligned to float4")); + if (ENABLE_DA) + { + OP_REQUIRES(ctx, !((uintptr_t)p.dda & 7), errors::Internal("dda input tensor not aligned to float2")); + OP_REQUIRES(ctx, !((uintptr_t)p.rastDB & 15), errors::Internal("rast_db input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)p.gradRasterDB & 15), errors::Internal("grad_rast_db output tensor not aligned to float4")); + } + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(IP_GRAD_MAX_KERNEL_BLOCK_WIDTH, IP_GRAD_MAX_KERNEL_BLOCK_HEIGHT, p.width, p.height); + dim3 gridSize = getLaunchGridSize(blockSize, p.width, p.height, p.depth); + + // Launch CUDA kernel. + void* args[] = {&p}; + void* func = ENABLE_DA ? (void*)InterpolateGradKernelDa : (void*)InterpolateGradKernel; + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(func, gridSize, blockSize, args, 0, stream)); + } +}; + +REGISTER_OP("InterpolateGrad") + .Input ("attr: float") + .Input ("rast: float") + .Input ("tri: int32") + .Input ("dy: float") + .Output ("grad_attr: float") + .Output ("grad_rast: float") + ; + +REGISTER_OP("InterpolateGradDa") + .Input ("attr: float") + .Input ("rast: float") + .Input ("tri: int32") + .Input ("dy: float") + .Input ("rast_db: float") + .Input ("dda: float") + .Output ("grad_attr: float") + .Output ("grad_rast: float") + .Output ("grad_rast_db: float") + .Attr ("diff_attrs_all: int") + .Attr ("diff_attrs: list(int)"); + ; + +REGISTER_KERNEL_BUILDER(Name("InterpolateGrad") .Device(DEVICE_GPU), InterpolateGradOp); +REGISTER_KERNEL_BUILDER(Name("InterpolateGradDa").Device(DEVICE_GPU), InterpolateGradOp); + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_rasterize.cu b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_rasterize.cu new file mode 100644 index 0000000000000000000000000000000000000000..de7e1e540e1154705418a348c722698923f4fca4 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_rasterize.cu @@ -0,0 +1,242 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// Forward TensorFlow op. + +struct RasterizeFwdOp : public OpKernel +{ + RasterizeGLState m_glState; // OpenGL-related persistent state. + int m_tri_const; // 1 if triangle array is known to be constant. + + RasterizeFwdOp(OpKernelConstruction* ctx): + OpKernel(ctx) + { + memset(&m_glState, 0, sizeof(RasterizeGLState)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("enable_db", &m_glState.enableDB)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("tri_const", &m_tri_const)); + } + + void Compute(OpKernelContext* ctx) + { + cudaStream_t stream = ctx->eigen_device().stream(); + + // Check that input shapes are correct. + const Tensor& pos = ctx->input(0); + const Tensor& tri = ctx->input(1); + const Tensor& resolution = ctx->input(2); + const Tensor& ranges = ctx->input(3); + + // Determine number of outputs + int num_outputs = m_glState.enableDB ? 2 : 1; + + // Determine instance mode and check input dimensions. + bool instance_mode = pos.dims() > 2; + if (instance_mode) + { + OP_REQUIRES(ctx, pos.dims() == 3 && pos.dim_size(0) > 0 && pos.dim_size(1) > 0 && pos.dim_size(2) == 4, errors::InvalidArgument("instance mode - pos must have shape [>0, >0, 4]")); + OP_REQUIRES(ctx, tri.dims() == 2 && tri.dim_size(0) > 0 && tri.dim_size(1) == 3, errors::InvalidArgument("tri must have shape [>0, 3]")); + OP_REQUIRES(ctx, resolution.dims() == 1 && resolution.dim_size(0) == 2, errors::InvalidArgument("resolution must have shape [2]")); + } + else + { + OP_REQUIRES(ctx, pos.dims() == 2 && pos.dim_size(0) > 0 && pos.dim_size(1) == 4, errors::InvalidArgument("range mode - pos must have shape [>0, 4]")); + OP_REQUIRES(ctx, tri.dims() == 2 && tri.dim_size(0) > 0 && tri.dim_size(1) == 3, errors::InvalidArgument("tri must have shape [>0, 3]")); + OP_REQUIRES(ctx, resolution.dims() == 1 && resolution.dim_size(0) == 2, errors::InvalidArgument("resolution must have shape [2]")); + OP_REQUIRES(ctx, ranges.dims() == 2 && ranges.dim_size(0) > 0 && ranges.dim_size(1) == 2, errors::InvalidArgument("range mode - ranges must have shape [>0, 2]")); + } + + // Get output shape. + const int32_t* res_in = resolution.flat().data(); // This is in CPU memory. + int height = res_in[0]; + int width = res_in[1]; + int depth = instance_mode ? pos.dim_size(0) : ranges.dim_size(0); + OP_REQUIRES(ctx, height > 0 && width > 0, errors::InvalidArgument("resolution must be [>0, >0]")); + + // Get position and triangle buffer sizes in int32/float32. + int posCount = 4 * pos.dim_size(0) * (instance_mode ? pos.dim_size(1) : 1); + int triCount = 3 * tri.dim_size(0); + + // Init context and GL? + bool initCtx = !m_glState.glFBO; + if (initCtx) + { + const DeviceBase::GpuDeviceInfo* g = ctx->device()->tensorflow_gpu_device_info(); + int cudaDeviceIdx = g ? g->gpu_id : -1; + rasterizeInitGLContext(ctx, m_glState, cudaDeviceIdx); // In common/rasterize.cpp + } + else + setGLContext(m_glState.glctx); // (Re-)Activate GL context. + + // Resize all buffers. + bool changes = false; + rasterizeResizeBuffers(ctx, m_glState, changes, posCount, triCount, width, height, depth); // In common/rasterize_gl.cpp + if (changes) + { +#ifdef _WIN32 + // Workaround for occasional blank first frame on Windows. + releaseGLContext(); + setGLContext(m_glState.glctx); +#endif + } + + // Copy input data to GL and render. + const float* posPtr = pos.flat().data(); + const int32_t* rangesPtr = instance_mode ? 0 : ranges.flat().data(); // This is in CPU memory. + const int32_t* triPtr = (initCtx || !m_tri_const) ? tri.flat().data() : NULL; // Copy triangles only if needed. + int vtxPerInstance = instance_mode ? pos.dim_size(1) : 0; + rasterizeRender(ctx, m_glState, stream, posPtr, posCount, vtxPerInstance, triPtr, triCount, rangesPtr, width, height, depth, -1); + + // Allocate output tensors. + TensorShape output_shape; + output_shape.AddDim(depth); + output_shape.AddDim(height); + output_shape.AddDim(width); + output_shape.AddDim(4); + float* outputPtr[2]; + for (int i=0; i < 2; i++) + { + if (i >= num_outputs) + output_shape.set_dim(3, 0); // Zero channels for unwanted out_db tensor. + Tensor* output_tensor = NULL; + OP_REQUIRES_OK(ctx, ctx->allocate_output(i, output_shape, &output_tensor)); + if (i < num_outputs) + outputPtr[i] = output_tensor->flat().data(); + } + + // Copy rasterized results into CUDA buffers. + rasterizeCopyResults(ctx, m_glState, stream, outputPtr, width, height, depth); + + // Done. Release GL context. + releaseGLContext(); + } +}; + +REGISTER_OP("RasterizeFwd") + .Input ("pos: float") + .Input ("tri: int32") + .Input ("resolution: int32") + .Input ("ranges: int32") + .Output ("out: float") + .Output ("out_db: float") + .Attr ("enable_db: int") + .Attr ("tri_const: int"); + +REGISTER_KERNEL_BUILDER(Name("RasterizeFwd").Device(DEVICE_GPU).HostMemory("resolution").HostMemory("ranges"), RasterizeFwdOp); + +//------------------------------------------------------------------------ +// Gradient TensorFlow op. + +template +struct RasterizeGradOp : public OpKernel +{ + RasterizeGradParams m_attribs; + + RasterizeGradOp(OpKernelConstruction* ctx): OpKernel(ctx) + { + memset(&m_attribs, 0, sizeof(m_attribs)); + } + + void Compute(OpKernelContext* ctx) + { + RasterizeGradParams& p = m_attribs; + cudaStream_t stream = ctx->eigen_device().stream(); + + // Input tensors. + const Tensor& pos = ctx->input(0); + const Tensor& tri = ctx->input(1); + const Tensor& out = ctx->input(2); + const Tensor& dy = ctx->input(3); + const Tensor& ddb = ctx->input(ENABLE_DB ? 4 : 3); + + // Determine instance mode. + p.instance_mode = (pos.dims() > 2) ? 1 : 0; + + // Shape is taken from the rasterizer output tensor. + OP_REQUIRES(ctx, out.dims() == 4, errors::InvalidArgument("out must be rank-4")); + p.depth = out.dim_size(0); + p.height = out.dim_size(1); + p.width = out.dim_size(2); + OP_REQUIRES(ctx, p.depth > 0 && p.height > 0 && p.width > 0, errors::InvalidArgument("resolution must be [>0, >0, >0]")); + + // Check other shapes. + if (p.instance_mode) + OP_REQUIRES(ctx, pos.dims() == 3 && pos.dim_size(0) == p.depth && pos.dim_size(1) > 0 && pos.dim_size(2) == 4, errors::InvalidArgument("pos must have shape [depth, >0, 4]")); + else + OP_REQUIRES(ctx, pos.dims() == 2 && pos.dim_size(0) > 0 && pos.dim_size(1) == 4, errors::InvalidArgument("pos must have shape [>0, 4]")); + OP_REQUIRES(ctx, tri.dims() == 2 && tri.dim_size(0) > 0 && tri.dim_size(1) == 3, errors::InvalidArgument("tri must have shape [>0, 3]")); + OP_REQUIRES(ctx, out.dims() == 4 && out.dim_size(0) == p.depth && out.dim_size(1) == p.height && out.dim_size(2) == p.width && out.dim_size(3) == 4, errors::InvalidArgument("out must have shape [depth, height, width, 4]")); + OP_REQUIRES(ctx, dy.dims() == 4 && dy.dim_size(0) == p.depth && dy.dim_size(1) == p.height && dy.dim_size(2) == p.width && dy.dim_size(3) == 4, errors::InvalidArgument("dy must have shape [depth, height, width, 4]")); + if (ENABLE_DB) + OP_REQUIRES(ctx, ddb.dims() == 4 && ddb.dim_size(0) == p.depth && ddb.dim_size(1) == p.height && ddb.dim_size(2) == p.width && ddb.dim_size(3) == 4, errors::InvalidArgument("ddb must have shape [depth, height, width, 4]")); + + // Populate parameters. + p.numTriangles = tri.dim_size(0); + p.numVertices = p.instance_mode ? pos.dim_size(1) : pos.dim_size(0); + p.pos = pos.flat().data(); + p.tri = tri.flat().data(); + p.out = out.flat().data(); + p.dy = dy.flat().data(); + p.ddb = ENABLE_DB ? ddb.flat().data() : 0; + + // Set up pixel position to clip space x, y transform. + p.xs = 2.f / (float)p.width; + p.xo = 1.f / (float)p.width - 1.f; + p.ys = 2.f / (float)p.height; + p.yo = 1.f / (float)p.height - 1.f; + + // Allocate output tensor for position gradients. + Tensor* grad_tensor = NULL; + TensorShape grad_shape; + if (p.instance_mode) + grad_shape.AddDim(p.depth); + grad_shape.AddDim(p.numVertices); + grad_shape.AddDim(4); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, grad_shape, &grad_tensor)); + p.grad = grad_tensor->flat().data(); + + // Clear the output buffers. + size_t gradBytes = (p.instance_mode ? p.depth : 1) * p.numVertices * 4 * sizeof(float); + cudaMemsetAsync(p.grad, 0, gradBytes, stream); + + // Verify that buffers are aligned to allow float2/float4 operations. + OP_REQUIRES(ctx, !((uintptr_t)p.pos & 15), errors::Internal("pos input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)p.dy & 7), errors::Internal("dy input tensor not aligned to float2")); + if (ENABLE_DB) + OP_REQUIRES(ctx, !((uintptr_t)p.ddb & 15), errors::Internal("ddb input tensor not aligned to float4")); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(RAST_GRAD_MAX_KERNEL_BLOCK_WIDTH, RAST_GRAD_MAX_KERNEL_BLOCK_HEIGHT, p.width, p.height); + dim3 gridSize = getLaunchGridSize(blockSize, p.width, p.height, p.depth); + + // Launch CUDA kernel. + void* args[] = {&p}; + void* func = ENABLE_DB ? (void*)RasterizeGradKernelDb : (void*)RasterizeGradKernel; + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(func, gridSize, blockSize, args, 0, stream)); + } +}; + +REGISTER_OP("RasterizeGrad") + .Input ("pos: float") + .Input ("tri: int32") + .Input ("out: float") + .Input ("dy: float") + .Output ("grad: float"); + +REGISTER_OP("RasterizeGradDb") + .Input ("pos: float") + .Input ("tri: int32") + .Input ("out: float") + .Input ("dy: float") + .Input ("ddb: float") + .Output ("grad: float"); + +REGISTER_KERNEL_BUILDER(Name("RasterizeGrad") .Device(DEVICE_GPU), RasterizeGradOp); +REGISTER_KERNEL_BUILDER(Name("RasterizeGradDb").Device(DEVICE_GPU), RasterizeGradOp); + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_texture.cu b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_texture.cu new file mode 100644 index 0000000000000000000000000000000000000000..a43db1e20e2edc34f535afa548b166d823fe087f --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/tensorflow/tf_texture.cu @@ -0,0 +1,525 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// Common op attribute parser. + +static __host__ void parseOpAttributes(OpKernelConstruction* ctx, TextureKernelParams& p) +{ + // Mip and filter modes. + OP_REQUIRES_OK(ctx, ctx->GetAttr("filter_mode", &p.filterMode)); + OP_REQUIRES(ctx, p.filterMode >= 0 && p.filterMode < TEX_MODE_COUNT, errors::InvalidArgument("filter_mode unsupported")); + p.enableMip = (p.filterMode == TEX_MODE_LINEAR_MIPMAP_NEAREST || p.filterMode == TEX_MODE_LINEAR_MIPMAP_LINEAR); + + // Mip level clamp. + if (p.enableMip) + { + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_mip_level", &p.mipLevelLimit)); + OP_REQUIRES(ctx, p.mipLevelLimit >= -1, errors::InvalidArgument("invalid max_mip_level")); + ctx->GetAttr("tex_const", &p.texConst); // Only available in forward op. + } + + // Boundary mode. + OP_REQUIRES_OK(ctx, ctx->GetAttr("boundary_mode", &p.boundaryMode)); + OP_REQUIRES(ctx, p.boundaryMode >= 0 && p.boundaryMode < TEX_BOUNDARY_MODE_COUNT, errors::InvalidArgument("boundary_mode unsupported")); +} + +//------------------------------------------------------------------------ +// Forward TensorFlow op. + +struct TextureFwdOp : public OpKernel +{ + TextureKernelParams m_attribs; + PersistentTensor m_persistentMipTensor; // Used if texture is constant and mips are enabled. + bool m_persistentMipTensorInitialized; + + TextureFwdOp(OpKernelConstruction* ctx): OpKernel(ctx) + { + memset(&m_attribs, 0, sizeof(m_attribs)); + m_persistentMipTensorInitialized = false; + parseOpAttributes(ctx, m_attribs); + } + + void Compute(OpKernelContext* ctx) + { + TextureKernelParams& p = m_attribs; + cudaStream_t stream = ctx->eigen_device().stream(); + bool cube_mode = (p.boundaryMode == TEX_BOUNDARY_MODE_CUBE); + + // Get input. + const Tensor& tex = ctx->input(0); + const Tensor& uv = ctx->input(1); + const Tensor& uv_da = ctx->input(p.enableMip ? 2 : 1); + + // Extract input dimensions. + p.n = (uv.dims() > 0) ? uv.dim_size(0) : 0; + p.imgHeight = (uv.dims() > 1) ? uv.dim_size(1) : 0; + p.imgWidth = (uv.dims() > 2) ? uv.dim_size(2) : 0; + p.texDepth = (tex.dims() > 0) ? tex.dim_size(0) : 0; + if (!cube_mode) + { + p.texHeight = (tex.dims() > 1) ? tex.dim_size(1) : 0; + p.texWidth = (tex.dims() > 2) ? tex.dim_size(2) : 0; + p.channels = (tex.dims() > 3) ? tex.dim_size(3) : 0; + } + else + { + p.texHeight = (tex.dims() > 2) ? tex.dim_size(2) : 0; + p.texWidth = (tex.dims() > 3) ? tex.dim_size(3) : 0; + p.channels = (tex.dims() > 4) ? tex.dim_size(4) : 0; + } + + // Sanity checks. + if (!cube_mode) + { + OP_REQUIRES(ctx, tex.dims() == 4 && tex.dim_size(0) > 0 && tex.dim_size(1) > 0 && tex.dim_size(2) > 0 && tex.dim_size(3) > 0, errors::InvalidArgument("tex must have shape[>0, >0, >0, >0]")); + OP_REQUIRES(ctx, uv.dims() == 4 && uv.dim_size(0) > 0 && uv.dim_size(1) > 0 && uv.dim_size(2) > 0 && uv.dim_size(3) == 2, errors::InvalidArgument("uv must have shape [>0, >0, >0, 2]")); + } + else + { + OP_REQUIRES(ctx, tex.dims() == 5 && tex.dim_size(0) > 0 && tex.dim_size(1) == 6 && tex.dim_size(2) > 0 && tex.dim_size(3) > 0 && tex.dim_size(4) > 0, errors::InvalidArgument("tex must have shape[>0, 6, >0, >0, >0] in cube map mode")); + OP_REQUIRES(ctx, uv.dims() == 4 && uv.dim_size(0) > 0 && uv.dim_size(1) > 0 && uv.dim_size(2) > 0 && uv.dim_size(3) == 3, errors::InvalidArgument("uv must have shape [>0, >0, >0, 3] in cube map mode")); + OP_REQUIRES(ctx, tex.dim_size(2) == tex.dim_size(3), errors::InvalidArgument("texture shape must be square in cube map mode")); + } + OP_REQUIRES(ctx, tex.dim_size(0) == 1 || tex.dim_size(0) == p.n, errors::InvalidArgument("minibatch size mismatch between inputs tex, uv")); + OP_REQUIRES(ctx, p.texWidth <= (1 << TEX_MAX_MIP_LEVEL) && p.texHeight <= (1 << TEX_MAX_MIP_LEVEL), errors::InvalidArgument("texture size too large")); + if (p.enableMip) + { + if (!cube_mode) + OP_REQUIRES(ctx, uv_da.dims() == 4 && uv_da.dim_size(0) == p.n && uv_da.dim_size(1) == p.imgHeight && uv_da.dim_size(2) == p.imgWidth && uv_da.dim_size(3) == 4, errors::InvalidArgument("uv_da must have shape [minibatch_size, height, width, 4]")); + else + OP_REQUIRES(ctx, uv_da.dims() == 4 && uv_da.dim_size(0) == p.n && uv_da.dim_size(1) == p.imgHeight && uv_da.dim_size(2) == p.imgWidth && uv_da.dim_size(3) == 6, errors::InvalidArgument("uv_da must have shape [minibatch_size, height, width, 6] in cube map mode")); + } + + // Get input pointers. + p.tex[0] = tex.flat().data(); + p.uv = uv.flat().data(); + p.uvDA = p.enableMip ? uv_da.flat().data() : 0; + + // Allocate output tensor. + Tensor* out_tensor = NULL; + TensorShape out_shape; + out_shape.AddDim(p.n); + out_shape.AddDim(p.imgHeight); + out_shape.AddDim(p.imgWidth); + out_shape.AddDim(p.channels); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out_tensor)); + p.out = out_tensor->flat().data(); + + // Choose kernel variants based on channel count. + void* args[] = {&p}; + int channel_div_idx = 0; + if (!(p.channels & 3)) + channel_div_idx = 2; // Channel count divisible by 4. + else if (!(p.channels & 1)) + channel_div_idx = 1; // Channel count divisible by 2. + + // Mip-related setup. + float* pmip = 0; + if (p.enableMip) + { + // Generate mip offsets. + int mipOffsets[TEX_MAX_MIP_LEVEL]; + int mipTotal = calculateMipInfo(ctx, p, mipOffsets); + + // Mip output tensor. + Tensor* mip_tensor = NULL; + TensorShape mip_shape; + mip_shape.AddDim(mipTotal); + + // If texture is constant, calculate mip stack only once. + bool computeMip = true; + if (p.texConst) + { + // First execution? + if (!m_persistentMipTensorInitialized) + { + // Allocate a persistent mip tensor. + OP_REQUIRES_OK(ctx, ctx->allocate_persistent(DT_FLOAT, mip_shape, &m_persistentMipTensor, &mip_tensor)); + m_persistentMipTensorInitialized = true; + } + else + { + // Reuse the persistent tensor, do not recompute mip levels. + mip_tensor = m_persistentMipTensor.AccessTensor(ctx); + computeMip = false; + } + + // Set as output tensor as well. + ctx->set_output(1, *mip_tensor); + } + else + { + // Allocate an output tensor as usual. + OP_REQUIRES_OK(ctx, ctx->allocate_output(1, mip_shape, &mip_tensor)); + } + + pmip = mip_tensor->flat().data(); // Pointer to data. + for (int i=1; i <= p.mipLevelMax; i++) + p.tex[i] = pmip + mipOffsets[i]; // Pointers to mip levels. + + // Build mip levels if needed. + if (computeMip) + { + for (int i=1; i <= p.mipLevelMax; i++) + { + int2 ms = mipLevelSize(p, i); + int3 sz = make_int3(ms.x, ms.y, p.texDepth); + dim3 blockSize = getLaunchBlockSize(TEX_FWD_MAX_MIP_KERNEL_BLOCK_WIDTH, TEX_FWD_MAX_MIP_KERNEL_BLOCK_HEIGHT, sz.x, sz.y); + dim3 gridSize = getLaunchGridSize(blockSize, sz.x, sz.y, sz.z * (cube_mode ? 6 : 1)); + p.mipLevelOut = i; + + void* build_func_tbl[3] = { (void*)MipBuildKernel1, (void*)MipBuildKernel2, (void*)MipBuildKernel4 }; + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(build_func_tbl[channel_div_idx], gridSize, blockSize, args, 0, stream)); + } + } + } + + // Verify that buffers are aligned to allow float2/float4 operations. Unused pointers are zero so always aligned. + if (!cube_mode) + OP_REQUIRES(ctx, !((uintptr_t)p.uv & 7), errors::Internal("uv input tensor not aligned to float2")); + if ((p.channels & 3) == 0) + { + OP_REQUIRES(ctx, !((uintptr_t)p.tex[0] & 15), errors::Internal("tex input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)p.out & 15), errors::Internal("out output tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)pmip & 15), errors::Internal("mip output tensor not aligned to float4")); + } + if ((p.channels & 1) == 0) + { + OP_REQUIRES(ctx, !((uintptr_t)p.tex[0] & 7), errors::Internal("tex input tensor not aligned to float2")); + OP_REQUIRES(ctx, !((uintptr_t)p.out & 7), errors::Internal("out output tensor not aligned to float2")); + OP_REQUIRES(ctx, !((uintptr_t)pmip & 7), errors::Internal("mip output tensor not aligned to float2")); + } + if (!cube_mode) + OP_REQUIRES(ctx, !((uintptr_t)p.uvDA & 15), errors::Internal("uv_da input tensor not aligned to float4")); + else + OP_REQUIRES(ctx, !((uintptr_t)p.uvDA & 7), errors::Internal("uv_da input tensor not aligned to float2")); + + // Choose launch parameters for texture lookup kernel. + dim3 blockSize = getLaunchBlockSize(TEX_FWD_MAX_KERNEL_BLOCK_WIDTH, TEX_FWD_MAX_KERNEL_BLOCK_HEIGHT, p.imgWidth, p.imgHeight); + dim3 gridSize = getLaunchGridSize(blockSize, p.imgWidth, p.imgHeight, p.n); + + // Choose kernel based on filter mode, cube mode, and datatype. + void* func_tbl[TEX_MODE_COUNT * 3 * 2] = { + (void*)TextureFwdKernelNearest1, + (void*)TextureFwdKernelNearest2, + (void*)TextureFwdKernelNearest4, + (void*)TextureFwdKernelLinear1, + (void*)TextureFwdKernelLinear2, + (void*)TextureFwdKernelLinear4, + (void*)TextureFwdKernelLinearMipmapNearest1, + (void*)TextureFwdKernelLinearMipmapNearest2, + (void*)TextureFwdKernelLinearMipmapNearest4, + (void*)TextureFwdKernelLinearMipmapLinear1, + (void*)TextureFwdKernelLinearMipmapLinear2, + (void*)TextureFwdKernelLinearMipmapLinear4, + (void*)TextureFwdKernelCubeNearest1, + (void*)TextureFwdKernelCubeNearest2, + (void*)TextureFwdKernelCubeNearest4, + (void*)TextureFwdKernelCubeLinear1, + (void*)TextureFwdKernelCubeLinear2, + (void*)TextureFwdKernelCubeLinear4, + (void*)TextureFwdKernelCubeLinearMipmapNearest1, + (void*)TextureFwdKernelCubeLinearMipmapNearest2, + (void*)TextureFwdKernelCubeLinearMipmapNearest4, + (void*)TextureFwdKernelCubeLinearMipmapLinear1, + (void*)TextureFwdKernelCubeLinearMipmapLinear2, + (void*)TextureFwdKernelCubeLinearMipmapLinear4, + }; + + // Function index. + int func_idx = p.filterMode; + if (cube_mode) + func_idx += TEX_MODE_COUNT; + func_idx = func_idx * 3 + channel_div_idx; + + // Launch kernel. + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(func_tbl[func_idx], gridSize, blockSize, args, 0, stream)); + } +}; + +REGISTER_OP("TextureFwd") + .Input ("tex: float") + .Input ("uv: float") + .Output ("out: float") + .Attr ("filter_mode: int") + .Attr ("boundary_mode: int"); + +REGISTER_OP("TextureFwdMip") + .Input ("tex: float") + .Input ("uv: float") + .Input ("uv_da: float") + .Output ("out: float") + .Output ("mip: float") + .Attr ("filter_mode: int") + .Attr ("boundary_mode: int") + .Attr ("tex_const: int") + .Attr ("max_mip_level: int"); + +REGISTER_KERNEL_BUILDER(Name("TextureFwd") .Device(DEVICE_GPU), TextureFwdOp); +REGISTER_KERNEL_BUILDER(Name("TextureFwdMip").Device(DEVICE_GPU), TextureFwdOp); + +//------------------------------------------------------------------------ +// Gradient TensorFlow op. + +struct TextureGradOp : public OpKernel +{ + TextureKernelParams m_attribs; + + TextureGradOp(OpKernelConstruction* ctx): OpKernel(ctx) + { + memset(&m_attribs, 0, sizeof(m_attribs)); + parseOpAttributes(ctx, m_attribs); + } + + void Compute(OpKernelContext* ctx) + { + TextureKernelParams& p = m_attribs; + cudaStream_t stream = ctx->eigen_device().stream(); + bool cube_mode = (p.boundaryMode == TEX_BOUNDARY_MODE_CUBE); + + // Get input. + const Tensor& tex = ctx->input(0); + const Tensor& uv = ctx->input(1); + const Tensor& dy = ctx->input(2); + const Tensor& uv_da = ctx->input(p.enableMip ? 3 : 2); + const Tensor& mip = ctx->input(p.enableMip ? 4 : 2); + + // Extract input dimensions. + p.n = (uv.dims() > 0) ? uv.dim_size(0) : 0; + p.imgHeight = (uv.dims() > 1) ? uv.dim_size(1) : 0; + p.imgWidth = (uv.dims() > 2) ? uv.dim_size(2) : 0; + p.texDepth = (tex.dims() > 0) ? tex.dim_size(0) : 0; + if (!cube_mode) + { + p.texHeight = (tex.dims() > 1) ? tex.dim_size(1) : 0; + p.texWidth = (tex.dims() > 2) ? tex.dim_size(2) : 0; + p.channels = (tex.dims() > 3) ? tex.dim_size(3) : 0; + } + else + { + p.texHeight = (tex.dims() > 2) ? tex.dim_size(2) : 0; + p.texWidth = (tex.dims() > 3) ? tex.dim_size(3) : 0; + p.channels = (tex.dims() > 4) ? tex.dim_size(4) : 0; + } + + // Sanity checks. + if (!cube_mode) + { + OP_REQUIRES(ctx, tex.dims() == 4 && tex.dim_size(0) > 0 && tex.dim_size(1) > 0 && tex.dim_size(2) > 0 && tex.dim_size(3) > 0, errors::InvalidArgument("tex must have shape[>0, >0, >0, >0]")); + OP_REQUIRES(ctx, uv.dims() == 4 && uv.dim_size(0) > 0 && uv.dim_size(1) > 0 && uv.dim_size(2) > 0 && uv.dim_size(3) == 2, errors::InvalidArgument("uv must have shape [>0, >0, >0, 2]")); + } + else + { + OP_REQUIRES(ctx, tex.dims() == 5 && tex.dim_size(0) > 0 && tex.dim_size(1) == 6 && tex.dim_size(2) > 0 && tex.dim_size(3) > 0 && tex.dim_size(4) > 0, errors::InvalidArgument("tex must have shape[>0, 6, >0, >0, >0] in cube map mode")); + OP_REQUIRES(ctx, uv.dims() == 4 && uv.dim_size(0) > 0 && uv.dim_size(1) > 0 && uv.dim_size(2) > 0 && uv.dim_size(3) == 3, errors::InvalidArgument("uv must have shape [>0, >0, >0, 3] in cube map mode")); + OP_REQUIRES(ctx, tex.dim_size(2) == tex.dim_size(3), errors::InvalidArgument("texture shape must be square in cube map mode")); + } + OP_REQUIRES(ctx, tex.dim_size(0) == 1 || tex.dim_size(0) == p.n, errors::InvalidArgument("minibatch size mismatch between inputs tex, uv")); + OP_REQUIRES(ctx, dy.dims() == 4 && dy.dim_size(0) == p.n && dy.dim_size(1) == p.imgHeight && dy.dim_size(2) == p.imgWidth && dy.dim_size(3) == p.channels, errors::InvalidArgument("dy must have shape [minibatch_size, height, width, channels]")); + if (p.enableMip) + { + if (!cube_mode) + OP_REQUIRES(ctx, uv_da.dims() == 4 && uv_da.dim_size(0) == p.n && uv_da.dim_size(1) == p.imgHeight && uv_da.dim_size(2) == p.imgWidth && uv_da.dim_size(3) == 4, errors::InvalidArgument("uv_da must have shape [minibatch_size, height, width, 4]")); + else + OP_REQUIRES(ctx, uv_da.dims() == 4 && uv_da.dim_size(0) == p.n && uv_da.dim_size(1) == p.imgHeight && uv_da.dim_size(2) == p.imgWidth && uv_da.dim_size(3) == 6, errors::InvalidArgument("uv_da must have shape [minibatch_size, height, width, 6] in cube map mode")); + } + + // Get input pointers. + p.tex[0] = tex.flat().data(); + p.uv = uv.flat().data(); + p.dy = dy.flat().data(); + p.uvDA = p.enableMip ? uv_da.flat().data() : 0; + float* pmip = p.enableMip ? (float*)mip.flat().data() : 0; + + // Allocate output tensor for tex gradient. + Tensor* grad_tex_tensor = NULL; + TensorShape grad_tex_shape; + grad_tex_shape.AddDim(p.texDepth); + if (cube_mode) + grad_tex_shape.AddDim(6); + grad_tex_shape.AddDim(p.texHeight); + grad_tex_shape.AddDim(p.texWidth); + grad_tex_shape.AddDim(p.channels); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, grad_tex_shape, &grad_tex_tensor)); + p.gradTex[0] = grad_tex_tensor->flat().data(); + + // Allocate output tensor for uv gradient. + if (p.filterMode != TEX_MODE_NEAREST) + { + TensorShape grad_uv_shape; + Tensor* grad_uv_tensor = NULL; + grad_uv_shape.AddDim(p.n); + grad_uv_shape.AddDim(p.imgHeight); + grad_uv_shape.AddDim(p.imgWidth); + grad_uv_shape.AddDim(uv.dim_size(3)); + OP_REQUIRES_OK(ctx, ctx->allocate_output(1, grad_uv_shape, &grad_uv_tensor)); + p.gradUV = grad_uv_tensor->flat().data(); + + // Allocate output tensor for uv_da gradient. + if (p.filterMode == TEX_MODE_LINEAR_MIPMAP_LINEAR) + { + Tensor* grad_uv_da_tensor = NULL; + grad_uv_shape.set_dim(3, uv_da.dim_size(3)); + OP_REQUIRES_OK(ctx, ctx->allocate_output(2, grad_uv_shape, &grad_uv_da_tensor)); + p.gradUVDA = grad_uv_da_tensor->flat().data(); + } + } + + // Choose kernel variants based on channel count. + int channel_div_idx = 0; + if (!(p.channels & 3)) + channel_div_idx = 2; // Channel count divisible by 4. + else if (!(p.channels & 1)) + channel_div_idx = 1; // Channel count divisible by 2. + + // Mip-related setup. + Tensor grad_mip_tensor; + float* pgradMip = 0; + if (p.enableMip) + { + // Generate mip offsets. + int mipOffsets[TEX_MAX_MIP_LEVEL]; + int mipTotal = calculateMipInfo(ctx, p, mipOffsets); + + // Get space for temporary mip gradients. + TensorShape grad_mip_shape; + grad_mip_shape.AddDim(mipTotal); + ctx->allocate_temp(DT_FLOAT, grad_mip_shape, &grad_mip_tensor); + pgradMip = grad_mip_tensor.flat().data(); + for (int i=1; i <= p.mipLevelMax; i++) + { + p.tex[i] = pmip + mipOffsets[i]; // Pointers to mip levels. + p.gradTex[i] = pgradMip + mipOffsets[i]; // Pointers to mip gradients. + } + + // Clear mip gradients. + OP_CHECK_CUDA_ERROR(ctx, cudaMemsetAsync(pgradMip, 0, mipTotal * sizeof(float), stream)); + } + + // Initialize texture gradients to zero. + int texBytes = p.texHeight * p.texWidth * p.texDepth * p.channels * sizeof(float); + if (cube_mode) + texBytes *= 6; + OP_CHECK_CUDA_ERROR(ctx, cudaMemsetAsync(p.gradTex[0], 0, texBytes, stream)); + + // Verify that buffers are aligned to allow float2/float4 operations. Unused pointers are zero so always aligned. + if (!cube_mode) + { + OP_REQUIRES(ctx, !((uintptr_t)p.uv & 7), errors::Internal("uv input tensor not aligned to float2")); + OP_REQUIRES(ctx, !((uintptr_t)p.gradUV & 7), errors::Internal("grad_uv output tensor not aligned to float2")); + OP_REQUIRES(ctx, !((uintptr_t)p.uvDA & 15), errors::Internal("uv_da input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)p.gradUVDA & 15), errors::Internal("grad_uv_da output tensor not aligned to float4")); + } + else + { + OP_REQUIRES(ctx, !((uintptr_t)p.uvDA & 7), errors::Internal("uv_da input tensor not aligned to float2")); + OP_REQUIRES(ctx, !((uintptr_t)p.gradUVDA & 7), errors::Internal("grad_uv_da output tensor not aligned to float2")); + } + if ((p.channels & 3) == 0) + { + OP_REQUIRES(ctx, !((uintptr_t)p.tex[0] & 15), errors::Internal("tex input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)p.gradTex[0] & 15), errors::Internal("grad_tex output tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)p.dy & 15), errors::Internal("dy input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)pmip & 15), errors::Internal("mip input tensor not aligned to float4")); + OP_REQUIRES(ctx, !((uintptr_t)pgradMip & 15), errors::Internal("internal mip gradient tensor not aligned to float4")); + } + if ((p.channels & 1) == 0) + { + OP_REQUIRES(ctx, !((uintptr_t)p.tex[0] & 7), errors::Internal("tex input tensor not aligned to float2")); + OP_REQUIRES(ctx, !((uintptr_t)p.gradTex[0] & 7), errors::Internal("grad_tex output tensor not aligned to float2")); + OP_REQUIRES(ctx, !((uintptr_t)p.dy & 7), errors::Internal("dy output tensor not aligned to float2")); + OP_REQUIRES(ctx, !((uintptr_t)pmip & 7), errors::Internal("mip input tensor not aligned to float2")); + OP_REQUIRES(ctx, !((uintptr_t)pgradMip & 7), errors::Internal("internal mip gradient tensor not aligned to float2")); + } + + // Choose launch parameters for main gradient kernel. + void* args[] = {&p}; + dim3 blockSize = getLaunchBlockSize(TEX_GRAD_MAX_KERNEL_BLOCK_WIDTH, TEX_GRAD_MAX_KERNEL_BLOCK_HEIGHT, p.imgWidth, p.imgHeight); + dim3 gridSize = getLaunchGridSize(blockSize, p.imgWidth, p.imgHeight, p.n); + + void* func_tbl[TEX_MODE_COUNT * 2] = { + (void*)TextureGradKernelNearest, + (void*)TextureGradKernelLinear, + (void*)TextureGradKernelLinearMipmapNearest, + (void*)TextureGradKernelLinearMipmapLinear, + (void*)TextureGradKernelCubeNearest, + (void*)TextureGradKernelCubeLinear, + (void*)TextureGradKernelCubeLinearMipmapNearest, + (void*)TextureGradKernelCubeLinearMipmapLinear, + }; + + // Function index. + int func_idx = p.filterMode; + if (cube_mode) + func_idx += TEX_MODE_COUNT; + + // Launch main gradient kernel. + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(func_tbl[func_idx], gridSize, blockSize, args, 0, stream)); + + // Launch kernel to pull gradients from mip levels. + if (p.enableMip) + { + dim3 blockSize = getLaunchBlockSize(TEX_GRAD_MAX_MIP_KERNEL_BLOCK_WIDTH, TEX_GRAD_MAX_MIP_KERNEL_BLOCK_HEIGHT, p.texWidth, p.texHeight); + dim3 gridSize = getLaunchGridSize(blockSize, p.texWidth, p.texHeight, p.texDepth * (cube_mode ? 6 : 1)); + int sharedBytes = blockSize.x * blockSize.y * p.channels * sizeof(float); + + void* mip_grad_func_tbl[3] = { (void*)MipGradKernel1, (void*)MipGradKernel2, (void*)MipGradKernel4 }; + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(mip_grad_func_tbl[channel_div_idx], gridSize, blockSize, args, sharedBytes, stream)); + } + } +}; + +REGISTER_OP("TextureGradNearest") + .Input ("tex: float") + .Input ("uv: float") + .Input ("dy: float") + .Output ("grad_tex: float") + .Attr ("filter_mode: int") + .Attr ("boundary_mode: int"); + +REGISTER_OP("TextureGradLinear") + .Input ("tex: float") + .Input ("uv: float") + .Input ("dy: float") + .Output ("grad_tex: float") + .Output ("grad_uv: float") + .Attr ("filter_mode: int") + .Attr ("boundary_mode: int"); + +REGISTER_OP("TextureGradLinearMipmapNearest") + .Input ("tex: float") + .Input ("uv: float") + .Input ("dy: float") + .Input ("uv_da: float") + .Input ("mip: float") + .Output ("grad_tex: float") + .Output ("grad_uv: float") + .Attr ("filter_mode: int") + .Attr ("boundary_mode: int") + .Attr ("max_mip_level: int"); + +REGISTER_OP("TextureGradLinearMipmapLinear") + .Input ("tex: float") + .Input ("uv: float") + .Input ("dy: float") + .Input ("uv_da: float") + .Input ("mip: float") + .Output ("grad_tex: float") + .Output ("grad_uv: float") + .Output ("grad_uv_da: float") + .Attr ("filter_mode: int") + .Attr ("boundary_mode: int") + .Attr ("max_mip_level: int"); + +REGISTER_KERNEL_BUILDER(Name("TextureGradNearest") .Device(DEVICE_GPU), TextureGradOp); +REGISTER_KERNEL_BUILDER(Name("TextureGradLinear") .Device(DEVICE_GPU), TextureGradOp); +REGISTER_KERNEL_BUILDER(Name("TextureGradLinearMipmapNearest").Device(DEVICE_GPU), TextureGradOp); +REGISTER_KERNEL_BUILDER(Name("TextureGradLinearMipmapLinear") .Device(DEVICE_GPU), TextureGradOp); + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/torch/__init__.py b/extensions/nvdiffrast/nvdiffrast/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c12ddfffee1da1f71a5cb34bbb79455f180263c5 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/torch/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +from .ops import RasterizeCudaContext, RasterizeGLContext, get_log_level, set_log_level, rasterize, DepthPeeler, interpolate, texture, texture_construct_mip, antialias, antialias_construct_topology_hash +__all__ = ["RasterizeCudaContext", "RasterizeGLContext", "get_log_level", "set_log_level", "rasterize", "DepthPeeler", "interpolate", "texture", "texture_construct_mip", "antialias", "antialias_construct_topology_hash"] diff --git a/extensions/nvdiffrast/nvdiffrast/torch/ops.py b/extensions/nvdiffrast/nvdiffrast/torch/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..5514bbb5254e5cd99f7ae3870069c7e4e81ec2dd --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/torch/ops.py @@ -0,0 +1,734 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import importlib +import logging +import numpy as np +import os +import torch +import torch.utils.cpp_extension +from . import _C + +#---------------------------------------------------------------------------- +# C++/Cuda plugin compiler/loader. + +_cached_plugin = {} +def _get_plugin(gl=False): + assert isinstance(gl, bool) + + # Modified with precompiled torch CUDA extension + if not gl: + return _C + + # Return cached plugin if already loaded. + if _cached_plugin.get(gl, None) is not None: + return _cached_plugin[gl] + + # Make sure we can find the necessary compiler and libary binaries. + if os.name == 'nt': + lib_dir = os.path.dirname(__file__) + r"\..\lib" + def find_cl_path(): + import glob + def get_sort_key(x): + # Primary criterion is VS version, secondary is edition, third is internal MSVC version. + x = x.split('\\')[3:] + x[1] = {'BuildTools': '~0', 'Community': '~1', 'Pro': '~2', 'Professional': '~3', 'Enterprise': '~4'}.get(x[1], x[1]) + return x + vs_relative_path = r"\Microsoft Visual Studio\*\*\VC\Tools\MSVC\*\bin\Hostx64\x64" + paths = glob.glob(r"C:\Program Files" + vs_relative_path) + paths += glob.glob(r"C:\Program Files (x86)" + vs_relative_path) + if paths: + return sorted(paths, key=get_sort_key)[-1] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ['PATH'] += ';' + cl_path + + # Compiler options. + common_opts = ['-DNVDR_TORCH'] + cc_opts = [] + if os.name == 'nt': + cc_opts += ['/wd4067', '/wd4624'] # Disable warnings in torch headers. + + # Linker options for the GL-interfacing plugin. + ldflags = [] + if gl: + if os.name == 'posix': + ldflags = ['-lGL', '-lEGL'] + elif os.name == 'nt': + libs = ['gdi32', 'opengl32', 'user32', 'setgpu'] + ldflags = ['/LIBPATH:' + lib_dir] + ['/DEFAULTLIB:' + x for x in libs] + + # List of source files. + if gl: + source_files = [ + '../common/common.cpp', + '../common/glutil.cpp', + '../common/rasterize_gl.cpp', + 'torch_bindings_gl.cpp', + 'torch_rasterize_gl.cpp', + ] + else: + source_files = [ + '../common/cudaraster/impl/Buffer.cpp', + '../common/cudaraster/impl/CudaRaster.cpp', + '../common/cudaraster/impl/RasterImpl.cu', + '../common/cudaraster/impl/RasterImpl.cpp', + '../common/common.cpp', + '../common/rasterize.cu', + '../common/interpolate.cu', + '../common/texture.cu', + '../common/texture.cpp', + '../common/antialias.cu', + 'torch_bindings.cpp', + 'torch_rasterize.cpp', + 'torch_interpolate.cpp', + 'torch_texture.cpp', + 'torch_antialias.cpp', + ] + + # Some containers set this to contain old architectures that won't compile. We only need the one installed in the machine. + os.environ['TORCH_CUDA_ARCH_LIST'] = '' + + # On Linux, show a warning if GLEW is being forcibly loaded when compiling the GL plugin. + if gl and (os.name == 'posix') and ('libGLEW' in os.environ.get('LD_PRELOAD', '')): + logging.getLogger('nvdiffrast').warning("Warning: libGLEW is being loaded via LD_PRELOAD, and will probably conflict with the OpenGL plugin") + + # Try to detect if a stray lock file is left in cache directory and show a warning. This sometimes happens on Windows if the build is interrupted at just the right moment. + plugin_name = 'nvdiffrast_plugin' + ('_gl' if gl else '') + try: + lock_fn = os.path.join(torch.utils.cpp_extension._get_build_directory(plugin_name, False), 'lock') + if os.path.exists(lock_fn): + logging.getLogger('nvdiffrast').warning("Lock file exists in build directory: '%s'" % lock_fn) + except: + pass + + # Speed up compilation on Windows. + if os.name == 'nt': + # Skip telemetry sending step in vcvarsall.bat + os.environ['VSCMD_SKIP_SENDTELEMETRY'] = '1' + + # Opportunistically patch distutils to cache MSVC environments. + try: + import distutils._msvccompiler + import functools + if not hasattr(distutils._msvccompiler._get_vc_env, '__wrapped__'): + distutils._msvccompiler._get_vc_env = functools.lru_cache()(distutils._msvccompiler._get_vc_env) + except: + pass + + # Compile and load. + source_paths = [os.path.join(os.path.dirname(__file__), fn) for fn in source_files] + torch.utils.cpp_extension.load(name=plugin_name, sources=source_paths, extra_cflags=common_opts+cc_opts, extra_cuda_cflags=common_opts+['-lineinfo'], extra_ldflags=ldflags, with_cuda=True, verbose=False) + + # Import, cache, and return the compiled module. + _cached_plugin[gl] = importlib.import_module(plugin_name) + return _cached_plugin[gl] + +#---------------------------------------------------------------------------- +# Log level. +#---------------------------------------------------------------------------- + +def get_log_level(): + '''Get current log level. + + Returns: + Current log level in nvdiffrast. See `set_log_level()` for possible values. + ''' + return _get_plugin().get_log_level() + +def set_log_level(level): + '''Set log level. + + Log levels follow the convention on the C++ side of Torch: + 0 = Info, + 1 = Warning, + 2 = Error, + 3 = Fatal. + The default log level is 1. + + Args: + level: New log level as integer. Internal nvdiffrast messages of this + severity or higher will be printed, while messages of lower + severity will be silent. + ''' + _get_plugin().set_log_level(level) + +#---------------------------------------------------------------------------- +# CudaRaster state wrapper. +#---------------------------------------------------------------------------- + +class RasterizeCudaContext: + def __init__(self, device=None): + '''Create a new Cuda rasterizer context. + + The context is deleted and internal storage is released when the object is + destroyed. + + Args: + device (Optional): Cuda device on which the context is created. Type can be + `torch.device`, string (e.g., `'cuda:1'`), or int. If not + specified, context will be created on currently active Cuda + device. + Returns: + The newly created Cuda rasterizer context. + ''' + if device is None: + cuda_device_idx = torch.cuda.current_device() + else: + with torch.cuda.device(device): + cuda_device_idx = torch.cuda.current_device() + self.cpp_wrapper = _get_plugin().RasterizeCRStateWrapper(cuda_device_idx) + self.output_db = True + self.active_depth_peeler = None + +#---------------------------------------------------------------------------- +# GL state wrapper. +#---------------------------------------------------------------------------- + +class RasterizeGLContext: + def __init__(self, output_db=True, mode='automatic', device=None): + '''Create a new OpenGL rasterizer context. + + Creating an OpenGL context is a slow operation so you should usually reuse the same + context in all calls to `rasterize()` on the same CPU thread. The OpenGL context + is deleted when the object is destroyed. + + Side note: When using the OpenGL context in a rasterization operation, the + context's internal framebuffer object is automatically enlarged to accommodate the + rasterization operation's output shape, but it is never shrunk in size until the + context is destroyed. Thus, if you need to rasterize, say, deep low-resolution + tensors and also shallow high-resolution tensors, you can conserve GPU memory by + creating two separate OpenGL contexts for these tasks. In this scenario, using the + same OpenGL context for both tasks would end up reserving GPU memory for a deep, + high-resolution output tensor. + + Args: + output_db (bool): Compute and output image-space derivates of barycentrics. + mode: OpenGL context handling mode. Valid values are 'manual' and 'automatic'. + device (Optional): Cuda device on which the context is created. Type can be + `torch.device`, string (e.g., `'cuda:1'`), or int. If not + specified, context will be created on currently active Cuda + device. + Returns: + The newly created OpenGL rasterizer context. + ''' + assert output_db is True or output_db is False + assert mode in ['automatic', 'manual'] + self.output_db = output_db + self.mode = mode + if device is None: + cuda_device_idx = torch.cuda.current_device() + else: + with torch.cuda.device(device): + cuda_device_idx = torch.cuda.current_device() + self.cpp_wrapper = _get_plugin(gl=True).RasterizeGLStateWrapper(output_db, mode == 'automatic', cuda_device_idx) + self.active_depth_peeler = None # For error checking only. + + def set_context(self): + '''Set (activate) OpenGL context in the current CPU thread. + Only available if context was created in manual mode. + ''' + assert self.mode == 'manual' + self.cpp_wrapper.set_context() + + def release_context(self): + '''Release (deactivate) currently active OpenGL context. + Only available if context was created in manual mode. + ''' + assert self.mode == 'manual' + self.cpp_wrapper.release_context() + +#---------------------------------------------------------------------------- +# Rasterize. +#---------------------------------------------------------------------------- + +class _rasterize_func(torch.autograd.Function): + @staticmethod + def forward(ctx, raster_ctx, pos, tri, resolution, ranges, grad_db, peeling_idx): + if isinstance(raster_ctx, RasterizeGLContext): + out, out_db = _get_plugin(gl=True).rasterize_fwd_gl(raster_ctx.cpp_wrapper, pos, tri, resolution, ranges, peeling_idx) + else: + out, out_db = _get_plugin().rasterize_fwd_cuda(raster_ctx.cpp_wrapper, pos, tri, resolution, ranges, peeling_idx) + ctx.save_for_backward(pos, tri, out) + ctx.saved_grad_db = grad_db + return out, out_db + + @staticmethod + def backward(ctx, dy, ddb): + pos, tri, out = ctx.saved_tensors + if ctx.saved_grad_db: + g_pos = _get_plugin().rasterize_grad_db(pos, tri, out, dy, ddb) + else: + g_pos = _get_plugin().rasterize_grad(pos, tri, out, dy) + return None, g_pos, None, None, None, None, None + +# Op wrapper. +def rasterize(glctx, pos, tri, resolution, ranges=None, grad_db=True): + '''Rasterize triangles. + + All input tensors must be contiguous and reside in GPU memory except for + the `ranges` tensor that, if specified, has to reside in CPU memory. The + output tensors will be contiguous and reside in GPU memory. + + Args: + glctx: Rasterizer context of type `RasterizeGLContext` or `RasterizeCudaContext`. + pos: Vertex position tensor with dtype `torch.float32`. To enable range + mode, this tensor should have a 2D shape [num_vertices, 4]. To enable + instanced mode, use a 3D shape [minibatch_size, num_vertices, 4]. + tri: Triangle tensor with shape [num_triangles, 3] and dtype `torch.int32`. + resolution: Output resolution as integer tuple (height, width). + ranges: In range mode, tensor with shape [minibatch_size, 2] and dtype + `torch.int32`, specifying start indices and counts into `tri`. + Ignored in instanced mode. + grad_db: Propagate gradients of image-space derivatives of barycentrics + into `pos` in backward pass. Ignored if using an OpenGL context that + was not configured to output image-space derivatives. + + Returns: + A tuple of two tensors. The first output tensor has shape [minibatch_size, + height, width, 4] and contains the main rasterizer output in order (u, v, z/w, + triangle_id). If the OpenGL context was configured to output image-space + derivatives of barycentrics, the second output tensor will also have shape + [minibatch_size, height, width, 4] and contain said derivatives in order + (du/dX, du/dY, dv/dX, dv/dY). Otherwise it will be an empty tensor with shape + [minibatch_size, height, width, 0]. + ''' + assert isinstance(glctx, (RasterizeGLContext, RasterizeCudaContext)) + assert grad_db is True or grad_db is False + grad_db = grad_db and glctx.output_db + + # Sanitize inputs. + assert isinstance(pos, torch.Tensor) and isinstance(tri, torch.Tensor) + resolution = tuple(resolution) + if ranges is None: + ranges = torch.empty(size=(0, 2), dtype=torch.int32, device='cpu') + else: + assert isinstance(ranges, torch.Tensor) + + # Check that context is not currently reserved for depth peeling. + if glctx.active_depth_peeler is not None: + return RuntimeError("Cannot call rasterize() during depth peeling operation, use rasterize_next_layer() instead") + + # Instantiate the function. + return _rasterize_func.apply(glctx, pos, tri, resolution, ranges, grad_db, -1) + +#---------------------------------------------------------------------------- +# Depth peeler context manager for rasterizing multiple depth layers. +#---------------------------------------------------------------------------- + +class DepthPeeler: + def __init__(self, glctx, pos, tri, resolution, ranges=None, grad_db=True): + '''Create a depth peeler object for rasterizing multiple depth layers. + + Arguments are the same as in `rasterize()`. + + Returns: + The newly created depth peeler. + ''' + assert isinstance(glctx, (RasterizeGLContext, RasterizeCudaContext)) + assert grad_db is True or grad_db is False + grad_db = grad_db and glctx.output_db + + # Sanitize inputs as usual. + assert isinstance(pos, torch.Tensor) and isinstance(tri, torch.Tensor) + resolution = tuple(resolution) + if ranges is None: + ranges = torch.empty(size=(0, 2), dtype=torch.int32, device='cpu') + else: + assert isinstance(ranges, torch.Tensor) + + # Store all the parameters. + self.raster_ctx = glctx + self.pos = pos + self.tri = tri + self.resolution = resolution + self.ranges = ranges + self.grad_db = grad_db + self.peeling_idx = None + + def __enter__(self): + if self.raster_ctx is None: + raise RuntimeError("Cannot re-enter a terminated depth peeling operation") + if self.raster_ctx.active_depth_peeler is not None: + raise RuntimeError("Cannot have multiple depth peelers active simultaneously in a rasterization context") + self.raster_ctx.active_depth_peeler = self + self.peeling_idx = 0 + return self + + def __exit__(self, *args): + assert self.raster_ctx.active_depth_peeler is self + self.raster_ctx.active_depth_peeler = None + self.raster_ctx = None # Remove all references to input tensor so they're not left dangling. + self.pos = None + self.tri = None + self.resolution = None + self.ranges = None + self.grad_db = None + self.peeling_idx = None + return None + + def rasterize_next_layer(self): + '''Rasterize next depth layer. + + Operation is equivalent to `rasterize()` except that previously reported + surface points are culled away. + + Returns: + A tuple of two tensors as in `rasterize()`. + ''' + assert self.raster_ctx.active_depth_peeler is self + assert self.peeling_idx >= 0 + result = _rasterize_func.apply(self.raster_ctx, self.pos, self.tri, self.resolution, self.ranges, self.grad_db, self.peeling_idx) + self.peeling_idx += 1 + return result + +#---------------------------------------------------------------------------- +# Interpolate. +#---------------------------------------------------------------------------- + +# Output pixel differentials for at least some attributes. +class _interpolate_func_da(torch.autograd.Function): + @staticmethod + def forward(ctx, attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list): + out, out_da = _get_plugin().interpolate_fwd_da(attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list) + ctx.save_for_backward(attr, rast, tri, rast_db) + ctx.saved_misc = diff_attrs_all, diff_attrs_list + return out, out_da + + @staticmethod + def backward(ctx, dy, dda): + attr, rast, tri, rast_db = ctx.saved_tensors + diff_attrs_all, diff_attrs_list = ctx.saved_misc + g_attr, g_rast, g_rast_db = _get_plugin().interpolate_grad_da(attr, rast, tri, dy, rast_db, dda, diff_attrs_all, diff_attrs_list) + return g_attr, g_rast, None, g_rast_db, None, None + +# No pixel differential for any attribute. +class _interpolate_func(torch.autograd.Function): + @staticmethod + def forward(ctx, attr, rast, tri): + out, out_da = _get_plugin().interpolate_fwd(attr, rast, tri) + ctx.save_for_backward(attr, rast, tri) + return out, out_da + + @staticmethod + def backward(ctx, dy, _): + attr, rast, tri = ctx.saved_tensors + g_attr, g_rast = _get_plugin().interpolate_grad(attr, rast, tri, dy) + return g_attr, g_rast, None + +# Op wrapper. +def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None): + """Interpolate vertex attributes. + + All input tensors must be contiguous and reside in GPU memory. The output tensors + will be contiguous and reside in GPU memory. + + Args: + attr: Attribute tensor with dtype `torch.float32`. + Shape is [num_vertices, num_attributes] in range mode, or + [minibatch_size, num_vertices, num_attributes] in instanced mode. + Broadcasting is supported along the minibatch axis. + rast: Main output tensor from `rasterize()`. + tri: Triangle tensor with shape [num_triangles, 3] and dtype `torch.int32`. + rast_db: (Optional) Tensor containing image-space derivatives of barycentrics, + i.e., the second output tensor from `rasterize()`. Enables computing + image-space derivatives of attributes. + diff_attrs: (Optional) List of attribute indices for which image-space + derivatives are to be computed. Special value 'all' is equivalent + to list [0, 1, ..., num_attributes - 1]. + + Returns: + A tuple of two tensors. The first output tensor contains interpolated + attributes and has shape [minibatch_size, height, width, num_attributes]. + If `rast_db` and `diff_attrs` were specified, the second output tensor contains + the image-space derivatives of the selected attributes and has shape + [minibatch_size, height, width, 2 * len(diff_attrs)]. The derivatives of the + first selected attribute A will be on channels 0 and 1 as (dA/dX, dA/dY), etc. + Otherwise, the second output tensor will be an empty tensor with shape + [minibatch_size, height, width, 0]. + """ + # Sanitize the list of pixel differential attributes. + if diff_attrs is None: + diff_attrs = [] + elif diff_attrs != 'all': + diff_attrs = np.asarray(diff_attrs, np.int32) + assert len(diff_attrs.shape) == 1 + diff_attrs = diff_attrs.tolist() + + diff_attrs_all = int(diff_attrs == 'all') + diff_attrs_list = [] if diff_attrs_all else diff_attrs + + # Check inputs. + assert all(isinstance(x, torch.Tensor) for x in (attr, rast, tri)) + if diff_attrs: + assert isinstance(rast_db, torch.Tensor) + + # Choose stub. + if diff_attrs: + return _interpolate_func_da.apply(attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list) + else: + return _interpolate_func.apply(attr, rast, tri) + +#---------------------------------------------------------------------------- +# Texture +#---------------------------------------------------------------------------- + +# Linear-mipmap-linear and linear-mipmap-nearest: Mipmaps enabled. +class _texture_func_mip(torch.autograd.Function): + @staticmethod + def forward(ctx, filter_mode, tex, uv, uv_da, mip_level_bias, mip_wrapper, filter_mode_enum, boundary_mode_enum, *mip_stack): + empty = torch.tensor([]) + if uv_da is None: + uv_da = empty + if mip_level_bias is None: + mip_level_bias = empty + if mip_wrapper is None: + mip_wrapper = _get_plugin().TextureMipWrapper() + out = _get_plugin().texture_fwd_mip(tex, uv, uv_da, mip_level_bias, mip_wrapper, mip_stack, filter_mode_enum, boundary_mode_enum) + ctx.save_for_backward(tex, uv, uv_da, mip_level_bias, *mip_stack) + ctx.saved_misc = filter_mode, mip_wrapper, filter_mode_enum, boundary_mode_enum + return out + + @staticmethod + def backward(ctx, dy): + tex, uv, uv_da, mip_level_bias, *mip_stack = ctx.saved_tensors + filter_mode, mip_wrapper, filter_mode_enum, boundary_mode_enum = ctx.saved_misc + if filter_mode == 'linear-mipmap-linear': + g_tex, g_uv, g_uv_da, g_mip_level_bias, g_mip_stack = _get_plugin().texture_grad_linear_mipmap_linear(tex, uv, dy, uv_da, mip_level_bias, mip_wrapper, mip_stack, filter_mode_enum, boundary_mode_enum) + return (None, g_tex, g_uv, g_uv_da, g_mip_level_bias, None, None, None) + tuple(g_mip_stack) + else: # linear-mipmap-nearest + g_tex, g_uv, g_mip_stack = _get_plugin().texture_grad_linear_mipmap_nearest(tex, uv, dy, uv_da, mip_level_bias, mip_wrapper, mip_stack, filter_mode_enum, boundary_mode_enum) + return (None, g_tex, g_uv, None, None, None, None, None) + tuple(g_mip_stack) + +# Linear and nearest: Mipmaps disabled. +class _texture_func(torch.autograd.Function): + @staticmethod + def forward(ctx, filter_mode, tex, uv, filter_mode_enum, boundary_mode_enum): + out = _get_plugin().texture_fwd(tex, uv, filter_mode_enum, boundary_mode_enum) + ctx.save_for_backward(tex, uv) + ctx.saved_misc = filter_mode, filter_mode_enum, boundary_mode_enum + return out + + @staticmethod + def backward(ctx, dy): + tex, uv = ctx.saved_tensors + filter_mode, filter_mode_enum, boundary_mode_enum = ctx.saved_misc + if filter_mode == 'linear': + g_tex, g_uv = _get_plugin().texture_grad_linear(tex, uv, dy, filter_mode_enum, boundary_mode_enum) + return None, g_tex, g_uv, None, None + else: # nearest + g_tex = _get_plugin().texture_grad_nearest(tex, uv, dy, filter_mode_enum, boundary_mode_enum) + return None, g_tex, None, None, None + +# Op wrapper. +def texture(tex, uv, uv_da=None, mip_level_bias=None, mip=None, filter_mode='auto', boundary_mode='wrap', max_mip_level=None): + """Perform texture sampling. + + All input tensors must be contiguous and reside in GPU memory. The output tensor + will be contiguous and reside in GPU memory. + + Args: + tex: Texture tensor with dtype `torch.float32`. For 2D textures, must have shape + [minibatch_size, tex_height, tex_width, tex_channels]. For cube map textures, + must have shape [minibatch_size, 6, tex_height, tex_width, tex_channels] where + tex_width and tex_height are equal. Note that `boundary_mode` must also be set + to 'cube' to enable cube map mode. Broadcasting is supported along the minibatch axis. + uv: Tensor containing per-pixel texture coordinates. When sampling a 2D texture, + must have shape [minibatch_size, height, width, 2]. When sampling a cube map + texture, must have shape [minibatch_size, height, width, 3]. + uv_da: (Optional) Tensor containing image-space derivatives of texture coordinates. + Must have same shape as `uv` except for the last dimension that is to be twice + as long. + mip_level_bias: (Optional) Per-pixel bias for mip level selection. If `uv_da` is omitted, + determines mip level directly. Must have shape [minibatch_size, height, width]. + mip: (Optional) Preconstructed mipmap stack from a `texture_construct_mip()` call, or a list + of tensors specifying a custom mipmap stack. When specifying a custom mipmap stack, + the tensors in the list must follow the same format as `tex` except for width and + height that must follow the usual rules for mipmap sizes. The base level texture + is still supplied in `tex` and must not be included in the list. Gradients of a + custom mipmap stack are not automatically propagated to base texture but the mipmap + tensors will receive gradients of their own. If a mipmap stack is not specified + but the chosen filter mode requires it, the mipmap stack is constructed internally + and discarded afterwards. + filter_mode: Texture filtering mode to be used. Valid values are 'auto', 'nearest', + 'linear', 'linear-mipmap-nearest', and 'linear-mipmap-linear'. Mode 'auto' + selects 'linear' if neither `uv_da` or `mip_level_bias` is specified, and + 'linear-mipmap-linear' when at least one of them is specified, these being + the highest-quality modes possible depending on the availability of the + image-space derivatives of the texture coordinates or direct mip level information. + boundary_mode: Valid values are 'wrap', 'clamp', 'zero', and 'cube'. If `tex` defines a + cube map, this must be set to 'cube'. The default mode 'wrap' takes fractional + part of texture coordinates. Mode 'clamp' clamps texture coordinates to the + centers of the boundary texels. Mode 'zero' virtually extends the texture with + all-zero values in all directions. + max_mip_level: If specified, limits the number of mipmaps constructed and used in mipmap-based + filter modes. + + Returns: + A tensor containing the results of the texture sampling with shape + [minibatch_size, height, width, tex_channels]. Cube map fetches with invalid uv coordinates + (e.g., zero vectors) output all zeros and do not propagate gradients. + """ + + # Default filter mode. + if filter_mode == 'auto': + filter_mode = 'linear-mipmap-linear' if (uv_da is not None or mip_level_bias is not None) else 'linear' + + # Sanitize inputs. + if max_mip_level is None: + max_mip_level = -1 + else: + max_mip_level = int(max_mip_level) + assert max_mip_level >= 0 + + # Check inputs. + assert isinstance(tex, torch.Tensor) and isinstance(uv, torch.Tensor) + if 'mipmap' in filter_mode: + assert isinstance(uv_da, torch.Tensor) or isinstance(mip_level_bias, torch.Tensor) + + # If mipping disabled via max level=0, we may as well use simpler filtering internally. + if max_mip_level == 0 and filter_mode in ['linear-mipmap-nearest', 'linear-mipmap-linear']: + filter_mode = 'linear' + + # Convert filter mode to internal enumeration. + filter_mode_dict = {'nearest': 0, 'linear': 1, 'linear-mipmap-nearest': 2, 'linear-mipmap-linear': 3} + filter_mode_enum = filter_mode_dict[filter_mode] + + # Convert boundary mode to internal enumeration. + boundary_mode_dict = {'cube': 0, 'wrap': 1, 'clamp': 2, 'zero': 3} + boundary_mode_enum = boundary_mode_dict[boundary_mode] + + # Construct a mipmap if necessary. + if 'mipmap' in filter_mode: + mip_wrapper, mip_stack = None, [] + if mip is not None: + assert isinstance(mip, (_get_plugin().TextureMipWrapper, list)) + if isinstance(mip, list): + assert all(isinstance(x, torch.Tensor) for x in mip) + mip_stack = mip + else: + mip_wrapper = mip + else: + mip_wrapper = _get_plugin().texture_construct_mip(tex, max_mip_level, boundary_mode == 'cube') + + # Choose stub. + if filter_mode == 'linear-mipmap-linear' or filter_mode == 'linear-mipmap-nearest': + return _texture_func_mip.apply(filter_mode, tex, uv, uv_da, mip_level_bias, mip_wrapper, filter_mode_enum, boundary_mode_enum, *mip_stack) + else: + return _texture_func.apply(filter_mode, tex, uv, filter_mode_enum, boundary_mode_enum) + +# Mipmap precalculation for cases where the texture stays constant. +def texture_construct_mip(tex, max_mip_level=None, cube_mode=False): + """Construct a mipmap stack for a texture. + + This function can be used for constructing a mipmap stack for a texture that is known to remain + constant. This avoids reconstructing it every time `texture()` is called. + + Args: + tex: Texture tensor with the same constraints as in `texture()`. + max_mip_level: If specified, limits the number of mipmaps constructed. + cube_mode: Must be set to True if `tex` specifies a cube map texture. + + Returns: + An opaque object containing the mipmap stack. This can be supplied in a call to `texture()` + in the `mip` argument. + """ + + assert isinstance(tex, torch.Tensor) + assert cube_mode is True or cube_mode is False + if max_mip_level is None: + max_mip_level = -1 + else: + max_mip_level = int(max_mip_level) + assert max_mip_level >= 0 + return _get_plugin().texture_construct_mip(tex, max_mip_level, cube_mode) + +#---------------------------------------------------------------------------- +# Antialias. +#---------------------------------------------------------------------------- + +class _antialias_func(torch.autograd.Function): + @staticmethod + def forward(ctx, color, rast, pos, tri, topology_hash, pos_gradient_boost): + out, work_buffer = _get_plugin().antialias_fwd(color, rast, pos, tri, topology_hash) + ctx.save_for_backward(color, rast, pos, tri) + ctx.saved_misc = pos_gradient_boost, work_buffer + return out + + @staticmethod + def backward(ctx, dy): + color, rast, pos, tri = ctx.saved_tensors + pos_gradient_boost, work_buffer = ctx.saved_misc + g_color, g_pos = _get_plugin().antialias_grad(color, rast, pos, tri, dy, work_buffer) + if pos_gradient_boost != 1.0: + g_pos = g_pos * pos_gradient_boost + return g_color, None, g_pos, None, None, None + +# Op wrapper. +def antialias(color, rast, pos, tri, topology_hash=None, pos_gradient_boost=1.0): + """Perform antialiasing. + + All input tensors must be contiguous and reside in GPU memory. The output tensor + will be contiguous and reside in GPU memory. + + Note that silhouette edge determination is based on vertex indices in the triangle + tensor. For it to work properly, a vertex belonging to multiple triangles must be + referred to using the same vertex index in each triangle. Otherwise, nvdiffrast will always + classify the adjacent edges as silhouette edges, which leads to bad performance and + potentially incorrect gradients. If you are unsure whether your data is good, check + which pixels are modified by the antialias operation and compare to the example in the + documentation. + + Args: + color: Input image to antialias with shape [minibatch_size, height, width, num_channels]. + rast: Main output tensor from `rasterize()`. + pos: Vertex position tensor used in the rasterization operation. + tri: Triangle tensor used in the rasterization operation. + topology_hash: (Optional) Preconstructed topology hash for the triangle tensor. If not + specified, the topology hash is constructed internally and discarded afterwards. + pos_gradient_boost: (Optional) Multiplier for gradients propagated to `pos`. + + Returns: + A tensor containing the antialiased image with the same shape as `color` input tensor. + """ + + # Check inputs. + assert all(isinstance(x, torch.Tensor) for x in (color, rast, pos, tri)) + + # Construct topology hash unless provided by user. + if topology_hash is not None: + assert isinstance(topology_hash, _get_plugin().TopologyHashWrapper) + else: + topology_hash = _get_plugin().antialias_construct_topology_hash(tri) + + # Instantiate the function. + return _antialias_func.apply(color, rast, pos, tri, topology_hash, pos_gradient_boost) + +# Topology hash precalculation for cases where the triangle array stays constant. +def antialias_construct_topology_hash(tri): + """Construct a topology hash for a triangle tensor. + + This function can be used for constructing a topology hash for a triangle tensor that is + known to remain constant. This avoids reconstructing it every time `antialias()` is called. + + Args: + tri: Triangle tensor with shape [num_triangles, 3]. Must be contiguous and reside in + GPU memory. + + Returns: + An opaque object containing the topology hash. This can be supplied in a call to + `antialias()` in the `topology_hash` argument. + """ + assert isinstance(tri, torch.Tensor) + return _get_plugin().antialias_construct_topology_hash(tri) + +#---------------------------------------------------------------------------- diff --git a/extensions/nvdiffrast/nvdiffrast/torch/torch_antialias.cpp b/extensions/nvdiffrast/nvdiffrast/torch/torch_antialias.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9e52f957daa2d7bffe6e6ec52e1435bbf1ea2f55 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/torch/torch_antialias.cpp @@ -0,0 +1,243 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "torch_common.inl" +#include "torch_types.h" +#include "../common/common.h" +#include "../common/antialias.h" + +//------------------------------------------------------------------------ +// Kernel prototypes. + +void AntialiasFwdMeshKernel (const AntialiasKernelParams p); +void AntialiasFwdDiscontinuityKernel(const AntialiasKernelParams p); +void AntialiasFwdAnalysisKernel (const AntialiasKernelParams p); +void AntialiasGradKernel (const AntialiasKernelParams p); + +//------------------------------------------------------------------------ +// Topology hash construction. + +TopologyHashWrapper antialias_construct_topology_hash(torch::Tensor tri) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(tri)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AntialiasKernelParams p = {}; // Initialize all fields to zero. + + // Check inputs. + NVDR_CHECK_DEVICE(tri); + NVDR_CHECK_CONTIGUOUS(tri); + NVDR_CHECK_I32(tri); + NVDR_CHECK(tri.sizes().size() == 2 && tri.size(0) > 0 && tri.size(1) == 3, "tri must have shape [>0, 3]"); + + // Fill in kernel parameters. + p.numTriangles = tri.size(0); + p.numVertices = 0x7fffffff; // Let's not require vertex positions just to enable an error check. + p.tri = tri.data_ptr(); + + // Kernel parameters. + p.allocTriangles = 64; + while (p.allocTriangles < p.numTriangles) + p.allocTriangles <<= 1; // Must be power of two. + + // Construct the hash tensor and get pointer. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); + torch::Tensor ev_hash = torch::zeros({(uint64_t)p.allocTriangles * AA_HASH_ELEMENTS_PER_TRIANGLE(p.allocTriangles) * 4}, opts); + p.evHash = (uint4*)(ev_hash.data_ptr()); + + // Check alignment. + NVDR_CHECK(!((uintptr_t)p.evHash & 15), "ev_hash internal tensor not aligned to int4"); + + // Populate the hash. + void* args[] = {&p}; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((void*)AntialiasFwdMeshKernel, (p.numTriangles - 1) / AA_MESH_KERNEL_THREADS_PER_BLOCK + 1, AA_MESH_KERNEL_THREADS_PER_BLOCK, args, 0, stream)); + + // Return. + TopologyHashWrapper hash_wrap; + hash_wrap.ev_hash = ev_hash; + return hash_wrap; +} + +//------------------------------------------------------------------------ +// Forward op. + +std::tuple antialias_fwd(torch::Tensor color, torch::Tensor rast, torch::Tensor pos, torch::Tensor tri, TopologyHashWrapper topology_hash_wrap) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(color)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AntialiasKernelParams p = {}; // Initialize all fields to zero. + p.instance_mode = (pos.sizes().size() > 2) ? 1 : 0; + torch::Tensor& topology_hash = topology_hash_wrap.ev_hash; // Unwrap. + + // Check inputs. + NVDR_CHECK_DEVICE(color, rast, pos, tri, topology_hash); + NVDR_CHECK_CONTIGUOUS(color, rast, pos, tri, topology_hash); + NVDR_CHECK_F32(color, rast, pos); + NVDR_CHECK_I32(tri, topology_hash); + + // Sanity checks. + NVDR_CHECK(color.sizes().size() == 4 && color.size(0) > 0 && color.size(1) > 0 && color.size(2) > 0 && color.size(3) > 0, "color must have shape[>0, >0, >0, >0]"); + NVDR_CHECK(rast.sizes().size() == 4 && rast.size(0) > 0 && rast.size(1) > 0 && rast.size(2) > 0 && rast.size(3) == 4, "rast must have shape[>0, >0, >0, 4]"); + NVDR_CHECK(tri.sizes().size() == 2 && tri.size(0) > 0 && tri.size(1) == 3, "tri must have shape [>0, 3]"); + NVDR_CHECK(color.size(1) == rast.size(1) && color.size(2) == rast.size(2), "color and rast inputs must have same spatial dimensions"); + if (p.instance_mode) + { + NVDR_CHECK(pos.sizes().size() == 3 && pos.size(0) > 0 && pos.size(1) > 0 && pos.size(2) == 4, "pos must have shape [>0, >0, 4] or [>0, 4]"); + NVDR_CHECK(rast.size(0) == color.size(0) && pos.size(0) == color.size(0), "minibatch size mismatch between inputs color, rast, pos"); + } + else + { + NVDR_CHECK(pos.sizes().size() == 2 && pos.size(0) > 0 && pos.size(1) == 4, "pos must have shape [>0, >0, 4] or [>0, 4]"); + NVDR_CHECK(rast.size(0) == color.size(0), "minibatch size mismatch between inputs color, rast"); + } + + // Extract input dimensions. + p.numVertices = pos.size(p.instance_mode ? 1 : 0); + p.numTriangles = tri.size(0); + p.n = color.size(0); + p.height = color.size(1); + p.width = color.size(2); + p.channels = color.size(3); + + // Get input pointers. + p.color = color.data_ptr(); + p.rasterOut = rast.data_ptr(); + p.tri = tri.data_ptr(); + p.pos = pos.data_ptr(); + p.evHash = (uint4*)(topology_hash.data_ptr()); + + // Misc parameters. + p.xh = .5f * (float)p.width; + p.yh = .5f * (float)p.height; + + // Determine hash allocation size. + p.allocTriangles = 64; + while (p.allocTriangles < p.numTriangles) + p.allocTriangles <<= 1; // Must be power of two. + + // Allocate output tensors. + torch::Tensor out = color.detach().clone(); // Use color as base. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor work_buffer = torch::empty({p.n * p.width * p.height * 8 + 4}, opts); // 8 int for a maximum of two work items per pixel. + p.output = out.data_ptr(); + p.workBuffer = (int4*)(work_buffer.data_ptr()); + + // Clear the work counters. + NVDR_CHECK_CUDA_ERROR(cudaMemsetAsync(p.workBuffer, 0, sizeof(int4), stream)); + + // Verify that buffers are aligned to allow float2/float4 operations. + NVDR_CHECK(!((uintptr_t)p.pos & 15), "pos input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.rasterOut & 7), "raster_out input tensor not aligned to float2"); + NVDR_CHECK(!((uintptr_t)p.workBuffer & 15), "work_buffer internal tensor not aligned to int4"); + NVDR_CHECK(!((uintptr_t)p.evHash & 15), "topology_hash internal tensor not aligned to int4"); + + // Choose launch parameters for the discontinuity finder kernel and launch. + void* args[] = {&p}; + dim3 blockSize(AA_DISCONTINUITY_KERNEL_BLOCK_WIDTH, AA_DISCONTINUITY_KERNEL_BLOCK_HEIGHT, 1); + dim3 gridSize = getLaunchGridSize(blockSize, p.width, p.height, p.n); + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((void*)AntialiasFwdDiscontinuityKernel, gridSize, blockSize, args, 0, stream)); + + // Determine optimum block size for the persistent analysis kernel and launch. + int device = 0; + int numCTA = 0; + int numSM = 0; + NVDR_CHECK_CUDA_ERROR(cudaGetDevice(&device)); + NVDR_CHECK_CUDA_ERROR(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numCTA, (void*)AntialiasFwdAnalysisKernel, AA_ANALYSIS_KERNEL_THREADS_PER_BLOCK, 0)); + NVDR_CHECK_CUDA_ERROR(cudaDeviceGetAttribute(&numSM, cudaDevAttrMultiProcessorCount, device)); + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((void*)AntialiasFwdAnalysisKernel, numCTA * numSM, AA_ANALYSIS_KERNEL_THREADS_PER_BLOCK, args, 0, stream)); + + // Return results. + return std::tuple(out, work_buffer); +} + +//------------------------------------------------------------------------ +// Gradient op. + +std::tuple antialias_grad(torch::Tensor color, torch::Tensor rast, torch::Tensor pos, torch::Tensor tri, torch::Tensor dy, torch::Tensor work_buffer) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(color)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AntialiasKernelParams p = {}; // Initialize all fields to zero. + p.instance_mode = (pos.sizes().size() > 2) ? 1 : 0; + + // Check inputs. + NVDR_CHECK_DEVICE(color, rast, pos, tri, dy, work_buffer); + NVDR_CHECK_CONTIGUOUS(color, rast, pos, tri, work_buffer); + NVDR_CHECK_F32(color, rast, pos, dy, work_buffer); + NVDR_CHECK_I32(tri); + + // Sanity checks. + NVDR_CHECK(dy.sizes().size() == 4 && dy.size(0) > 0 && dy.size(1) > 0 && dy.size(2) > 0 && dy.size(3) > 0, "dy must have shape[>0, >0, >0, >0]"); + NVDR_CHECK(color.sizes().size() == 4 && color.size(0) > 0 && color.size(1) > 0 && color.size(2) > 0 && color.size(3) > 0, "color must have shape[>0, >0, >0, >0]"); + NVDR_CHECK(rast.sizes().size() == 4 && rast.size(0) > 0 && rast.size(1) > 0 && rast.size(2) > 0 && rast.size(3) == 4, "raster_out must have shape[>0, >0, >0, 4]"); + NVDR_CHECK(tri.sizes().size() == 2 && tri.size(0) > 0 && tri.size(1) == 3, "tri must have shape [>0, 3]"); + NVDR_CHECK(color.size(1) == rast.size(1) && color.size(2) == rast.size(2), "color and raster_out inputs must have same spatial dimensions"); + NVDR_CHECK(color.size(1) == dy.size(1) && color.size(2) == dy.size(2) && color.size(3) == dy.size(3), "color and dy inputs must have same dimensions"); + if (p.instance_mode) + { + NVDR_CHECK(pos.sizes().size() == 3 && pos.size(0) > 0 && pos.size(1) > 0 && pos.size(2) == 4, "pos must have shape [>0, >0, 4] or [>0, 4]"); + NVDR_CHECK(rast.size(0) == color.size(0) && pos.size(0) == color.size(0), "minibatch size mismatch between inputs color, raster_out, pos"); + NVDR_CHECK(dy.size(0) == color.size(0) && rast.size(0) == color.size(0) && pos.size(0) ==color.size(0), "minibatch size mismatch between inputs dy, color, raster_out, pos"); + } + else + { + NVDR_CHECK(pos.sizes().size() == 2 && pos.size(0) > 0 && pos.size(1) == 4, "pos must have shape [>0, >0, 4] or [>0, 4]"); + NVDR_CHECK(rast.size(0) == color.size(0), "minibatch size mismatch between inputs color, raster_out"); + NVDR_CHECK(dy.size(0) == color.size(0) && rast.size(0) == color.size(0), "minibatch size mismatch between inputs dy, color, raster_out"); + } + + // Extract input dimensions. + p.numVertices = pos.size(p.instance_mode ? 1 : 0); + p.numTriangles = tri.size(0); + p.n = color.size(0); + p.height = color.size(1); + p.width = color.size(2); + p.channels = color.size(3); + + // Ensure dy is contiguous. + torch::Tensor dy_ = dy.contiguous(); + + // Get input pointers. + p.color = color.data_ptr(); + p.rasterOut = rast.data_ptr(); + p.tri = tri.data_ptr(); + p.pos = pos.data_ptr(); + p.dy = dy_.data_ptr(); + p.workBuffer = (int4*)(work_buffer.data_ptr()); + + // Misc parameters. + p.xh = .5f * (float)p.width; + p.yh = .5f * (float)p.height; + + // Allocate output tensors. + torch::Tensor grad_color = dy_.detach().clone(); // Use dy as base. + torch::Tensor grad_pos = torch::zeros_like(pos); + p.gradColor = grad_color.data_ptr(); + p.gradPos = grad_pos.data_ptr(); + + // Clear gradient kernel work counter. + NVDR_CHECK_CUDA_ERROR(cudaMemsetAsync(&p.workBuffer[0].y, 0, sizeof(int), stream)); + + // Verify that buffers are aligned to allow float2/float4 operations. + NVDR_CHECK(!((uintptr_t)p.pos & 15), "pos input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.workBuffer & 15), "work_buffer internal tensor not aligned to int4"); + + // Determine optimum block size for the gradient kernel and launch. + void* args[] = {&p}; + int device = 0; + int numCTA = 0; + int numSM = 0; + NVDR_CHECK_CUDA_ERROR(cudaGetDevice(&device)); + NVDR_CHECK_CUDA_ERROR(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numCTA, (void*)AntialiasGradKernel, AA_GRAD_KERNEL_THREADS_PER_BLOCK, 0)); + NVDR_CHECK_CUDA_ERROR(cudaDeviceGetAttribute(&numSM, cudaDevAttrMultiProcessorCount, device)); + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((void*)AntialiasGradKernel, numCTA * numSM, AA_GRAD_KERNEL_THREADS_PER_BLOCK, args, 0, stream)); + + // Return results. + return std::tuple(grad_color, grad_pos); +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/torch/torch_bindings.cpp b/extensions/nvdiffrast/nvdiffrast/torch/torch_bindings.cpp new file mode 100644 index 0000000000000000000000000000000000000000..431b68c629504510e474d92909395c4d3d06c34c --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/torch/torch_bindings.cpp @@ -0,0 +1,73 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "torch_common.inl" +#include "torch_types.h" +#include + +//------------------------------------------------------------------------ +// Op prototypes. Return type macros for readability. + +#define OP_RETURN_T torch::Tensor +#define OP_RETURN_TT std::tuple +#define OP_RETURN_TTT std::tuple +#define OP_RETURN_TTTT std::tuple +#define OP_RETURN_TTV std::tuple > +#define OP_RETURN_TTTTV std::tuple > + +OP_RETURN_TT rasterize_fwd_cuda (RasterizeCRStateWrapper& stateWrapper, torch::Tensor pos, torch::Tensor tri, std::tuple resolution, torch::Tensor ranges, int peeling_idx); +OP_RETURN_T rasterize_grad (torch::Tensor pos, torch::Tensor tri, torch::Tensor out, torch::Tensor dy); +OP_RETURN_T rasterize_grad_db (torch::Tensor pos, torch::Tensor tri, torch::Tensor out, torch::Tensor dy, torch::Tensor ddb); +OP_RETURN_TT interpolate_fwd (torch::Tensor attr, torch::Tensor rast, torch::Tensor tri); +OP_RETURN_TT interpolate_fwd_da (torch::Tensor attr, torch::Tensor rast, torch::Tensor tri, torch::Tensor rast_db, bool diff_attrs_all, std::vector& diff_attrs_vec); +OP_RETURN_TT interpolate_grad (torch::Tensor attr, torch::Tensor rast, torch::Tensor tri, torch::Tensor dy); +OP_RETURN_TTT interpolate_grad_da (torch::Tensor attr, torch::Tensor rast, torch::Tensor tri, torch::Tensor dy, torch::Tensor rast_db, torch::Tensor dda, bool diff_attrs_all, std::vector& diff_attrs_vec); +TextureMipWrapper texture_construct_mip (torch::Tensor tex, int max_mip_level, bool cube_mode); +OP_RETURN_T texture_fwd (torch::Tensor tex, torch::Tensor uv, int filter_mode, int boundary_mode); +OP_RETURN_T texture_fwd_mip (torch::Tensor tex, torch::Tensor uv, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip_wrapper, std::vector mip_stack, int filter_mode, int boundary_mode); +OP_RETURN_T texture_grad_nearest (torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, int filter_mode, int boundary_mode); +OP_RETURN_TT texture_grad_linear (torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, int filter_mode, int boundary_mode); +OP_RETURN_TTV texture_grad_linear_mipmap_nearest (torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip_wrapper, std::vector mip_stack, int filter_mode, int boundary_mode); +OP_RETURN_TTTTV texture_grad_linear_mipmap_linear (torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip_wrapper, std::vector mip_stack, int filter_mode, int boundary_mode); +TopologyHashWrapper antialias_construct_topology_hash (torch::Tensor tri); +OP_RETURN_TT antialias_fwd (torch::Tensor color, torch::Tensor rast, torch::Tensor pos, torch::Tensor tri, TopologyHashWrapper topology_hash); +OP_RETURN_TT antialias_grad (torch::Tensor color, torch::Tensor rast, torch::Tensor pos, torch::Tensor tri, torch::Tensor dy, torch::Tensor work_buffer); + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // State classes. + pybind11::class_(m, "RasterizeCRStateWrapper").def(pybind11::init()); + pybind11::class_(m, "TextureMipWrapper").def(pybind11::init<>()); + pybind11::class_(m, "TopologyHashWrapper"); + + // Plumbing to torch/c10 logging system. + m.def("get_log_level", [](void) { return FLAGS_caffe2_log_level; }, "get log level"); + m.def("set_log_level", [](int level){ FLAGS_caffe2_log_level = level; }, "set log level"); + + // Ops. + m.def("rasterize_fwd_cuda", &rasterize_fwd_cuda, "rasterize forward op (cuda)"); + m.def("rasterize_grad", &rasterize_grad, "rasterize gradient op ignoring db gradients"); + m.def("rasterize_grad_db", &rasterize_grad_db, "rasterize gradient op with db gradients"); + m.def("interpolate_fwd", &interpolate_fwd, "interpolate forward op with attribute derivatives"); + m.def("interpolate_fwd_da", &interpolate_fwd_da, "interpolate forward op without attribute derivatives"); + m.def("interpolate_grad", &interpolate_grad, "interpolate gradient op with attribute derivatives"); + m.def("interpolate_grad_da", &interpolate_grad_da, "interpolate gradient op without attribute derivatives"); + m.def("texture_construct_mip", &texture_construct_mip, "texture mipmap construction"); + m.def("texture_fwd", &texture_fwd, "texture forward op without mipmapping"); + m.def("texture_fwd_mip", &texture_fwd_mip, "texture forward op with mipmapping"); + m.def("texture_grad_nearest", &texture_grad_nearest, "texture gradient op in nearest mode"); + m.def("texture_grad_linear", &texture_grad_linear, "texture gradient op in linear mode"); + m.def("texture_grad_linear_mipmap_nearest", &texture_grad_linear_mipmap_nearest, "texture gradient op in linear-mipmap-nearest mode"); + m.def("texture_grad_linear_mipmap_linear", &texture_grad_linear_mipmap_linear, "texture gradient op in linear-mipmap-linear mode"); + m.def("antialias_construct_topology_hash", &antialias_construct_topology_hash, "antialias topology hash construction"); + m.def("antialias_fwd", &antialias_fwd, "antialias forward op"); + m.def("antialias_grad", &antialias_grad, "antialias gradient op"); +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/torch/torch_bindings_gl.cpp b/extensions/nvdiffrast/nvdiffrast/torch/torch_bindings_gl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f403d73cf3149531ccedd277e6b8b26afa5cdda4 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/torch/torch_bindings_gl.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "torch_common.inl" +#include "torch_types.h" +#include + +//------------------------------------------------------------------------ +// Op prototypes. + +std::tuple rasterize_fwd_gl(RasterizeGLStateWrapper& stateWrapper, torch::Tensor pos, torch::Tensor tri, std::tuple resolution, torch::Tensor ranges, int peeling_idx); + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // State classes. + pybind11::class_(m, "RasterizeGLStateWrapper").def(pybind11::init()) + .def("set_context", &RasterizeGLStateWrapper::setContext) + .def("release_context", &RasterizeGLStateWrapper::releaseContext); + + // Ops. + m.def("rasterize_fwd_gl", &rasterize_fwd_gl, "rasterize forward op (opengl)"); +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/torch/torch_common.inl b/extensions/nvdiffrast/nvdiffrast/torch/torch_common.inl new file mode 100644 index 0000000000000000000000000000000000000000..3809ecef8b091fa911d1fdf89c5b84dc1038e9a1 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/torch/torch_common.inl @@ -0,0 +1,29 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#pragma once +#include "../common/framework.h" + +//------------------------------------------------------------------------ +// Input check helpers. +//------------------------------------------------------------------------ + +#ifdef _MSC_VER +#define __func__ __FUNCTION__ +#endif + +#define NVDR_CHECK_DEVICE(...) do { TORCH_CHECK(at::cuda::check_device({__VA_ARGS__}), __func__, "(): Inputs " #__VA_ARGS__ " must reside on the same GPU device") } while(0) +#define NVDR_CHECK_CPU(...) do { nvdr_check_cpu({__VA_ARGS__}, __func__, "(): Inputs " #__VA_ARGS__ " must reside on CPU"); } while(0) +#define NVDR_CHECK_CONTIGUOUS(...) do { nvdr_check_contiguous({__VA_ARGS__}, __func__, "(): Inputs " #__VA_ARGS__ " must be contiguous tensors"); } while(0) +#define NVDR_CHECK_F32(...) do { nvdr_check_f32({__VA_ARGS__}, __func__, "(): Inputs " #__VA_ARGS__ " must be float32 tensors"); } while(0) +#define NVDR_CHECK_I32(...) do { nvdr_check_i32({__VA_ARGS__}, __func__, "(): Inputs " #__VA_ARGS__ " must be int32 tensors"); } while(0) +inline void nvdr_check_cpu(at::ArrayRef ts, const char* func, const char* err_msg) { for (const at::Tensor& t : ts) TORCH_CHECK(t.device().type() == c10::DeviceType::CPU, func, err_msg); } +inline void nvdr_check_contiguous(at::ArrayRef ts, const char* func, const char* err_msg) { for (const at::Tensor& t : ts) TORCH_CHECK(t.is_contiguous(), func, err_msg); } +inline void nvdr_check_f32(at::ArrayRef ts, const char* func, const char* err_msg) { for (const at::Tensor& t : ts) TORCH_CHECK(t.dtype() == torch::kFloat32, func, err_msg); } +inline void nvdr_check_i32(at::ArrayRef ts, const char* func, const char* err_msg) { for (const at::Tensor& t : ts) TORCH_CHECK(t.dtype() == torch::kInt32, func, err_msg); } +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/torch/torch_interpolate.cpp b/extensions/nvdiffrast/nvdiffrast/torch/torch_interpolate.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c3f303a2ec9ae02a1db81d326966415efca6568e --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/torch/torch_interpolate.cpp @@ -0,0 +1,250 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "torch_common.inl" +#include "../common/common.h" +#include "../common/interpolate.h" + +//------------------------------------------------------------------------ +// Kernel prototypes. + +void InterpolateFwdKernel (const InterpolateKernelParams p); +void InterpolateFwdKernelDa (const InterpolateKernelParams p); +void InterpolateGradKernel (const InterpolateKernelParams p); +void InterpolateGradKernelDa(const InterpolateKernelParams p); + +//------------------------------------------------------------------------ +// Helper + +static void set_diff_attrs(InterpolateKernelParams& p, bool diff_attrs_all, std::vector& diff_attrs_vec) +{ + if (diff_attrs_all) + { + p.numDiffAttr = p.numAttr; + p.diff_attrs_all = 1; + } + else + { + NVDR_CHECK(diff_attrs_vec.size() <= IP_MAX_DIFF_ATTRS, "too many entries in diff_attrs list (increase IP_MAX_DIFF_ATTRS)"); + p.numDiffAttr = diff_attrs_vec.size(); + memcpy(p.diffAttrs, &diff_attrs_vec[0], diff_attrs_vec.size()*sizeof(int)); + } +} + +//------------------------------------------------------------------------ +// Forward op. + +std::tuple interpolate_fwd_da(torch::Tensor attr, torch::Tensor rast, torch::Tensor tri, torch::Tensor rast_db, bool diff_attrs_all, std::vector& diff_attrs_vec) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(attr)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + InterpolateKernelParams p = {}; // Initialize all fields to zero. + bool enable_da = (rast_db.defined()) && (diff_attrs_all || !diff_attrs_vec.empty()); + p.instance_mode = (attr.sizes().size() > 2) ? 1 : 0; + + // Check inputs. + if (enable_da) + { + NVDR_CHECK_DEVICE(attr, rast, tri, rast_db); + NVDR_CHECK_CONTIGUOUS(attr, rast, tri, rast_db); + NVDR_CHECK_F32(attr, rast, rast_db); + NVDR_CHECK_I32(tri); + } + else + { + NVDR_CHECK_DEVICE(attr, rast, tri); + NVDR_CHECK_CONTIGUOUS(attr, rast, tri); + NVDR_CHECK_F32(attr, rast); + NVDR_CHECK_I32(tri); + } + + // Sanity checks. + NVDR_CHECK(rast.sizes().size() == 4 && rast.size(0) > 0 && rast.size(1) > 0 && rast.size(2) > 0 && rast.size(3) == 4, "rast must have shape[>0, >0, >0, 4]"); + NVDR_CHECK( tri.sizes().size() == 2 && tri.size(0) > 0 && tri.size(1) == 3, "tri must have shape [>0, 3]"); + NVDR_CHECK((attr.sizes().size() == 2 || attr.sizes().size() == 3) && attr.size(0) > 0 && attr.size(1) > 0 && (attr.sizes().size() == 2 || attr.size(2) > 0), "attr must have shape [>0, >0, >0] or [>0, >0]"); + if (p.instance_mode) + NVDR_CHECK(attr.size(0) == rast.size(0) || attr.size(0) == 1, "minibatch size mismatch between inputs rast, attr"); + if (enable_da) + { + NVDR_CHECK(rast_db.sizes().size() == 4 && rast_db.size(0) > 0 && rast_db.size(1) > 0 && rast_db.size(2) > 0 && rast_db.size(3) == 4, "rast_db must have shape[>0, >0, >0, 4]"); + NVDR_CHECK(rast_db.size(1) == rast.size(1) && rast_db.size(2) == rast.size(2), "spatial size mismatch between inputs rast and rast_db"); + NVDR_CHECK(rast_db.size(0) == rast.size(0), "minibatch size mismatch between inputs rast, rast_db"); + } + + // Extract input dimensions. + p.numVertices = attr.size(p.instance_mode ? 1 : 0); + p.numAttr = attr.size(p.instance_mode ? 2 : 1); + p.numTriangles = tri.size(0); + p.height = rast.size(1); + p.width = rast.size(2); + p.depth = rast.size(0); + + // Set attribute pixel differential info if enabled, otherwise leave as zero. + if (enable_da) + set_diff_attrs(p, diff_attrs_all, diff_attrs_vec); + else + p.numDiffAttr = 0; + + // Get input pointers. + p.attr = attr.data_ptr(); + p.rast = rast.data_ptr(); + p.tri = tri.data_ptr(); + p.rastDB = enable_da ? rast_db.data_ptr() : NULL; + p.attrBC = (p.instance_mode && attr.size(0) == 1) ? 1 : 0; + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({p.depth, p.height, p.width, p.numAttr}, opts); + torch::Tensor out_da = torch::empty({p.depth, p.height, p.width, p.numDiffAttr * 2}, opts); + + p.out = out.data_ptr(); + p.outDA = enable_da ? out_da.data_ptr() : NULL; + + // Verify that buffers are aligned to allow float2/float4 operations. + NVDR_CHECK(!((uintptr_t)p.rast & 15), "rast input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.rastDB & 15), "rast_db input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.outDA & 7), "out_da output tensor not aligned to float2"); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(IP_FWD_MAX_KERNEL_BLOCK_WIDTH, IP_FWD_MAX_KERNEL_BLOCK_HEIGHT, p.width, p.height); + dim3 gridSize = getLaunchGridSize(blockSize, p.width, p.height, p.depth); + + // Launch CUDA kernel. + void* args[] = {&p}; + void* func = enable_da ? (void*)InterpolateFwdKernelDa : (void*)InterpolateFwdKernel; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel(func, gridSize, blockSize, args, 0, stream)); + + // Return results. + return std::tuple(out, out_da); +} + +// Version without derivatives. +std::tuple interpolate_fwd(torch::Tensor attr, torch::Tensor rast, torch::Tensor tri) +{ + std::vector empty_vec; + torch::Tensor empty_tensor; + return interpolate_fwd_da(attr, rast, tri, empty_tensor, false, empty_vec); +} + +//------------------------------------------------------------------------ +// Gradient op. + +std::tuple interpolate_grad_da(torch::Tensor attr, torch::Tensor rast, torch::Tensor tri, torch::Tensor dy, torch::Tensor rast_db, torch::Tensor dda, bool diff_attrs_all, std::vector& diff_attrs_vec) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(attr)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + InterpolateKernelParams p = {}; // Initialize all fields to zero. + bool enable_da = (rast_db.defined()) && (diff_attrs_all || !diff_attrs_vec.empty()); + p.instance_mode = (attr.sizes().size() > 2) ? 1 : 0; + + // Check inputs. + if (enable_da) + { + NVDR_CHECK_DEVICE(attr, rast, tri, dy, rast_db, dda); + NVDR_CHECK_CONTIGUOUS(attr, rast, tri, rast_db); + NVDR_CHECK_F32(attr, rast, dy, rast_db, dda); + NVDR_CHECK_I32(tri); + } + else + { + NVDR_CHECK_DEVICE(attr, rast, tri, dy); + NVDR_CHECK_CONTIGUOUS(attr, rast, tri); + NVDR_CHECK_F32(attr, rast, dy); + NVDR_CHECK_I32(tri); + } + + // Depth of attributes. + int attr_depth = p.instance_mode ? (attr.sizes().size() > 1 ? attr.size(0) : 0) : 1; + + // Sanity checks. + NVDR_CHECK(rast.sizes().size() == 4 && rast.size(0) > 0 && rast.size(1) > 0 && rast.size(2) > 0 && rast.size(3) == 4, "rast must have shape[>0, >0, >0, 4]"); + NVDR_CHECK(tri.sizes().size() == 2 && tri.size(0) > 0 && tri.size(1) == 3, "tri must have shape [>0, 3]"); + NVDR_CHECK((attr.sizes().size() == 2 || attr.sizes().size() == 3) && attr.size(0) > 0 && attr.size(1) > 0 && (attr.sizes().size() == 2 || attr.size(2) > 0), "attr must have shape [>0, >0, >0] or [>0, >0]"); + NVDR_CHECK(dy.sizes().size() == 4 && dy.size(0) > 0 && dy.size(1) == rast.size(1) && dy.size(2) == rast.size(2) && dy.size(3) > 0, "dy must have shape [>0, height, width, >0]"); + NVDR_CHECK(dy.size(3) == attr.size(attr.sizes().size() - 1), "argument count mismatch between inputs dy, attr"); + NVDR_CHECK((attr_depth == rast.size(0) || attr_depth == 1) && dy.size(0) == rast.size(0), "minibatch size mismatch between inputs rast, dy, attr"); + if (enable_da) + { + NVDR_CHECK(dda.sizes().size() == 4 && dda.size(0) > 0 && dda.size(1) == rast.size(1) && dda.size(2) == rast.size(2), "dda must have shape [>0, height, width, ?]"); + NVDR_CHECK(dda.size(0) == rast.size(0), "minibatch size mismatch between rast, dda"); + NVDR_CHECK(rast_db.sizes().size() == 4 && rast_db.size(0) > 0 && rast_db.size(1) > 0 && rast_db.size(2) > 0 && rast_db.size(3) == 4, "rast_db must have shape[>0, >0, >0, 4]"); + NVDR_CHECK(rast_db.size(1) == rast.size(1) && rast_db.size(2) == rast.size(2), "spatial size mismatch between inputs rast and rast_db"); + NVDR_CHECK(rast_db.size(0) == rast.size(0), "minibatch size mismatch between inputs rast, rast_db"); + } + + // Extract input dimensions. + p.numVertices = attr.size(p.instance_mode ? 1 : 0); + p.numAttr = attr.size(p.instance_mode ? 2 : 1); + p.numTriangles = tri.size(0); + p.height = rast.size(1); + p.width = rast.size(2); + p.depth = rast.size(0); + + // Ensure gradients are contiguous. + torch::Tensor dy_ = dy.contiguous(); + torch::Tensor dda_; + if (enable_da) + dda_ = dda.contiguous(); + + // Set attribute pixel differential info if enabled, otherwise leave as zero. + if (enable_da) + set_diff_attrs(p, diff_attrs_all, diff_attrs_vec); + else + p.numDiffAttr = 0; + + // Get input pointers. + p.attr = attr.data_ptr(); + p.rast = rast.data_ptr(); + p.tri = tri.data_ptr(); + p.dy = dy_.data_ptr(); + p.rastDB = enable_da ? rast_db.data_ptr() : NULL; + p.dda = enable_da ? dda_.data_ptr() : NULL; + p.attrBC = (p.instance_mode && attr_depth < p.depth) ? 1 : 0; + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor gradAttr = torch::zeros_like(attr); + torch::Tensor gradRaster = torch::empty_like(rast); + torch::Tensor gradRasterDB; + if (enable_da) + gradRasterDB = torch::empty_like(rast_db); + + p.gradAttr = gradAttr.data_ptr(); + p.gradRaster = gradRaster.data_ptr(); + p.gradRasterDB = enable_da ? gradRasterDB.data_ptr() : NULL; + + // Verify that buffers are aligned to allow float2/float4 operations. + NVDR_CHECK(!((uintptr_t)p.rast & 15), "rast input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.rastDB & 15), "rast_db input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.dda & 7), "dda input tensor not aligned to float2"); + NVDR_CHECK(!((uintptr_t)p.gradRaster & 15), "grad_rast output tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.gradRasterDB & 15), "grad_rast_db output tensor not aligned to float4"); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(IP_GRAD_MAX_KERNEL_BLOCK_WIDTH, IP_GRAD_MAX_KERNEL_BLOCK_HEIGHT, p.width, p.height); + dim3 gridSize = getLaunchGridSize(blockSize, p.width, p.height, p.depth); + + // Launch CUDA kernel. + void* args[] = {&p}; + void* func = enable_da ? (void*)InterpolateGradKernelDa : (void*)InterpolateGradKernel; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel(func, gridSize, blockSize, args, 0, stream)); + + // Return results. + return std::tuple(gradAttr, gradRaster, gradRasterDB); +} + +// Version without derivatives. +std::tuple interpolate_grad(torch::Tensor attr, torch::Tensor rast, torch::Tensor tri, torch::Tensor dy) +{ + std::vector empty_vec; + torch::Tensor empty_tensor; + std::tuple result = interpolate_grad_da(attr, rast, tri, dy, empty_tensor, empty_tensor, false, empty_vec); + return std::tuple(std::get<0>(result), std::get<1>(result)); +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/torch/torch_rasterize.cpp b/extensions/nvdiffrast/nvdiffrast/torch/torch_rasterize.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6f2ef697aa9bd409c82674bd2c1d921140a691d7 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/torch/torch_rasterize.cpp @@ -0,0 +1,265 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "torch_common.inl" +#include "torch_types.h" +#include "../common/common.h" +#include "../common/rasterize.h" +#include "../common/cudaraster/CudaRaster.hpp" +#include "../common/cudaraster/impl/Constants.hpp" +#include + +//------------------------------------------------------------------------ +// Kernel prototypes. + +void RasterizeCudaFwdShaderKernel(const RasterizeCudaFwdShaderParams p); +void RasterizeGradKernel(const RasterizeGradParams p); +void RasterizeGradKernelDb(const RasterizeGradParams p); + +//------------------------------------------------------------------------ +// Python CudaRaster state wrapper methods. + +RasterizeCRStateWrapper::RasterizeCRStateWrapper(int cudaDeviceIdx_) +{ + const at::cuda::OptionalCUDAGuard device_guard(cudaDeviceIdx_); + cudaDeviceIdx = cudaDeviceIdx_; + cr = new CR::CudaRaster(); +} + +RasterizeCRStateWrapper::~RasterizeCRStateWrapper(void) +{ + const at::cuda::OptionalCUDAGuard device_guard(cudaDeviceIdx); + delete cr; +} + +//------------------------------------------------------------------------ +// Forward op (Cuda). + +std::tuple rasterize_fwd_cuda(RasterizeCRStateWrapper& stateWrapper, torch::Tensor pos, torch::Tensor tri, std::tuple resolution, torch::Tensor ranges, int peeling_idx) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(pos)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + CR::CudaRaster* cr = stateWrapper.cr; + + // Check inputs. + NVDR_CHECK_DEVICE(pos, tri); + NVDR_CHECK_CPU(ranges); + NVDR_CHECK_CONTIGUOUS(pos, tri, ranges); + NVDR_CHECK_F32(pos); + NVDR_CHECK_I32(tri, ranges); + + // Check that CudaRaster context was created for the correct GPU. + NVDR_CHECK(pos.get_device() == stateWrapper.cudaDeviceIdx, "CudaRaster context must must reside on the same device as input tensors"); + + // Determine instance mode and check input dimensions. + bool instance_mode = pos.sizes().size() > 2; + if (instance_mode) + NVDR_CHECK(pos.sizes().size() == 3 && pos.size(0) > 0 && pos.size(1) > 0 && pos.size(2) == 4, "instance mode - pos must have shape [>0, >0, 4]"); + else + { + NVDR_CHECK(pos.sizes().size() == 2 && pos.size(0) > 0 && pos.size(1) == 4, "range mode - pos must have shape [>0, 4]"); + NVDR_CHECK(ranges.sizes().size() == 2 && ranges.size(0) > 0 && ranges.size(1) == 2, "range mode - ranges must have shape [>0, 2]"); + } + NVDR_CHECK(tri.sizes().size() == 2 && tri.size(0) > 0 && tri.size(1) == 3, "tri must have shape [>0, 3]"); + + // Get output shape. + int height_out = std::get<0>(resolution); + int width_out = std::get<1>(resolution); + int depth = instance_mode ? pos.size(0) : ranges.size(0); // Depth of tensor, not related to depth buffering. + NVDR_CHECK(height_out > 0 && width_out > 0, "resolution must be [>0, >0]"); + + // Round internal resolution up to tile size. + int height = (height_out + CR_TILE_SIZE - 1) & (-CR_TILE_SIZE); + int width = (width_out + CR_TILE_SIZE - 1) & (-CR_TILE_SIZE); + + // Get position and triangle buffer sizes in vertices / triangles. + int posCount = instance_mode ? pos.size(1) : pos.size(0); + int triCount = tri.size(0); + + // Set up CudaRaster buffers. + const float* posPtr = pos.data_ptr(); + const int32_t* rangesPtr = instance_mode ? 0 : ranges.data_ptr(); // This is in CPU memory. + const int32_t* triPtr = tri.data_ptr(); + cr->setVertexBuffer((void*)posPtr, posCount); + cr->setIndexBuffer((void*)triPtr, triCount); + cr->setBufferSize(width_out, height_out, depth); + + // Enable depth peeling? + bool enablePeel = (peeling_idx > 0); + cr->setRenderModeFlags(enablePeel ? CR::CudaRaster::RenderModeFlag_EnableDepthPeeling : 0); // No backface culling. + if (enablePeel) + cr->swapDepthAndPeel(); // Use previous depth buffer as peeling depth input. + + // Determine viewport tiling. + int tileCountX = (width + CR_MAXVIEWPORT_SIZE - 1) / CR_MAXVIEWPORT_SIZE; + int tileCountY = (height + CR_MAXVIEWPORT_SIZE - 1) / CR_MAXVIEWPORT_SIZE; + int tileSizeX = ((width + tileCountX - 1) / tileCountX + CR_TILE_SIZE - 1) & (-CR_TILE_SIZE); + int tileSizeY = ((height + tileCountY - 1) / tileCountY + CR_TILE_SIZE - 1) & (-CR_TILE_SIZE); + TORCH_CHECK(tileCountX > 0 && tileCountY > 0 && tileSizeX > 0 && tileSizeY > 0, "internal error in tile size calculation: count or size is zero"); + TORCH_CHECK(tileSizeX <= CR_MAXVIEWPORT_SIZE && tileSizeY <= CR_MAXVIEWPORT_SIZE, "internal error in tile size calculation: tile larger than allowed"); + TORCH_CHECK((tileSizeX & (CR_TILE_SIZE - 1)) == 0 && (tileSizeY & (CR_TILE_SIZE - 1)) == 0, "internal error in tile size calculation: tile not divisible by ", CR_TILE_SIZE); + TORCH_CHECK(tileCountX * tileSizeX >= width && tileCountY * tileSizeY >= height, "internal error in tile size calculation: tiles do not cover viewport"); + + // Rasterize in tiles. + for (int tileY = 0; tileY < tileCountY; tileY++) + for (int tileX = 0; tileX < tileCountX; tileX++) + { + // Set CudaRaster viewport according to tile. + int offsetX = tileX * tileSizeX; + int offsetY = tileY * tileSizeY; + int sizeX = (width_out - offsetX) < tileSizeX ? (width_out - offsetX) : tileSizeX; + int sizeY = (height_out - offsetY) < tileSizeY ? (height_out - offsetY) : tileSizeY; + cr->setViewport(sizeX, sizeY, offsetX, offsetY); + + // Run all triangles in one batch. In case of error, the workload could be split into smaller batches - maybe do that in the future. + // Only enable peeling-specific optimizations to skip first stages when image fits in one tile. Those are not valid otherwise. + cr->deferredClear(0u); + bool success = cr->drawTriangles(rangesPtr, enablePeel && (tileCountX == 1 && tileCountY == 1), stream); + NVDR_CHECK(success, "subtriangle count overflow"); + } + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({depth, height_out, width_out, 4}, opts); + torch::Tensor out_db = torch::empty({depth, height_out, width_out, 4}, opts); + + // Populate pixel shader kernel parameters. + RasterizeCudaFwdShaderParams p; + p.pos = posPtr; + p.tri = triPtr; + p.in_idx = (const int*)cr->getColorBuffer(); + p.out = out.data_ptr(); + p.out_db = out_db.data_ptr(); + p.numTriangles = triCount; + p.numVertices = posCount; + p.width_in = width; + p.height_in = height; + p.width_out = width_out; + p.height_out = height_out; + p.depth = depth; + p.instance_mode = (pos.sizes().size() > 2) ? 1 : 0; + p.xs = 2.f / (float)width_out; + p.xo = 1.f / (float)width_out - 1.f; + p.ys = 2.f / (float)height_out; + p.yo = 1.f / (float)height_out - 1.f; + + // Verify that buffers are aligned to allow float2/float4 operations. + NVDR_CHECK(!((uintptr_t)p.pos & 15), "pos input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.out & 15), "out output tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.out_db & 15), "out_db output tensor not aligned to float4"); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(RAST_CUDA_FWD_SHADER_KERNEL_BLOCK_WIDTH, RAST_CUDA_FWD_SHADER_KERNEL_BLOCK_HEIGHT, p.width_out, p.height_out); + dim3 gridSize = getLaunchGridSize(blockSize, p.width_out, p.height_out, p.depth); + + // Launch CUDA kernel. + void* args[] = {&p}; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((void*)RasterizeCudaFwdShaderKernel, gridSize, blockSize, args, 0, stream)); + + // Return. + return std::tuple(out, out_db); +} + +//------------------------------------------------------------------------ +// Gradient op. + +torch::Tensor rasterize_grad_db(torch::Tensor pos, torch::Tensor tri, torch::Tensor out, torch::Tensor dy, torch::Tensor ddb) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(pos)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + RasterizeGradParams p; + bool enable_db = ddb.defined(); + + // Check inputs. + if (enable_db) + { + NVDR_CHECK_DEVICE(pos, tri, out, dy, ddb); + NVDR_CHECK_CONTIGUOUS(pos, tri, out); + NVDR_CHECK_F32(pos, out, dy, ddb); + NVDR_CHECK_I32(tri); + } + else + { + NVDR_CHECK_DEVICE(pos, tri, out, dy); + NVDR_CHECK_CONTIGUOUS(pos, tri, out); + NVDR_CHECK_F32(pos, out, dy); + NVDR_CHECK_I32(tri); + } + + // Determine instance mode. + p.instance_mode = (pos.sizes().size() > 2) ? 1 : 0; + + // Shape is taken from the rasterizer output tensor. + NVDR_CHECK(out.sizes().size() == 4, "tensor out must be rank-4"); + p.depth = out.size(0); + p.height = out.size(1); + p.width = out.size(2); + NVDR_CHECK(p.depth > 0 && p.height > 0 && p.width > 0, "resolution must be [>0, >0, >0]"); + + // Check other shapes. + if (p.instance_mode) + NVDR_CHECK(pos.sizes().size() == 3 && pos.size(0) == p.depth && pos.size(1) > 0 && pos.size(2) == 4, "pos must have shape [depth, >0, 4]"); + else + NVDR_CHECK(pos.sizes().size() == 2 && pos.size(0) > 0 && pos.size(1) == 4, "pos must have shape [>0, 4]"); + NVDR_CHECK(tri.sizes().size() == 2 && tri.size(0) > 0 && tri.size(1) == 3, "tri must have shape [>0, 3]"); + NVDR_CHECK(out.sizes().size() == 4 && out.size(0) == p.depth && out.size(1) == p.height && out.size(2) == p.width && out.size(3) == 4, "out must have shape [depth, height, width, 4]"); + NVDR_CHECK( dy.sizes().size() == 4 && dy.size(0) == p.depth && dy.size(1) == p.height && dy.size(2) == p.width && dy.size(3) == 4, "dy must have shape [depth, height, width, 4]"); + if (enable_db) + NVDR_CHECK(ddb.sizes().size() == 4 && ddb.size(0) == p.depth && ddb.size(1) == p.height && ddb.size(2) == p.width && ddb.size(3) == 4, "ddb must have shape [depth, height, width, 4]"); + + // Ensure gradients are contiguous. + torch::Tensor dy_ = dy.contiguous(); + torch::Tensor ddb_; + if (enable_db) + ddb_ = ddb.contiguous(); + + // Populate parameters. + p.numTriangles = tri.size(0); + p.numVertices = p.instance_mode ? pos.size(1) : pos.size(0); + p.pos = pos.data_ptr(); + p.tri = tri.data_ptr(); + p.out = out.data_ptr(); + p.dy = dy_.data_ptr(); + p.ddb = enable_db ? ddb_.data_ptr() : NULL; + + // Set up pixel position to clip space x, y transform. + p.xs = 2.f / (float)p.width; + p.xo = 1.f / (float)p.width - 1.f; + p.ys = 2.f / (float)p.height; + p.yo = 1.f / (float)p.height - 1.f; + + // Allocate output tensor for position gradients. + torch::Tensor grad = torch::zeros_like(pos); + p.grad = grad.data_ptr(); + + // Verify that buffers are aligned to allow float2/float4 operations. + NVDR_CHECK(!((uintptr_t)p.pos & 15), "pos input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.dy & 7), "dy input tensor not aligned to float2"); + NVDR_CHECK(!((uintptr_t)p.ddb & 15), "ddb input tensor not aligned to float4"); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(RAST_GRAD_MAX_KERNEL_BLOCK_WIDTH, RAST_GRAD_MAX_KERNEL_BLOCK_HEIGHT, p.width, p.height); + dim3 gridSize = getLaunchGridSize(blockSize, p.width, p.height, p.depth); + + // Launch CUDA kernel. + void* args[] = {&p}; + void* func = enable_db ? (void*)RasterizeGradKernelDb : (void*)RasterizeGradKernel; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel(func, gridSize, blockSize, args, 0, stream)); + + // Return the gradients. + return grad; +} + +// Version without derivatives. +torch::Tensor rasterize_grad(torch::Tensor pos, torch::Tensor tri, torch::Tensor out, torch::Tensor dy) +{ + torch::Tensor empty_tensor; + return rasterize_grad_db(pos, tri, out, dy, empty_tensor); +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/torch/torch_rasterize_gl.cpp b/extensions/nvdiffrast/nvdiffrast/torch/torch_rasterize_gl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..76f875d6adaa273107ec7e44f7724cf59e90de48 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/torch/torch_rasterize_gl.cpp @@ -0,0 +1,132 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "torch_common.inl" +#include "torch_types.h" +#include "../common/common.h" +#include "../common/rasterize_gl.h" +#include + +//------------------------------------------------------------------------ +// Python GL state wrapper methods. + +RasterizeGLStateWrapper::RasterizeGLStateWrapper(bool enableDB, bool automatic_, int cudaDeviceIdx_) +{ + pState = new RasterizeGLState(); + automatic = automatic_; + cudaDeviceIdx = cudaDeviceIdx_; + memset(pState, 0, sizeof(RasterizeGLState)); + pState->enableDB = enableDB ? 1 : 0; + rasterizeInitGLContext(NVDR_CTX_PARAMS, *pState, cudaDeviceIdx_); + releaseGLContext(); +} + +RasterizeGLStateWrapper::~RasterizeGLStateWrapper(void) +{ + setGLContext(pState->glctx); + rasterizeReleaseBuffers(NVDR_CTX_PARAMS, *pState); + releaseGLContext(); + destroyGLContext(pState->glctx); + delete pState; +} + +void RasterizeGLStateWrapper::setContext(void) +{ + setGLContext(pState->glctx); +} + +void RasterizeGLStateWrapper::releaseContext(void) +{ + releaseGLContext(); +} + +//------------------------------------------------------------------------ +// Forward op (OpenGL). + +std::tuple rasterize_fwd_gl(RasterizeGLStateWrapper& stateWrapper, torch::Tensor pos, torch::Tensor tri, std::tuple resolution, torch::Tensor ranges, int peeling_idx) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(pos)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + RasterizeGLState& s = *stateWrapper.pState; + + // Check inputs. + NVDR_CHECK_DEVICE(pos, tri); + NVDR_CHECK_CPU(ranges); + NVDR_CHECK_CONTIGUOUS(pos, tri, ranges); + NVDR_CHECK_F32(pos); + NVDR_CHECK_I32(tri, ranges); + + // Check that GL context was created for the correct GPU. + NVDR_CHECK(pos.get_device() == stateWrapper.cudaDeviceIdx, "GL context must must reside on the same device as input tensors"); + + // Determine number of outputs + int num_outputs = s.enableDB ? 2 : 1; + + // Determine instance mode and check input dimensions. + bool instance_mode = pos.sizes().size() > 2; + if (instance_mode) + NVDR_CHECK(pos.sizes().size() == 3 && pos.size(0) > 0 && pos.size(1) > 0 && pos.size(2) == 4, "instance mode - pos must have shape [>0, >0, 4]"); + else + { + NVDR_CHECK(pos.sizes().size() == 2 && pos.size(0) > 0 && pos.size(1) == 4, "range mode - pos must have shape [>0, 4]"); + NVDR_CHECK(ranges.sizes().size() == 2 && ranges.size(0) > 0 && ranges.size(1) == 2, "range mode - ranges must have shape [>0, 2]"); + } + NVDR_CHECK(tri.sizes().size() == 2 && tri.size(0) > 0 && tri.size(1) == 3, "tri must have shape [>0, 3]"); + + // Get output shape. + int height = std::get<0>(resolution); + int width = std::get<1>(resolution); + int depth = instance_mode ? pos.size(0) : ranges.size(0); + NVDR_CHECK(height > 0 && width > 0, "resolution must be [>0, >0]"); + + // Get position and triangle buffer sizes in int32/float32. + int posCount = 4 * pos.size(0) * (instance_mode ? pos.size(1) : 1); + int triCount = 3 * tri.size(0); + + // Set the GL context unless manual context. + if (stateWrapper.automatic) + setGLContext(s.glctx); + + // Resize all buffers. + bool changes = false; + rasterizeResizeBuffers(NVDR_CTX_PARAMS, s, changes, posCount, triCount, width, height, depth); + if (changes) + { +#ifdef _WIN32 + // Workaround for occasional blank first frame on Windows. + releaseGLContext(); + setGLContext(s.glctx); +#endif + } + + // Copy input data to GL and render. + const float* posPtr = pos.data_ptr(); + const int32_t* rangesPtr = instance_mode ? 0 : ranges.data_ptr(); // This is in CPU memory. + const int32_t* triPtr = tri.data_ptr(); + int vtxPerInstance = instance_mode ? pos.size(1) : 0; + rasterizeRender(NVDR_CTX_PARAMS, s, stream, posPtr, posCount, vtxPerInstance, triPtr, triCount, rangesPtr, width, height, depth, peeling_idx); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({depth, height, width, 4}, opts); + torch::Tensor out_db = torch::empty({depth, height, width, s.enableDB ? 4 : 0}, opts); + float* outputPtr[2]; + outputPtr[0] = out.data_ptr(); + outputPtr[1] = s.enableDB ? out_db.data_ptr() : NULL; + + // Copy rasterized results into CUDA buffers. + rasterizeCopyResults(NVDR_CTX_PARAMS, s, stream, outputPtr, width, height, depth); + + // Done. Release GL context and return. + if (stateWrapper.automatic) + releaseGLContext(); + + return std::tuple(out, out_db); +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/torch/torch_texture.cpp b/extensions/nvdiffrast/nvdiffrast/torch/torch_texture.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6e3693eedd9e17d862d74ca08ee9636dcca290cc --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/torch/torch_texture.cpp @@ -0,0 +1,718 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "torch_common.inl" +#include "torch_types.h" +#include "../common/common.h" +#include "../common/texture.h" +#include + +//------------------------------------------------------------------------ +// Kernel prototypes. + +void MipBuildKernel1 (const TextureKernelParams p); +void MipBuildKernel2 (const TextureKernelParams p); +void MipBuildKernel4 (const TextureKernelParams p); +void TextureFwdKernelNearest1 (const TextureKernelParams p); +void TextureFwdKernelNearest2 (const TextureKernelParams p); +void TextureFwdKernelNearest4 (const TextureKernelParams p); +void TextureFwdKernelLinear1 (const TextureKernelParams p); +void TextureFwdKernelLinear2 (const TextureKernelParams p); +void TextureFwdKernelLinear4 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapNearest1 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapNearest2 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapNearest4 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapLinear1 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapLinear2 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapLinear4 (const TextureKernelParams p); +void TextureFwdKernelCubeNearest1 (const TextureKernelParams p); +void TextureFwdKernelCubeNearest2 (const TextureKernelParams p); +void TextureFwdKernelCubeNearest4 (const TextureKernelParams p); +void TextureFwdKernelCubeLinear1 (const TextureKernelParams p); +void TextureFwdKernelCubeLinear2 (const TextureKernelParams p); +void TextureFwdKernelCubeLinear4 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapNearest1 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapNearest2 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapNearest4 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapLinear1 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapLinear2 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapLinear4 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapNearestBO1 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapNearestBO2 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapNearestBO4 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapLinearBO1 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapLinearBO2 (const TextureKernelParams p); +void TextureFwdKernelLinearMipmapLinearBO4 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapNearestBO1 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapNearestBO2 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapNearestBO4 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapLinearBO1 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapLinearBO2 (const TextureKernelParams p); +void TextureFwdKernelCubeLinearMipmapLinearBO4 (const TextureKernelParams p); +void MipGradKernel1 (const TextureKernelParams p); +void MipGradKernel2 (const TextureKernelParams p); +void MipGradKernel4 (const TextureKernelParams p); +void TextureGradKernelNearest (const TextureKernelParams p); +void TextureGradKernelLinear (const TextureKernelParams p); +void TextureGradKernelLinearMipmapNearest (const TextureKernelParams p); +void TextureGradKernelLinearMipmapLinear (const TextureKernelParams p); +void TextureGradKernelCubeNearest (const TextureKernelParams p); +void TextureGradKernelCubeLinear (const TextureKernelParams p); +void TextureGradKernelCubeLinearMipmapNearest (const TextureKernelParams p); +void TextureGradKernelCubeLinearMipmapLinear (const TextureKernelParams p); +void TextureGradKernelLinearMipmapNearestBO (const TextureKernelParams p); +void TextureGradKernelLinearMipmapLinearBO (const TextureKernelParams p); +void TextureGradKernelCubeLinearMipmapNearestBO (const TextureKernelParams p); +void TextureGradKernelCubeLinearMipmapLinearBO (const TextureKernelParams p); + +//------------------------------------------------------------------------ +// Modeselektor. + +static void set_modes(TextureKernelParams& p, int filter_mode, int boundary_mode, int max_mip_level) +{ + // Mip and filter modes. + p.filterMode = filter_mode; + NVDR_CHECK(p.filterMode >= 0 && p.filterMode < TEX_MODE_COUNT, "filter_mode unsupported"); + p.enableMip = (p.filterMode == TEX_MODE_LINEAR_MIPMAP_NEAREST || p.filterMode == TEX_MODE_LINEAR_MIPMAP_LINEAR); + + // Mip level clamp. + if (p.enableMip) + { + p.mipLevelLimit = max_mip_level; + NVDR_CHECK(p.mipLevelLimit >= -1, "invalid max_mip_level"); + } + + // Boundary mode. + p.boundaryMode = boundary_mode; + NVDR_CHECK(p.boundaryMode >= 0 && p.boundaryMode < TEX_BOUNDARY_MODE_COUNT, "boundary_mode unsupported"); +} + +//------------------------------------------------------------------------ +// Mipmap construction. + +TextureMipWrapper texture_construct_mip(torch::Tensor tex, int max_mip_level, bool cube_mode) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(tex)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + TextureKernelParams p = {}; // Initialize all fields to zero. + p.mipLevelLimit = max_mip_level; + p.boundaryMode = cube_mode ? TEX_BOUNDARY_MODE_CUBE : TEX_BOUNDARY_MODE_WRAP; + NVDR_CHECK(p.mipLevelLimit >= -1, "invalid max_mip_level"); + + // Check inputs. + NVDR_CHECK_DEVICE(tex); + NVDR_CHECK_CONTIGUOUS(tex); + NVDR_CHECK_F32(tex); + + // Populate parameters and sanity check tex shape. + if (!cube_mode) + { + NVDR_CHECK(tex.sizes().size() == 4 && tex.size(0) > 0 && tex.size(1) > 0 && tex.size(2) > 0 && tex.size(3) > 0, "tex must have shape[>0, >0, >0, >0]"); + } + else + { + NVDR_CHECK(tex.sizes().size() == 5 && tex.size(0) > 0 && tex.size(1) == 6 && tex.size(2) > 0 && tex.size(3) > 0 && tex.size(4) > 0, "tex must have shape[>0, 6, >0, >0, >0] in cube map mode"); + NVDR_CHECK(tex.size(2) == tex.size(3), "texture shape must be square in cube map mode"); + } + p.texDepth = tex.size(0); + p.texHeight = tex.size(cube_mode ? 2 : 1); + p.texWidth = tex.size(cube_mode ? 3 : 2); + p.channels = tex.size(cube_mode ? 4 : 3); + + // Set texture pointer. + p.tex[0] = tex.data_ptr(); + + // Generate mip offsets and calculate total size. + int mipOffsets[TEX_MAX_MIP_LEVEL]; + int mipTotal = calculateMipInfo(NVDR_CTX_PARAMS, p, mipOffsets); + + // Allocate and set mip tensor. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor mip = torch::empty({mipTotal}, opts); + float* pmip = mip.data_ptr(); + for (int i=1; i <= p.mipLevelMax; i++) + p.tex[i] = pmip + mipOffsets[i]; // Pointers to mip levels. + + // Choose kernel variants based on channel count. + void* args[] = {&p}; + int channel_div_idx = 0; + if (!(p.channels & 3)) + channel_div_idx = 2; // Channel count divisible by 4. + else if (!(p.channels & 1)) + channel_div_idx = 1; // Channel count divisible by 2. + + // Build mip levels. + for (int i=1; i <= p.mipLevelMax; i++) + { + int2 ms = mipLevelSize(p, i); + int3 sz = make_int3(ms.x, ms.y, p.texDepth); + dim3 blockSize = getLaunchBlockSize(TEX_FWD_MAX_MIP_KERNEL_BLOCK_WIDTH, TEX_FWD_MAX_MIP_KERNEL_BLOCK_HEIGHT, sz.x, sz.y); + dim3 gridSize = getLaunchGridSize(blockSize, sz.x, sz.y, sz.z * (cube_mode ? 6 : 1)); + p.mipLevelOut = i; + + void* build_func_tbl[3] = { (void*)MipBuildKernel1, (void*)MipBuildKernel2, (void*)MipBuildKernel4 }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel(build_func_tbl[channel_div_idx], gridSize, blockSize, args, 0, stream)); + } + + // Return the mip tensor in a wrapper. + TextureMipWrapper mip_wrapper; + mip_wrapper.mip = mip; + mip_wrapper.max_mip_level = max_mip_level; + mip_wrapper.texture_size = tex.sizes().vec(); + mip_wrapper.cube_mode = cube_mode; + return mip_wrapper; +} + +//------------------------------------------------------------------------ +// Forward op. + +torch::Tensor texture_fwd_mip(torch::Tensor tex, torch::Tensor uv, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip_wrapper, std::vector mip_stack, int filter_mode, int boundary_mode) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(tex)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + TextureKernelParams p = {}; // Initialize all fields to zero. + bool has_mip_stack = (mip_stack.size() > 0); + torch::Tensor& mip_w = mip_wrapper.mip; // Unwrap. + int max_mip_level = has_mip_stack ? mip_stack.size() : mip_wrapper.max_mip_level; + set_modes(p, filter_mode, boundary_mode, max_mip_level); + + // See if we have these tensors or not. + bool has_uv_da = uv_da.defined() && uv_da.nbytes(); + bool has_mip_level_bias = mip_level_bias.defined() && mip_level_bias.nbytes(); + + if (p.enableMip) + { + NVDR_CHECK(has_uv_da || has_mip_level_bias, "mipmapping filter mode requires uv_da and/or mip_level_bias input"); + NVDR_CHECK(has_mip_stack || mip_w.defined(), "mipmapping filter mode requires mip wrapper or mip stack input"); + } + + // Check inputs. + NVDR_CHECK_DEVICE(tex, uv); + NVDR_CHECK_CONTIGUOUS(tex, uv); + NVDR_CHECK_F32(tex, uv); + if (p.enableMip) + { + if (has_mip_stack) + { + TORCH_CHECK(at::cuda::check_device(mip_stack), __func__, "(): Mip stack inputs must reside on the correct GPU device"); + nvdr_check_contiguous(mip_stack, __func__, "(): Mip stack inputs must be contiguous tensors"); + nvdr_check_f32(mip_stack, __func__, "(): Mip stack inputs must be float32 tensors"); + } + else + { + NVDR_CHECK_DEVICE(mip_w); + NVDR_CHECK_CONTIGUOUS(mip_w); + NVDR_CHECK_F32(mip_w); + } + if (has_uv_da) + { + NVDR_CHECK_DEVICE(uv_da); + NVDR_CHECK_CONTIGUOUS(uv_da); + NVDR_CHECK_F32(uv_da); + } + if (has_mip_level_bias) + { + NVDR_CHECK_DEVICE(mip_level_bias); + NVDR_CHECK_CONTIGUOUS(mip_level_bias); + NVDR_CHECK_F32(mip_level_bias); + } + } + + // Sanity checks and state setters. + bool cube_mode = (boundary_mode == TEX_BOUNDARY_MODE_CUBE); + if (!cube_mode) + { + NVDR_CHECK(tex.sizes().size() == 4 && tex.size(0) > 0 && tex.size(1) > 0 && tex.size(2) > 0 && tex.size(3) > 0, "tex must have shape[>0, >0, >0, >0]"); + NVDR_CHECK(uv.sizes().size() == 4 && uv.size(0) > 0 && uv.size(1) > 0 && uv.size(2) > 0 && uv.size(3) == 2, "uv must have shape [>0, >0, >0, 2]"); + p.texHeight = tex.size(1); + p.texWidth = tex.size(2); + p.channels = tex.size(3); + } + else + { + NVDR_CHECK(tex.sizes().size() == 5 && tex.size(0) > 0 && tex.size(1) == 6 && tex.size(2) > 0 && tex.size(3) > 0 && tex.size(4) > 0, "tex must have shape[>0, 6, >0, >0, >0] in cube map mode"); + NVDR_CHECK(uv.sizes().size() == 4 && uv.size(0) > 0 && uv.size(1) > 0 && uv.size(2) > 0 && uv.size(3) == 3, "uv must have shape [>0, >0, >0, 3] in cube map mode"); + NVDR_CHECK(tex.size(2) == tex.size(3), "texture shape must be square in cube map mode"); + p.texHeight = tex.size(2); + p.texWidth = tex.size(3); + p.channels = tex.size(4); + } + NVDR_CHECK(tex.size(0) == 1 || tex.size(0) == uv.size(0), "minibatch size mismatch between inputs tex, uv"); + NVDR_CHECK(p.texWidth <= (1 << TEX_MAX_MIP_LEVEL) && p.texHeight <= (1 << TEX_MAX_MIP_LEVEL), "texture size too large"); + p.n = uv.size(0); + p.imgHeight = uv.size(1); + p.imgWidth = uv.size(2); + p.texDepth = tex.size(0); + if (p.enableMip) + { + if (has_uv_da) + { + if (!cube_mode) + NVDR_CHECK(uv_da.sizes().size() == 4 && uv_da.size(0) == p.n && uv_da.size(1) == p.imgHeight && uv_da.size(2) == p.imgWidth && uv_da.size(3) == 4, "uv_da must have shape [minibatch_size, height, width, 4]"); + else + NVDR_CHECK(uv_da.sizes().size() == 4 && uv_da.size(0) == p.n && uv_da.size(1) == p.imgHeight && uv_da.size(2) == p.imgWidth && uv_da.size(3) == 6, "uv_da must have shape [minibatch_size, height, width, 6] in cube map mode"); + } + if (has_mip_level_bias) + NVDR_CHECK(mip_level_bias.sizes().size() == 3 && mip_level_bias.size(0) == p.n && mip_level_bias.size(1) == p.imgHeight && mip_level_bias.size(2) == p.imgWidth, "mip_level_bias must have shape [minibatch_size, height, width]"); + } + + // Get input pointers. + p.tex[0] = tex.data_ptr(); + p.uv = uv.data_ptr(); + p.uvDA = (p.enableMip && has_uv_da) ? uv_da.data_ptr() : NULL; + p.mipLevelBias = (p.enableMip && has_mip_level_bias) ? mip_level_bias.data_ptr() : NULL; + + // Allocate output tensor. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({p.n, p.imgHeight, p.imgWidth, p.channels}, opts); + p.out = out.data_ptr(); + + // Choose kernel variants based on channel count. + void* args[] = {&p}; + int channel_div_idx = 0; + if (!(p.channels & 3)) + channel_div_idx = 2; // Channel count divisible by 4. + else if (!(p.channels & 1)) + channel_div_idx = 1; // Channel count divisible by 2. + + // Mip-related setup. + float* pmip = 0; + if (p.enableMip) + { + if (has_mip_stack) + { + // Custom mip stack supplied. Check that sizes match and assign. + p.mipLevelMax = max_mip_level; + for (int i=1; i <= p.mipLevelMax; i++) + { + torch::Tensor& t = mip_stack[i-1]; + int2 sz = mipLevelSize(p, i); + if (!cube_mode) + NVDR_CHECK(t.sizes().size() == 4 && t.size(0) == tex.size(0) && t.size(1) == sz.y && t.size(2) == sz.x && t.size(3) == p.channels, "mip level size mismatch in custom mip stack"); + else + NVDR_CHECK(t.sizes().size() == 5 && t.size(0) == tex.size(0) && t.size(1) == 6 && t.size(2) == sz.y && t.size(3) == sz.x && t.size(4) == p.channels, "mip level size mismatch in mip stack"); + if (sz.x == 1 && sz.y == 1) + NVDR_CHECK(i == p.mipLevelMax, "mip level size mismatch in mip stack"); + p.tex[i] = t.data_ptr(); + } + } + else + { + // Generate mip offsets, check mipmap size, and set mip data pointer. + int mipOffsets[TEX_MAX_MIP_LEVEL]; + int mipTotal = calculateMipInfo(NVDR_CTX_PARAMS, p, mipOffsets); + NVDR_CHECK(tex.sizes() == mip_wrapper.texture_size && cube_mode == mip_wrapper.cube_mode, "mip does not match texture size"); + NVDR_CHECK(mip_w.sizes().size() == 1 && mip_w.size(0) == mipTotal, "wrapped mip tensor size mismatch"); + pmip = mip_w.data_ptr(); + for (int i=1; i <= p.mipLevelMax; i++) + p.tex[i] = pmip + mipOffsets[i]; // Pointers to mip levels. + } + } + + // Verify that buffers are aligned to allow float2/float4 operations. Unused pointers are zero so always aligned. + if (!cube_mode) + NVDR_CHECK(!((uintptr_t)p.uv & 7), "uv input tensor not aligned to float2"); + if ((p.channels & 3) == 0) + { + for (int i=0; i <= p.mipLevelMax; i++) + NVDR_CHECK(!((uintptr_t)p.tex[i] & 15), "tex or mip input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.out & 15), "out output tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)pmip & 15), "mip input tensor not aligned to float4"); + } + if ((p.channels & 1) == 0) + { + for (int i=0; i <= p.mipLevelMax; i++) + NVDR_CHECK(!((uintptr_t)p.tex[i] & 7), "tex or mip input tensor not aligned to float2"); + NVDR_CHECK(!((uintptr_t)p.out & 7), "out output tensor not aligned to float2"); + NVDR_CHECK(!((uintptr_t)pmip & 7), "mip input tensor not aligned to float2"); + } + if (!cube_mode) + NVDR_CHECK(!((uintptr_t)p.uvDA & 15), "uv_da input tensor not aligned to float4"); + else + NVDR_CHECK(!((uintptr_t)p.uvDA & 7), "uv_da input tensor not aligned to float2"); + + // Choose launch parameters for texture lookup kernel. + dim3 blockSize = getLaunchBlockSize(TEX_FWD_MAX_KERNEL_BLOCK_WIDTH, TEX_FWD_MAX_KERNEL_BLOCK_HEIGHT, p.imgWidth, p.imgHeight); + dim3 gridSize = getLaunchGridSize(blockSize, p.imgWidth, p.imgHeight, p.n); + + // Choose kernel based on filter mode, cube mode, bias-only mode, and datatype. + void* func_tbl[TEX_MODE_COUNT * 2 * 2 * 3] = { + (void*)TextureFwdKernelNearest1, + (void*)TextureFwdKernelNearest2, + (void*)TextureFwdKernelNearest4, + (void*)TextureFwdKernelLinear1, + (void*)TextureFwdKernelLinear2, + (void*)TextureFwdKernelLinear4, + (void*)TextureFwdKernelLinearMipmapNearest1, + (void*)TextureFwdKernelLinearMipmapNearest2, + (void*)TextureFwdKernelLinearMipmapNearest4, + (void*)TextureFwdKernelLinearMipmapLinear1, + (void*)TextureFwdKernelLinearMipmapLinear2, + (void*)TextureFwdKernelLinearMipmapLinear4, + (void*)TextureFwdKernelCubeNearest1, + (void*)TextureFwdKernelCubeNearest2, + (void*)TextureFwdKernelCubeNearest4, + (void*)TextureFwdKernelCubeLinear1, + (void*)TextureFwdKernelCubeLinear2, + (void*)TextureFwdKernelCubeLinear4, + (void*)TextureFwdKernelCubeLinearMipmapNearest1, + (void*)TextureFwdKernelCubeLinearMipmapNearest2, + (void*)TextureFwdKernelCubeLinearMipmapNearest4, + (void*)TextureFwdKernelCubeLinearMipmapLinear1, + (void*)TextureFwdKernelCubeLinearMipmapLinear2, + (void*)TextureFwdKernelCubeLinearMipmapLinear4, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + (void*)TextureFwdKernelLinearMipmapNearestBO1, + (void*)TextureFwdKernelLinearMipmapNearestBO2, + (void*)TextureFwdKernelLinearMipmapNearestBO4, + (void*)TextureFwdKernelLinearMipmapLinearBO1, + (void*)TextureFwdKernelLinearMipmapLinearBO2, + (void*)TextureFwdKernelLinearMipmapLinearBO4, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + (void*)TextureFwdKernelCubeLinearMipmapNearestBO1, + (void*)TextureFwdKernelCubeLinearMipmapNearestBO2, + (void*)TextureFwdKernelCubeLinearMipmapNearestBO4, + (void*)TextureFwdKernelCubeLinearMipmapLinearBO1, + (void*)TextureFwdKernelCubeLinearMipmapLinearBO2, + (void*)TextureFwdKernelCubeLinearMipmapLinearBO4, + }; + + // Function index. + int func_idx = p.filterMode; + if (cube_mode) + func_idx += TEX_MODE_COUNT; // Cube variant. + if (p.enableMip && !has_uv_da) + func_idx += TEX_MODE_COUNT * 2; // Bias-only variant. + func_idx = func_idx * 3 + channel_div_idx; // Choose vector size. + + // Launch kernel. + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel(func_tbl[func_idx], gridSize, blockSize, args, 0, stream)); + + // Return output tensor. + return out; +} + +// Version without mipmaps. +torch::Tensor texture_fwd(torch::Tensor tex, torch::Tensor uv, int filter_mode, int boundary_mode) +{ + torch::Tensor empty_tensor; + std::vector empty_vector; + return texture_fwd_mip(tex, uv, empty_tensor, empty_tensor, TextureMipWrapper(), empty_vector, filter_mode, boundary_mode); +} + +//------------------------------------------------------------------------ +// Gradient op. + +std::tuple > texture_grad_linear_mipmap_linear(torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip_wrapper, std::vector mip_stack, int filter_mode, int boundary_mode) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(tex)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + TextureKernelParams p = {}; // Initialize all fields to zero. + bool has_mip_stack = (mip_stack.size() > 0); + torch::Tensor& mip_w = mip_wrapper.mip; // Unwrap. + int max_mip_level = has_mip_stack ? mip_stack.size() : mip_wrapper.max_mip_level; + set_modes(p, filter_mode, boundary_mode, max_mip_level); + + // See if we have these tensors or not. + bool has_uv_da = uv_da.defined() && uv_da.nbytes(); + bool has_mip_level_bias = mip_level_bias.defined() && mip_level_bias.nbytes(); + + if (p.enableMip) + { + NVDR_CHECK(has_uv_da || has_mip_level_bias, "mipmapping filter mode requires uv_da and/or mip_level_bias input"); + NVDR_CHECK(has_mip_stack || mip_w.defined(), "mipmapping filter mode requires mip wrapper or mip stack input"); + } + + // Check inputs. + NVDR_CHECK_DEVICE(tex, uv); + NVDR_CHECK_CONTIGUOUS(tex, uv); + NVDR_CHECK_F32(tex, uv); + if (p.enableMip) + { + if (has_mip_stack) + { + TORCH_CHECK(at::cuda::check_device(mip_stack), __func__, "(): Mip stack inputs must reside on the correct GPU device"); + nvdr_check_contiguous(mip_stack, __func__, "(): Mip stack inputs must be contiguous tensors"); + nvdr_check_f32(mip_stack, __func__, "(): Mip stack inputs must be float32 tensors"); + } + else + { + NVDR_CHECK_DEVICE(mip_w); + NVDR_CHECK_CONTIGUOUS(mip_w); + NVDR_CHECK_F32(mip_w); + } + if (has_uv_da) + { + NVDR_CHECK_DEVICE(uv_da); + NVDR_CHECK_CONTIGUOUS(uv_da); + NVDR_CHECK_F32(uv_da); + } + if (has_mip_level_bias) + { + NVDR_CHECK_DEVICE(mip_level_bias); + NVDR_CHECK_CONTIGUOUS(mip_level_bias); + NVDR_CHECK_F32(mip_level_bias); + } + } + + // Sanity checks and state setters. + bool cube_mode = (boundary_mode == TEX_BOUNDARY_MODE_CUBE); + if (!cube_mode) + { + NVDR_CHECK(tex.sizes().size() == 4 && tex.size(0) > 0 && tex.size(1) > 0 && tex.size(2) > 0 && tex.size(3) > 0, "tex must have shape[>0, >0, >0, >0]"); + NVDR_CHECK(uv.sizes().size() == 4 && uv.size(0) > 0 && uv.size(1) > 0 && uv.size(2) > 0 && uv.size(3) == 2, "uv must have shape [>0, >0, >0, 2]"); + p.texHeight = tex.size(1); + p.texWidth = tex.size(2); + p.channels = tex.size(3); + } + else + { + NVDR_CHECK(tex.sizes().size() == 5 && tex.size(0) > 0 && tex.size(1) == 6 && tex.size(2) > 0 && tex.size(3) > 0 && tex.size(4) > 0, "tex must have shape[>0, 6, >0, >0, >0] in cube map mode"); + NVDR_CHECK(uv.sizes().size() == 4 && uv.size(0) > 0 && uv.size(1) > 0 && uv.size(2) > 0 && uv.size(3) == 3, "uv must have shape [>0, >0, >0, 3] in cube map mode"); + NVDR_CHECK(tex.size(2) == tex.size(3), "texture shape must be square in cube map mode"); + p.texHeight = tex.size(2); + p.texWidth = tex.size(3); + p.channels = tex.size(4); + } + NVDR_CHECK(tex.size(0) == 1 || tex.size(0) == uv.size(0), "minibatch size mismatch between inputs tex, uv"); + NVDR_CHECK(p.texWidth <= (1 << TEX_MAX_MIP_LEVEL) && p.texHeight <= (1 << TEX_MAX_MIP_LEVEL), "texture size too large"); + p.n = uv.size(0); + p.imgHeight = uv.size(1); + p.imgWidth = uv.size(2); + p.texDepth = tex.size(0); + if (p.enableMip) + { + if (has_uv_da) + { + if (!cube_mode) + NVDR_CHECK(uv_da.sizes().size() == 4 && uv_da.size(0) == p.n && uv_da.size(1) == p.imgHeight && uv_da.size(2) == p.imgWidth && uv_da.size(3) == 4, "uv_da must have shape [minibatch_size, height, width, 4]"); + else + NVDR_CHECK(uv_da.sizes().size() == 4 && uv_da.size(0) == p.n && uv_da.size(1) == p.imgHeight && uv_da.size(2) == p.imgWidth && uv_da.size(3) == 6, "uv_da must have shape [minibatch_size, height, width, 6] in cube map mode"); + } + if (has_mip_level_bias) + NVDR_CHECK(mip_level_bias.sizes().size() == 3 && mip_level_bias.size(0) == p.n && mip_level_bias.size(1) == p.imgHeight && mip_level_bias.size(2) == p.imgWidth, "mip_level_bias must have shape [minibatch_size, height, width]"); + } + NVDR_CHECK(dy.sizes().size() == 4 && dy.size(0) == p.n && dy.size(1) == p.imgHeight && dy.size(2) == p.imgWidth && dy.size(3) == p.channels, "dy must have shape [minibatch_size, height, width, channels]"); + + // Get contiguous version of dy. + torch::Tensor dy_ = dy.contiguous(); + + // Get input pointers. + p.tex[0] = tex.data_ptr(); + p.uv = uv.data_ptr(); + p.dy = dy_.data_ptr(); + p.uvDA = (p.enableMip && has_uv_da) ? uv_da.data_ptr() : NULL; + p.mipLevelBias = (p.enableMip && has_mip_level_bias) ? mip_level_bias.data_ptr() : NULL; + + // Allocate output tensor for tex gradient. + torch::Tensor grad_tex = torch::zeros_like(tex); + p.gradTex[0] = grad_tex.data_ptr(); + + // Allocate output tensor for uv gradient. + torch::Tensor grad_uv; + torch::Tensor grad_uv_da; + torch::Tensor grad_mip_level_bias; + if (p.filterMode != TEX_MODE_NEAREST) + { + grad_uv = torch::empty_like(uv); + p.gradUV = grad_uv.data_ptr(); + + // Gradients for things affecting mip level. + if (p.filterMode == TEX_MODE_LINEAR_MIPMAP_LINEAR) + { + // Allocate output tensor for uv_da gradient. + if (has_uv_da) + { + grad_uv_da = torch::empty_like(uv_da); + p.gradUVDA = grad_uv_da.data_ptr(); + } + + // Allocate output tensor for mip_level_bias gradient. + if (has_mip_level_bias) + { + grad_mip_level_bias = torch::empty_like(mip_level_bias); + p.gradMipLevelBias = grad_mip_level_bias.data_ptr(); + } + } + } + + // Choose kernel variants based on channel count. + int channel_div_idx = 0; + if (!(p.channels & 3)) + channel_div_idx = 2; // Channel count divisible by 4. + else if (!(p.channels & 1)) + channel_div_idx = 1; // Channel count divisible by 2. + + // Mip-related setup. + torch::Tensor grad_mip; + std::vector grad_mip_stack; + float* pmip = 0; + float* pgradMip = 0; + if (p.enableMip) + { + if (has_mip_stack) + { + // Custom mip stack supplied. Check that sizes match, assign, construct gradient tensors. + p.mipLevelMax = max_mip_level; + for (int i=1; i <= p.mipLevelMax; i++) + { + torch::Tensor& t = mip_stack[i-1]; + int2 sz = mipLevelSize(p, i); + if (!cube_mode) + NVDR_CHECK(t.sizes().size() == 4 && t.size(0) == tex.size(0) && t.size(1) == sz.y && t.size(2) == sz.x && t.size(3) == p.channels, "mip level size mismatch in mip stack"); + else + NVDR_CHECK(t.sizes().size() == 5 && t.size(0) == tex.size(0) && t.size(1) == 6 && t.size(2) == sz.y && t.size(3) == sz.x && t.size(4) == p.channels, "mip level size mismatch in mip stack"); + if (sz.x == 1 && sz.y == 1) + NVDR_CHECK(i == p.mipLevelMax, "mip level size mismatch in mip stack"); + + torch::Tensor g = torch::zeros_like(t); + grad_mip_stack.push_back(g); + + p.tex[i] = t.data_ptr(); + p.gradTex[i] = g.data_ptr(); + } + } + else + { + // Generate mip offsets and get space for temporary mip gradients. + int mipOffsets[TEX_MAX_MIP_LEVEL]; + int mipTotal = calculateMipInfo(NVDR_CTX_PARAMS, p, mipOffsets); + NVDR_CHECK(tex.sizes() == mip_wrapper.texture_size && cube_mode == mip_wrapper.cube_mode, "mip does not match texture size"); + NVDR_CHECK(mip_w.sizes().size() == 1 && mip_w.size(0) == mipTotal, "mip tensor size mismatch"); + grad_mip = torch::zeros_like(mip_w); + pmip = (float*)mip_w.data_ptr(); + pgradMip = grad_mip.data_ptr(); + for (int i=1; i <= p.mipLevelMax; i++) + { + p.tex[i] = pmip + mipOffsets[i]; // Pointers to mip levels. + p.gradTex[i] = pgradMip + mipOffsets[i]; // Pointers to mip gradients. + } + } + } + + // Verify that buffers are aligned to allow float2/float4 operations. Unused pointers are zero so always aligned. + if (!cube_mode) + { + NVDR_CHECK(!((uintptr_t)p.uv & 7), "uv input tensor not aligned to float2"); + NVDR_CHECK(!((uintptr_t)p.gradUV & 7), "grad_uv output tensor not aligned to float2"); + NVDR_CHECK(!((uintptr_t)p.uvDA & 15), "uv_da input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.gradUVDA & 15), "grad_uv_da output tensor not aligned to float4"); + } + else + { + NVDR_CHECK(!((uintptr_t)p.uvDA & 7), "uv_da input tensor not aligned to float2"); + NVDR_CHECK(!((uintptr_t)p.gradUVDA & 7), "grad_uv_da output tensor not aligned to float2"); + } + if ((p.channels & 3) == 0) + { + for (int i=0; i <= p.mipLevelMax; i++) + { + NVDR_CHECK(!((uintptr_t)p.tex[i] & 15), "tex or mip input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)p.gradTex[i] & 15), "grad_tex output tensor not aligned to float4"); + } + NVDR_CHECK(!((uintptr_t)p.dy & 15), "dy input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)pmip & 15), "mip input tensor not aligned to float4"); + NVDR_CHECK(!((uintptr_t)pgradMip & 15), "internal mip gradient tensor not aligned to float4"); + } + if ((p.channels & 1) == 0) + { + for (int i=0; i <= p.mipLevelMax; i++) + { + NVDR_CHECK(!((uintptr_t)p.tex[i] & 7), "tex or mip input tensor not aligned to float2"); + NVDR_CHECK(!((uintptr_t)p.gradTex[i] & 7), "grad_tex output tensor not aligned to float2"); + } + NVDR_CHECK(!((uintptr_t)p.dy & 7), "dy output tensor not aligned to float2"); + NVDR_CHECK(!((uintptr_t)pmip & 7), "mip input tensor not aligned to float2"); + NVDR_CHECK(!((uintptr_t)pgradMip & 7), "internal mip gradient tensor not aligned to float2"); + } + + // Choose launch parameters for main gradient kernel. + void* args[] = {&p}; + dim3 blockSize = getLaunchBlockSize(TEX_GRAD_MAX_KERNEL_BLOCK_WIDTH, TEX_GRAD_MAX_KERNEL_BLOCK_HEIGHT, p.imgWidth, p.imgHeight); + dim3 gridSize = getLaunchGridSize(blockSize, p.imgWidth, p.imgHeight, p.n); + + void* func_tbl[TEX_MODE_COUNT * 2 * 2] = { + (void*)TextureGradKernelNearest, + (void*)TextureGradKernelLinear, + (void*)TextureGradKernelLinearMipmapNearest, + (void*)TextureGradKernelLinearMipmapLinear, + (void*)TextureGradKernelCubeNearest, + (void*)TextureGradKernelCubeLinear, + (void*)TextureGradKernelCubeLinearMipmapNearest, + (void*)TextureGradKernelCubeLinearMipmapLinear, + NULL, + NULL, + (void*)TextureGradKernelLinearMipmapNearestBO, + (void*)TextureGradKernelLinearMipmapLinearBO, + NULL, + NULL, + (void*)TextureGradKernelCubeLinearMipmapNearestBO, + (void*)TextureGradKernelCubeLinearMipmapLinearBO, + }; + + // Function index. + int func_idx = p.filterMode; + if (cube_mode) + func_idx += TEX_MODE_COUNT; // Cube variant. + if (p.enableMip && !has_uv_da) + func_idx += TEX_MODE_COUNT * 2; // Bias-only variant. + + // Launch main gradient kernel. + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel(func_tbl[func_idx], gridSize, blockSize, args, 0, stream)); + + // Launch kernel to pull gradients from mip levels. Don't do this if mip stack was supplied - individual level gradients are already there. + if (p.enableMip && !has_mip_stack) + { + dim3 blockSize = getLaunchBlockSize(TEX_GRAD_MAX_MIP_KERNEL_BLOCK_WIDTH, TEX_GRAD_MAX_MIP_KERNEL_BLOCK_HEIGHT, p.texWidth, p.texHeight); + dim3 gridSize = getLaunchGridSize(blockSize, p.texWidth, p.texHeight, p.texDepth * (cube_mode ? 6 : 1)); + int sharedBytes = blockSize.x * blockSize.y * p.channels * sizeof(float); + + void* mip_grad_func_tbl[3] = { (void*)MipGradKernel1, (void*)MipGradKernel2, (void*)MipGradKernel4 }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel(mip_grad_func_tbl[channel_div_idx], gridSize, blockSize, args, sharedBytes, stream)); + } + + // Return output tensors. + return std::tuple >(grad_tex, grad_uv, grad_uv_da, grad_mip_level_bias, grad_mip_stack); +} + +// Version for nearest filter mode. +torch::Tensor texture_grad_nearest(torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, int filter_mode, int boundary_mode) +{ + torch::Tensor empty_tensor; + std::vector empty_vector; + std::tuple > result = texture_grad_linear_mipmap_linear(tex, uv, dy, empty_tensor, empty_tensor, TextureMipWrapper(), empty_vector, filter_mode, boundary_mode); + return std::get<0>(result); +} + +// Version for linear filter mode. +std::tuple texture_grad_linear(torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, int filter_mode, int boundary_mode) +{ + torch::Tensor empty_tensor; + std::vector empty_vector; + std::tuple > result = texture_grad_linear_mipmap_linear(tex, uv, dy, empty_tensor, empty_tensor, TextureMipWrapper(), empty_vector, filter_mode, boundary_mode); + return std::tuple(std::get<0>(result), std::get<1>(result)); +} + +// Version for linear-mipmap-nearest mode. +std::tuple > texture_grad_linear_mipmap_nearest(torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip_wrapper, std::vector mip_stack, int filter_mode, int boundary_mode) +{ + std::tuple > result = texture_grad_linear_mipmap_linear(tex, uv, dy, uv_da, mip_level_bias, mip_wrapper, mip_stack, filter_mode, boundary_mode); + return std::tuple >(std::get<0>(result), std::get<1>(result), std::get<4>(result)); +} + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/nvdiffrast/torch/torch_types.h b/extensions/nvdiffrast/nvdiffrast/torch/torch_types.h new file mode 100644 index 0000000000000000000000000000000000000000..9ea4bdea7c011866cd80b72d7f0a593ee091d070 --- /dev/null +++ b/extensions/nvdiffrast/nvdiffrast/torch/torch_types.h @@ -0,0 +1,65 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "torch_common.inl" + +//------------------------------------------------------------------------ +// Python GL state wrapper. + +class RasterizeGLState; +class RasterizeGLStateWrapper +{ +public: + RasterizeGLStateWrapper (bool enableDB, bool automatic, int cudaDeviceIdx); + ~RasterizeGLStateWrapper (void); + + void setContext (void); + void releaseContext (void); + + RasterizeGLState* pState; + bool automatic; + int cudaDeviceIdx; +}; + +//------------------------------------------------------------------------ +// Python CudaRaster state wrapper. + +namespace CR { class CudaRaster; } +class RasterizeCRStateWrapper +{ +public: + RasterizeCRStateWrapper (int cudaDeviceIdx); + ~RasterizeCRStateWrapper (void); + + CR::CudaRaster* cr; + int cudaDeviceIdx; +}; + +//------------------------------------------------------------------------ +// Mipmap wrapper to prevent intrusion from Python side. + +class TextureMipWrapper +{ +public: + torch::Tensor mip; + int max_mip_level; + std::vector texture_size; // For error checking. + bool cube_mode; // For error checking. +}; + + +//------------------------------------------------------------------------ +// Antialias topology hash wrapper to prevent intrusion from Python side. + +class TopologyHashWrapper +{ +public: + torch::Tensor ev_hash; +}; + +//------------------------------------------------------------------------ diff --git a/extensions/nvdiffrast/run_sample.sh b/extensions/nvdiffrast/run_sample.sh new file mode 100644 index 0000000000000000000000000000000000000000..08b4104d6e8aef18b0ff8b9491f74586df0cd550 --- /dev/null +++ b/extensions/nvdiffrast/run_sample.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +function print_help { + echo "Usage: `basename $0` [--build-container] " + echo "" + echo "Option --build-container will build the Docker container based on" + echo "docker/Dockerfile and tag the image with gltorch:latest." + echo "" + echo "Example: `basename $0` samples/torch/envphong.py" +} + +build_container=0 +sample="" +while [[ "$#" -gt 0 ]]; do + case $1 in + --build-container) build_container=1;; + -h|--help) print_help; exit 0 ;; + --*) echo "Unknown parameter passed: $1"; exit 1 ;; + *) sample="$1"; shift; break; + esac + shift +done + +rest=$@ + +# Build the docker container +if [ "$build_container" = "1" ]; then + docker build --tag gltorch:latest -f docker/Dockerfile . +fi + +if [ ! -f "$sample" ]; then + echo + echo "No python sample given or file '$sample' not found. Exiting." + exit 1 +fi + +image="gltorch:latest" + +echo "Using container image: $image" +echo "Running command: $sample $rest" + +# Run a sample with docker +docker run --rm -it --gpus all --user $(id -u):$(id -g) \ + -v `pwd`:/app --workdir /app -e TORCH_EXTENSIONS_DIR=/app/tmp $image python3 $sample $rest diff --git a/extensions/nvdiffrast/setup copy.py b/extensions/nvdiffrast/setup copy.py new file mode 100644 index 0000000000000000000000000000000000000000..07541da2a746eb9cd5f6ae61cb34475b890439cd --- /dev/null +++ b/extensions/nvdiffrast/setup copy.py @@ -0,0 +1,51 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import nvdiffrast +import setuptools +import os + +with open("README.md", "r") as fh: + long_description = fh.read() + +setuptools.setup( + name="nvdiffrast", + version=nvdiffrast.__version__, + author="Samuli Laine", + author_email="slaine@nvidia.com", + description="nvdiffrast - modular primitives for high-performance differentiable rendering", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/NVlabs/nvdiffrast", + packages=setuptools.find_packages(), + package_data={ + 'nvdiffrast': [ + 'common/*.h', + 'common/*.inl', + 'common/*.cu', + 'common/*.cpp', + 'common/cudaraster/*.hpp', + 'common/cudaraster/impl/*.cpp', + 'common/cudaraster/impl/*.hpp', + 'common/cudaraster/impl/*.inl', + 'common/cudaraster/impl/*.cu', + 'lib/*.h', + 'torch/*.h', + 'torch/*.inl', + 'torch/*.cpp', + 'tensorflow/*.cu', + ] + (['lib/*.lib'] if os.name == 'nt' else []) + }, + include_package_data=True, + install_requires=['numpy'], # note: can't require torch here as it will install torch even for a TensorFlow container + classifiers=[ + "Programming Language :: Python :: 3", + "Operating System :: OS Independent", + ], + python_requires='>=3.6', +) diff --git a/extensions/nvdiffrast/setup.py b/extensions/nvdiffrast/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..b335971c46c41edee3352c882c6d23823d6b5ad1 --- /dev/null +++ b/extensions/nvdiffrast/setup.py @@ -0,0 +1,82 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import nvdiffrast +import setuptools +import os +from torch.utils.cpp_extension import CUDAExtension, BuildExtension + + +with open("README.md", "r") as fh: + long_description = fh.read() + +setuptools.setup( + name="nvdiffrast", + version=nvdiffrast.__version__, + author="Samuli Laine", + author_email="slaine@nvidia.com", + description="nvdiffrast - modular primitives for high-performance differentiable rendering", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/NVlabs/nvdiffrast", + packages=setuptools.find_packages(), + # package_data={ + # 'nvdiffrast': [ + # 'common/*.h', + # 'common/*.inl', + # 'common/*.cu', + # 'common/*.cpp', + # 'common/cudaraster/*.hpp', + # 'common/cudaraster/impl/*.cpp', + # 'common/cudaraster/impl/*.hpp', + # 'common/cudaraster/impl/*.inl', + # 'common/cudaraster/impl/*.cu', + # 'lib/*.h', + # 'torch/*.h', + # 'torch/*.inl', + # 'torch/*.cpp', + # 'tensorflow/*.cu', + # ] + (['lib/*.lib'] if os.name == 'nt' else []) + # }, + # include_package_data=True, + ext_modules=[ + CUDAExtension( + name="nvdiffrast.torch._C", + sources=[ + 'nvdiffrast/common/cudaraster/impl/Buffer.cpp', + 'nvdiffrast/common/cudaraster/impl/CudaRaster.cpp', + 'nvdiffrast/common/cudaraster/impl/RasterImpl_.cu', + 'nvdiffrast/common/cudaraster/impl/RasterImpl.cpp', + 'nvdiffrast/common/common.cpp', + 'nvdiffrast/common/rasterize.cu', + 'nvdiffrast/common/interpolate.cu', + 'nvdiffrast/common/texture_.cu', + 'nvdiffrast/common/texture.cpp', + 'nvdiffrast/common/antialias.cu', + 'nvdiffrast/torch/torch_bindings.cpp', + 'nvdiffrast/torch/torch_rasterize.cpp', + 'nvdiffrast/torch/torch_interpolate.cpp', + 'nvdiffrast/torch/torch_texture.cpp', + 'nvdiffrast/torch/torch_antialias.cpp', + ], + extra_compile_args={ + 'cxx': ['-DNVDR_TORCH'], + 'nvcc': ['-DNVDR_TORCH', '-lineinfo'], + }, + ) + ], + cmdclass={ + 'build_ext': BuildExtension + }, + install_requires=['numpy'], # note: can't require torch here as it will install torch even for a TensorFlow container + classifiers=[ + "Programming Language :: Python :: 3", + "Operating System :: OS Independent", + ], + python_requires='>=3.6', +) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..47a19c1b300c29c9702e16bcacb4b821441c43c1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,26 @@ +--extra-index-url https://download.pytorch.org/whl/cu121 + +torch==2.4.0 +torchvision==0.19.0 +pillow==10.4.0 +imageio==2.36.1 +imageio-ffmpeg==0.5.1 +tqdm==4.67.1 +easydict==1.13 +opencv-python-headless==4.10.0.84 +scipy==1.14.1 +rembg==2.0.60 +onnxruntime==1.20.1 +trimesh==4.5.3 +xatlas==0.0.9 +pyvista==0.44.2 +pymeshfix==0.17.0 +igraph==0.11.8 +git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8 +xformers==0.0.27.post2 +spconv-cu120==2.3.6 +transformers==4.46.3 +gradio_litmodel3d==0.0.1 +https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl +https://huggingface.co./spaces/JeffreyXiang/TRELLIS/resolve/main/wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl?download=true +https://huggingface.co./spaces/JeffreyXiang/TRELLIS/resolve/main/wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl?download=true \ No newline at end of file diff --git a/trellis/__init__.py b/trellis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b02ac31563fec7c36fa4bd5d420ac4af2472bba8 --- /dev/null +++ b/trellis/__init__.py @@ -0,0 +1,6 @@ +from . import models +from . import modules +from . import pipelines +from . import renderers +from . import representations +from . import utils diff --git a/trellis/models/__init__.py b/trellis/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..00fd66a3272a48d43f9853682db48bcde2959d63 --- /dev/null +++ b/trellis/models/__init__.py @@ -0,0 +1,70 @@ +import importlib + +__attributes = { + 'SparseStructureEncoder': 'sparse_structure_vae', + 'SparseStructureDecoder': 'sparse_structure_vae', + 'SparseStructureFlowModel': 'sparse_structure_flow', + 'SLatEncoder': 'structured_latent_vae', + 'SLatGaussianDecoder': 'structured_latent_vae', + 'SLatRadianceFieldDecoder': 'structured_latent_vae', + 'SLatMeshDecoder': 'structured_latent_vae', + 'SLatFlowModel': 'structured_latent_flow', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +def from_pretrained(path: str, **kwargs): + """ + Load a model from a pretrained checkpoint. + + Args: + path: The path to the checkpoint. Can be either local path or a Hugging Face model name. + NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively. + **kwargs: Additional arguments for the model constructor. + """ + import os + import json + from safetensors.torch import load_file + is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors") + + if is_local: + config_file = f"{path}.json" + model_file = f"{path}.safetensors" + else: + from huggingface_hub import hf_hub_download + path_parts = path.split('/') + repo_id = f'{path_parts[0]}/{path_parts[1]}' + model_name = '/'.join(path_parts[2:]) + config_file = hf_hub_download(repo_id, f"{model_name}.json") + model_file = hf_hub_download(repo_id, f"{model_name}.safetensors") + + with open(config_file, 'r') as f: + config = json.load(f) + model = __getattr__(config['name'])(**config['args'], **kwargs) + model.load_state_dict(load_file(model_file)) + + return model + + +# For Pylance +if __name__ == '__main__': + from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder + from .sparse_structure_flow import SparseStructureFlowModel + from .structured_latent_vae import SLatEncoder, SLatGaussianDecoder, SLatRadianceFieldDecoder, SLatMeshDecoder + from .structured_latent_flow import SLatFlowModel diff --git a/trellis/models/sparse_structure_flow.py b/trellis/models/sparse_structure_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..baa5dd9644e569b73717d7e7a9ebed55e9930459 --- /dev/null +++ b/trellis/models/sparse_structure_flow.py @@ -0,0 +1,200 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..modules.utils import convert_module_to_f16, convert_module_to_f32 +from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock +from ..modules.spatial import patchify, unpatchify + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + + Returns: + an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class SparseStructureFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + patch_size: int = 2, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + share_mod: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.patch_size = patch_size + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + if pe_mode == "ape": + pos_embedder = AbsolutePositionEmbedder(model_channels, 3) + coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution // patch_size] * 3], indexing='ij') + coords = torch.stack(coords, dim=-1).reshape(-1, 3) + pos_emb = pos_embedder(coords) + self.register_buffer("pos_emb", pos_emb) + + self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels) + + self.blocks = nn.ModuleList([ + ModulatedTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + share_mod=share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ]) + + self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ + f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" + + h = patchify(x, self.patch_size) + h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous() + + h = self.input_layer(h) + h = h + self.pos_emb[None] + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = t_emb.type(self.dtype) + h = h.type(self.dtype) + cond = cond.type(self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond) + h = h.type(x.dtype) + h = F.layer_norm(h, h.shape[-1:]) + h = self.out_layer(h) + + h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3) + h = unpatchify(h, self.patch_size).contiguous() + + return h diff --git a/trellis/models/sparse_structure_vae.py b/trellis/models/sparse_structure_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..6ed49ae65b9cde2a45a59beb6868981a644b75d3 --- /dev/null +++ b/trellis/models/sparse_structure_vae.py @@ -0,0 +1,306 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ..modules.norm import GroupNorm32, ChannelLayerNorm32 +from ..modules.spatial import pixel_shuffle_3d +from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 + + +def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module: + """ + Return a normalization layer. + """ + if norm_type == "group": + return GroupNorm32(32, *args, **kwargs) + elif norm_type == "layer": + return ChannelLayerNorm32(*args, **kwargs) + else: + raise ValueError(f"Invalid norm type {norm_type}") + + +class ResBlock3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + norm_type: Literal["group", "layer"] = "layer", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.norm1 = norm_layer(norm_type, channels) + self.norm2 = norm_layer(norm_type, self.out_channels) + self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1) + self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1)) + self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = F.silu(h) + h = self.conv1(h) + h = self.norm2(h) + h = F.silu(h) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + +class DownsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode: Literal["conv", "avgpool"] = "conv", + ): + assert mode in ["conv", "avgpool"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2) + elif mode == "avgpool": + assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + return self.conv(x) + else: + return F.avg_pool3d(x, 2) + + +class UpsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode: Literal["conv", "nearest"] = "conv", + ): + assert mode in ["conv", "nearest"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1) + elif mode == "nearest": + assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + x = self.conv(x) + return pixel_shuffle_3d(x, 2) + else: + return F.interpolate(x, scale_factor=2, mode="nearest") + + +class SparseStructureEncoder(nn.Module): + """ + Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3). + + Args: + in_channels (int): Channels of the input. + latent_channels (int): Channels of the latent representation. + num_res_blocks (int): Number of residual blocks at each resolution. + channels (List[int]): Channels of the encoder blocks. + num_res_blocks_middle (int): Number of residual blocks in the middle. + norm_type (Literal["group", "layer"]): Type of normalization layer. + use_fp16 (bool): Whether to use FP16. + """ + def __init__( + self, + in_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type: Literal["group", "layer"] = "layer", + use_fp16: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + DownsampleBlock3d(ch, channels[i+1]) + ) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[-1], channels[-1]) + for _ in range(num_res_blocks_middle) + ]) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor: + h = self.input_layer(x) + h = h.type(self.dtype) + + for block in self.blocks: + h = block(h) + h = self.middle_block(h) + + h = h.type(x.dtype) + h = self.out_layer(h) + + mean, logvar = h.chunk(2, dim=1) + + if sample_posterior: + std = torch.exp(0.5 * logvar) + z = mean + std * torch.randn_like(std) + else: + z = mean + + if return_raw: + return z, mean, logvar + return z + + +class SparseStructureDecoder(nn.Module): + """ + Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3). + + Args: + out_channels (int): Channels of the output. + latent_channels (int): Channels of the latent representation. + num_res_blocks (int): Number of residual blocks at each resolution. + channels (List[int]): Channels of the decoder blocks. + num_res_blocks_middle (int): Number of residual blocks in the middle. + norm_type (Literal["group", "layer"]): Type of normalization layer. + use_fp16 (bool): Whether to use FP16. + """ + def __init__( + self, + out_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type: Literal["group", "layer"] = "layer", + use_fp16: bool = False, + ): + super().__init__() + self.out_channels = out_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[0], channels[0]) + for _ in range(num_res_blocks_middle) + ]) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + UpsampleBlock3d(ch, channels[i+1]) + ) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], out_channels, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.input_layer(x) + + h = h.type(self.dtype) + + h = self.middle_block(h) + for block in self.blocks: + h = block(h) + + h = h.type(x.dtype) + h = self.out_layer(h) + return h diff --git a/trellis/models/structured_latent_flow.py b/trellis/models/structured_latent_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..19c11597244ea53505d746b593d10cddad4bcb6f --- /dev/null +++ b/trellis/models/structured_latent_flow.py @@ -0,0 +1,262 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 +from ..modules.transformer import AbsolutePositionEmbedder +from ..modules.norm import LayerNorm32 +from ..modules import sparse as sp +from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock +from .sparse_structure_flow import TimestepEmbedder + + +class SparseResBlock3d(nn.Module): + def __init__( + self, + channels: int, + emb_channels: int, + out_channels: Optional[int] = None, + downsample: bool = False, + upsample: bool = False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.out_channels = out_channels or channels + self.downsample = downsample + self.upsample = upsample + + assert not (downsample and upsample), "Cannot downsample and upsample at the same time" + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear(emb_channels, 2 * self.out_channels, bias=True), + ) + self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() + self.updown = None + if self.downsample: + self.updown = sp.SparseDownsample(2) + elif self.upsample: + self.updown = sp.SparseUpsample(2) + + def _updown(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.updown is not None: + x = self.updown(x) + return x + + def forward(self, x: sp.SparseTensor, emb: torch.Tensor) -> sp.SparseTensor: + emb_out = self.emb_layers(emb).type(x.dtype) + scale, shift = torch.chunk(emb_out, 2, dim=1) + + x = self._updown(x) + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv1(h) + h = h.replace(self.norm2(h.feats)) * (1 + scale) + shift + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + + return h + + +class SLatFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + patch_size: int = 2, + num_io_res_blocks: int = 2, + io_block_channels: List[int] = None, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + use_skip_connection: bool = True, + share_mod: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.patch_size = patch_size + self.num_io_res_blocks = num_io_res_blocks + self.io_block_channels = io_block_channels + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_checkpoint = use_checkpoint + self.use_skip_connection = use_skip_connection + self.share_mod = share_mod + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = torch.float16 if use_fp16 else torch.float32 + + assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2" + assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages" + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + if pe_mode == "ape": + self.pos_embedder = AbsolutePositionEmbedder(model_channels) + + self.input_layer = sp.SparseLinear(in_channels, io_block_channels[0]) + self.input_blocks = nn.ModuleList([]) + for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]): + self.input_blocks.extend([ + SparseResBlock3d( + chs, + model_channels, + out_channels=chs, + ) + for _ in range(num_io_res_blocks-1) + ]) + self.input_blocks.append( + SparseResBlock3d( + chs, + model_channels, + out_channels=next_chs, + downsample=True, + ) + ) + + self.blocks = nn.ModuleList([ + ModulatedSparseTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + share_mod=self.share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ]) + + self.out_blocks = nn.ModuleList([]) + for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))): + self.out_blocks.append( + SparseResBlock3d( + prev_chs * 2 if self.use_skip_connection else prev_chs, + model_channels, + out_channels=chs, + upsample=True, + ) + ) + self.out_blocks.extend([ + SparseResBlock3d( + chs * 2 if self.use_skip_connection else chs, + model_channels, + out_channels=chs, + ) + for _ in range(num_io_res_blocks-1) + ]) + self.out_layer = sp.SparseLinear(io_block_channels[0], out_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.blocks.apply(convert_module_to_f16) + self.out_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.blocks.apply(convert_module_to_f32) + self.out_blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor) -> sp.SparseTensor: + h = self.input_layer(x).type(self.dtype) + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = t_emb.type(self.dtype) + cond = cond.type(self.dtype) + + skips = [] + # pack with input blocks + for block in self.input_blocks: + h = block(h, t_emb) + skips.append(h.feats) + + if self.pe_mode == "ape": + h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond) + + # unpack with output blocks + for block, skip in zip(self.out_blocks, reversed(skips)): + if self.use_skip_connection: + h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb) + else: + h = block(h, t_emb) + + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h.type(x.dtype)) + return h diff --git a/trellis/models/structured_latent_vae/__init__.py b/trellis/models/structured_latent_vae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..00cbf8826b328f9abe76cd641645f67128d7a04b --- /dev/null +++ b/trellis/models/structured_latent_vae/__init__.py @@ -0,0 +1,4 @@ +from .encoder import SLatEncoder +from .decoder_gs import SLatGaussianDecoder +from .decoder_rf import SLatRadianceFieldDecoder +from .decoder_mesh import SLatMeshDecoder diff --git a/trellis/models/structured_latent_vae/base.py b/trellis/models/structured_latent_vae/base.py new file mode 100644 index 0000000000000000000000000000000000000000..7b86006fb35dee6f4f61a6f827d13787e0a287b2 --- /dev/null +++ b/trellis/models/structured_latent_vae/base.py @@ -0,0 +1,117 @@ +from typing import * +import torch +import torch.nn as nn +from ...modules.utils import convert_module_to_f16, convert_module_to_f32 +from ...modules import sparse as sp +from ...modules.transformer import AbsolutePositionEmbedder +from ...modules.sparse.transformer import SparseTransformerBlock + + +def block_attn_config(self): + """ + Return the attention configuration of the model. + """ + for i in range(self.num_blocks): + if self.attn_mode == "shift_window": + yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER + elif self.attn_mode == "shift_sequence": + yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER + elif self.attn_mode == "shift_order": + yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4] + elif self.attn_mode == "full": + yield "full", None, None, None, None + elif self.attn_mode == "swin": + yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None + + +class SparseTransformerBase(nn.Module): + """ + Sparse Transformer without output layers. + Serve as the base class for encoder and decoder. + """ + def __init__( + self, + in_channels: int, + model_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.model_channels = model_channels + self.num_blocks = num_blocks + self.window_size = window_size + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.attn_mode = attn_mode + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_checkpoint = use_checkpoint + self.qk_rms_norm = qk_rms_norm + self.dtype = torch.float16 if use_fp16 else torch.float32 + + if pe_mode == "ape": + self.pos_embedder = AbsolutePositionEmbedder(model_channels) + + self.input_layer = sp.SparseLinear(in_channels, model_channels) + self.blocks = nn.ModuleList([ + SparseTransformerBlock( + model_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + qk_rms_norm=self.qk_rms_norm, + ) + for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self) + ]) + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = self.input_layer(x) + if self.pe_mode == "ape": + h = h + self.pos_embedder(x.coords[:, 1:]) + h = h.type(self.dtype) + for block in self.blocks: + h = block(h) + return h diff --git a/trellis/models/structured_latent_vae/decoder_gs.py b/trellis/models/structured_latent_vae/decoder_gs.py new file mode 100644 index 0000000000000000000000000000000000000000..b6948173f57063ea1eff411f4840c8e1a711bd69 --- /dev/null +++ b/trellis/models/structured_latent_vae/decoder_gs.py @@ -0,0 +1,122 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...modules import sparse as sp +from ...utils.random_utils import hammersley_sequence +from .base import SparseTransformerBase +from ...representations import Gaussian + + +class SLatGaussianDecoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + representation_config: dict = None, + ): + super().__init__( + in_channels=latent_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.rep_config = representation_config + self._calc_layout() + self.out_layer = sp.SparseLinear(model_channels, self.out_channels) + self._build_perturbation() + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def _build_perturbation(self) -> None: + perturbation = [hammersley_sequence(3, i, self.rep_config['num_gaussians']) for i in range(self.rep_config['num_gaussians'])] + perturbation = torch.tensor(perturbation).float() * 2 - 1 + perturbation = perturbation / self.rep_config['voxel_size'] + perturbation = torch.atanh(perturbation).to(self.device) + self.register_buffer('offset_perturbation', perturbation) + + def _calc_layout(self) -> None: + self.layout = { + '_xyz' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3}, + '_features_dc' : {'shape': (self.rep_config['num_gaussians'], 1, 3), 'size': self.rep_config['num_gaussians'] * 3}, + '_scaling' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3}, + '_rotation' : {'shape': (self.rep_config['num_gaussians'], 4), 'size': self.rep_config['num_gaussians'] * 4}, + '_opacity' : {'shape': (self.rep_config['num_gaussians'], 1), 'size': self.rep_config['num_gaussians']}, + } + start = 0 + for k, v in self.layout.items(): + v['range'] = (start, start + v['size']) + start += v['size'] + self.out_channels = start + + def to_representation(self, x: sp.SparseTensor) -> List[Gaussian]: + """ + Convert a batch of network outputs to 3D representations. + + Args: + x: The [N x * x C] sparse tensor output by the network. + + Returns: + list of representations + """ + ret = [] + for i in range(x.shape[0]): + representation = Gaussian( + sh_degree=0, + aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0], + mininum_kernel_size = self.rep_config['3d_filter_kernel_size'], + scaling_bias = self.rep_config['scaling_bias'], + opacity_bias = self.rep_config['opacity_bias'], + scaling_activation = self.rep_config['scaling_activation'] + ) + xyz = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution + for k, v in self.layout.items(): + if k == '_xyz': + offset = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']) + offset = offset * self.rep_config['lr'][k] + if self.rep_config['perturb_offset']: + offset = offset + self.offset_perturbation + offset = torch.tanh(offset) / self.resolution * 0.5 * self.rep_config['voxel_size'] + _xyz = xyz.unsqueeze(1) + offset + setattr(representation, k, _xyz.flatten(0, 1)) + else: + feats = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1) + feats = feats * self.rep_config['lr'][k] + setattr(representation, k, feats) + ret.append(representation) + return ret + + def forward(self, x: sp.SparseTensor) -> List[Gaussian]: + h = super().forward(x) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + return self.to_representation(h) diff --git a/trellis/models/structured_latent_vae/decoder_mesh.py b/trellis/models/structured_latent_vae/decoder_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..06c0e7286af45eab1c9861b65174431fc2210bcc --- /dev/null +++ b/trellis/models/structured_latent_vae/decoder_mesh.py @@ -0,0 +1,167 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 +from ...modules import sparse as sp +from .base import SparseTransformerBase +from ...representations import MeshExtractResult +from ...representations.mesh import SparseFeatures2Mesh + + +class SparseSubdivideBlock3d(nn.Module): + """ + A 3D subdivide block that can subdivide the sparse tensor. + + Args: + channels: channels in the inputs and outputs. + out_channels: if specified, the number of output channels. + num_groups: the number of groups for the group norm. + """ + def __init__( + self, + channels: int, + resolution: int, + out_channels: Optional[int] = None, + num_groups: int = 32 + ): + super().__init__() + self.channels = channels + self.resolution = resolution + self.out_resolution = resolution * 2 + self.out_channels = out_channels or channels + + self.act_layers = nn.Sequential( + sp.SparseGroupNorm32(num_groups, channels), + sp.SparseSiLU() + ) + + self.sub = sp.SparseSubdivide() + + self.out_layers = nn.Sequential( + sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"), + sp.SparseGroupNorm32(num_groups, self.out_channels), + sp.SparseSiLU(), + zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}") + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + Args: + x: an [N x C x ...] Tensor of features. + Returns: + an [N x C x ...] Tensor of outputs. + """ + h = self.act_layers(x) + h = self.sub(h) + x = self.sub(x) + h = self.out_layers(h) + h = h + self.skip_connection(x) + return h + + +class SLatMeshDecoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + representation_config: dict = None, + ): + super().__init__( + in_channels=latent_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.rep_config = representation_config + self.mesh_extractor = SparseFeatures2Mesh(res=self.resolution*4, use_color=self.rep_config.get('use_color', False)) + self.out_channels = self.mesh_extractor.feats_channels + self.upsample = nn.ModuleList([ + SparseSubdivideBlock3d( + channels=model_channels, + resolution=resolution, + out_channels=model_channels // 4 + ), + SparseSubdivideBlock3d( + channels=model_channels // 4, + resolution=resolution * 2, + out_channels=model_channels // 8 + ) + ]) + self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + super().convert_to_fp16() + self.upsample.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + super().convert_to_fp32() + self.upsample.apply(convert_module_to_f32) + + def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]: + """ + Convert a batch of network outputs to 3D representations. + + Args: + x: The [N x * x C] sparse tensor output by the network. + + Returns: + list of representations + """ + ret = [] + for i in range(x.shape[0]): + mesh = self.mesh_extractor(x[i], training=self.training) + ret.append(mesh) + return ret + + def forward(self, x: sp.SparseTensor) -> List[MeshExtractResult]: + h = super().forward(x) + for block in self.upsample: + h = block(h) + h = h.type(x.dtype) + h = self.out_layer(h) + return self.to_representation(h) diff --git a/trellis/models/structured_latent_vae/decoder_rf.py b/trellis/models/structured_latent_vae/decoder_rf.py new file mode 100644 index 0000000000000000000000000000000000000000..4e916eebafd7e97fed82cadb567244719dbbcd83 --- /dev/null +++ b/trellis/models/structured_latent_vae/decoder_rf.py @@ -0,0 +1,104 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ...modules import sparse as sp +from .base import SparseTransformerBase +from ...representations import Strivec + + +class SLatRadianceFieldDecoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + representation_config: dict = None, + ): + super().__init__( + in_channels=latent_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.rep_config = representation_config + self._calc_layout() + self.out_layer = sp.SparseLinear(model_channels, self.out_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def _calc_layout(self) -> None: + self.layout = { + 'trivec': {'shape': (self.rep_config['rank'], 3, self.rep_config['dim']), 'size': self.rep_config['rank'] * 3 * self.rep_config['dim']}, + 'density': {'shape': (self.rep_config['rank'],), 'size': self.rep_config['rank']}, + 'features_dc': {'shape': (self.rep_config['rank'], 1, 3), 'size': self.rep_config['rank'] * 3}, + } + start = 0 + for k, v in self.layout.items(): + v['range'] = (start, start + v['size']) + start += v['size'] + self.out_channels = start + + def to_representation(self, x: sp.SparseTensor) -> List[Strivec]: + """ + Convert a batch of network outputs to 3D representations. + + Args: + x: The [N x * x C] sparse tensor output by the network. + + Returns: + list of representations + """ + ret = [] + for i in range(x.shape[0]): + representation = Strivec( + sh_degree=0, + resolution=self.resolution, + aabb=[-0.5, -0.5, -0.5, 1, 1, 1], + rank=self.rep_config['rank'], + dim=self.rep_config['dim'], + device='cuda', + ) + representation.density_shift = 0.0 + representation.position = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution + representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda') + for k, v in self.layout.items(): + setattr(representation, k, x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape'])) + representation.trivec = representation.trivec + 1 + ret.append(representation) + return ret + + def forward(self, x: sp.SparseTensor) -> List[Strivec]: + h = super().forward(x) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + return self.to_representation(h) diff --git a/trellis/models/structured_latent_vae/encoder.py b/trellis/models/structured_latent_vae/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d3c04928bbd0a3fe88c05c687024a92daa0a1d6d --- /dev/null +++ b/trellis/models/structured_latent_vae/encoder.py @@ -0,0 +1,72 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...modules import sparse as sp +from .base import SparseTransformerBase + + +class SLatEncoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__( + in_channels=in_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: sp.SparseTensor, sample_posterior=True, return_raw=False): + h = super().forward(x) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + + # Sample from the posterior distribution + mean, logvar = h.feats.chunk(2, dim=-1) + if sample_posterior: + std = torch.exp(0.5 * logvar) + z = mean + std * torch.randn_like(std) + else: + z = mean + z = h.replace(z) + + if return_raw: + return z, mean, logvar + else: + return z diff --git a/trellis/modules/attention/__init__.py b/trellis/modules/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ffebf7dbf737606b32ef62c2e86f568189d322f0 --- /dev/null +++ b/trellis/modules/attention/__init__.py @@ -0,0 +1,36 @@ +from typing import * + +BACKEND = 'flash_attn' +DEBUG = False + +def __from_env(): + import os + + global BACKEND + global DEBUG + + env_attn_backend = os.environ.get('ATTN_BACKEND') + env_sttn_debug = os.environ.get('ATTN_DEBUG') + + if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']: + BACKEND = env_attn_backend + if env_sttn_debug is not None: + DEBUG = env_sttn_debug == '1' + + print(f"[ATTENTION] Using backend: {BACKEND}") + + +__from_env() + + +def set_backend(backend: Literal['xformers', 'flash_attn']): + global BACKEND + BACKEND = backend + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug + + +from .full_attn import * +from .modules import * diff --git a/trellis/modules/attention/full_attn.py b/trellis/modules/attention/full_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..68303dca94cefacb43865d3b737f2723dded20dd --- /dev/null +++ b/trellis/modules/attention/full_attn.py @@ -0,0 +1,140 @@ +from typing import * +import torch +import math +from . import DEBUG, BACKEND + +if BACKEND == 'xformers': + import xformers.ops as xops +elif BACKEND == 'flash_attn': + import flash_attn +elif BACKEND == 'sdpa': + from torch.nn.functional import scaled_dot_product_attention as sdpa +elif BACKEND == 'naive': + pass +else: + raise ValueError(f"Unknown attention backend: {BACKEND}") + + +__all__ = [ + 'scaled_dot_product_attention', +] + + +def _naive_sdpa(q, k, v): + """ + Naive implementation of scaled dot product attention. + """ + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + scale_factor = 1 / math.sqrt(q.size(-1)) + attn_weight = q @ k.transpose(-2, -1) * scale_factor + attn_weight = torch.softmax(attn_weight, dim=-1) + out = attn_weight @ v + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + return out + + +@overload +def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs. + """ + ... + +@overload +def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, C] tensor containing Qs. + kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs. + """ + ... + +@overload +def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + +def scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]" + device = qkv.device + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + device = q.device + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + device = q.device + + if BACKEND == 'xformers': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = xops.memory_efficient_attention(q, k, v) + elif BACKEND == 'flash_attn': + if num_all_args == 1: + out = flash_attn.flash_attn_qkvpacked_func(qkv) + elif num_all_args == 2: + out = flash_attn.flash_attn_kvpacked_func(q, kv) + elif num_all_args == 3: + out = flash_attn.flash_attn_func(q, k, v) + elif BACKEND == 'sdpa': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + out = sdpa(q, k, v) # [N, H, L, C] + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + elif BACKEND == 'naive': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = _naive_sdpa(q, k, v) + else: + raise ValueError(f"Unknown attention module: {BACKEND}") + + return out diff --git a/trellis/modules/attention/modules.py b/trellis/modules/attention/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..a82a9d27b6c02c7c88c724ddc0456fe61f6c0fb0 --- /dev/null +++ b/trellis/modules/attention/modules.py @@ -0,0 +1,146 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from .full_attn import scaled_dot_product_attention + + +class MultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) + + +class RotaryPositionEmbedder(nn.Module): + def __init__(self, hidden_size: int, in_channels: int = 3): + super().__init__() + assert hidden_size % 2 == 0, "Hidden size must be divisible by 2" + self.hidden_size = hidden_size + self.in_channels = in_channels + self.freq_dim = hidden_size // in_channels // 2 + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = 1.0 / (10000 ** self.freqs) + + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: + self.freqs = self.freqs.to(indices.device) + phases = torch.outer(indices, self.freqs) + phases = torch.polar(torch.ones_like(phases), phases) + return phases + + def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = x_complex * phases + x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + return x_embed + + def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + q (sp.SparseTensor): [..., N, D] tensor of queries + k (sp.SparseTensor): [..., N, D] tensor of keys + indices (torch.Tensor): [..., N, C] tensor of spatial positions + """ + if indices is None: + indices = torch.arange(q.shape[-2], device=q.device) + if len(q.shape) > 2: + indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,)) + + phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) + if phases.shape[1] < self.hidden_size // 2: + phases = torch.cat([phases, torch.polar( + torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device), + torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device) + )], dim=-1) + q_embed = self._rotary_embedding(q, phases) + k_embed = self._rotary_embedding(k, phases) + return q_embed, k_embed + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int]=None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__() + assert channels % num_heads == 0 + assert type in ["self", "cross"], f"Invalid attention type: {type}" + assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}" + assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" + + if attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + + self.to_out = nn.Linear(channels, channels) + + if use_rope: + self.rope = RotaryPositionEmbedder(channels) + + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor: + B, L, C = x.shape + if self._type == "self": + qkv = self.to_qkv(x) + qkv = qkv.reshape(B, L, 3, self.num_heads, -1) + if self.use_rope: + q, k, v = qkv.unbind(dim=2) + q, k = self.rope(q, k, indices) + qkv = torch.stack([q, k, v], dim=2) + if self.attn_mode == "full": + if self.qk_rms_norm: + q, k, v = qkv.unbind(dim=2) + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(qkv) + elif self.attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + else: + Lkv = context.shape[1] + q = self.to_q(x) + kv = self.to_kv(context) + q = q.reshape(B, L, self.num_heads, -1) + kv = kv.reshape(B, Lkv, 2, self.num_heads, -1) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=2) + k = self.k_rms_norm(k) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(q, kv) + h = h.reshape(B, L, -1) + h = self.to_out(h) + return h diff --git a/trellis/modules/norm.py b/trellis/modules/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..8f1750b88aae9e5ea901fc7af9f978a0db82de6d --- /dev/null +++ b/trellis/modules/norm.py @@ -0,0 +1,25 @@ +import torch +import torch.nn as nn + + +class LayerNorm32(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x.float()).type(x.dtype) + + +class GroupNorm32(nn.GroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x.float()).type(x.dtype) + + +class ChannelLayerNorm32(LayerNorm32): + def forward(self, x: torch.Tensor) -> torch.Tensor: + DIM = x.dim() + x = x.permute(0, *range(2, DIM), 1).contiguous() + x = super().forward(x) + x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous() + return x + \ No newline at end of file diff --git a/trellis/modules/sparse/__init__.py b/trellis/modules/sparse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..df649cc4431a0f7f1a49ed15780f4217399adf66 --- /dev/null +++ b/trellis/modules/sparse/__init__.py @@ -0,0 +1,102 @@ +from typing import * + +BACKEND = 'spconv' +DEBUG = False +ATTN = 'flash_attn' + +def __from_env(): + import os + + global BACKEND + global DEBUG + global ATTN + + env_sparse_backend = os.environ.get('SPARSE_BACKEND') + env_sparse_debug = os.environ.get('SPARSE_DEBUG') + env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND') + if env_sparse_attn is None: + env_sparse_attn = os.environ.get('ATTN_BACKEND') + + if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']: + BACKEND = env_sparse_backend + if env_sparse_debug is not None: + DEBUG = env_sparse_debug == '1' + if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']: + ATTN = env_sparse_attn + + print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}") + + +__from_env() + + +def set_backend(backend: Literal['spconv', 'torchsparse']): + global BACKEND + BACKEND = backend + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug + +def set_attn(attn: Literal['xformers', 'flash_attn']): + global ATTN + ATTN = attn + + +import importlib + +__attributes = { + 'SparseTensor': 'basic', + 'sparse_batch_broadcast': 'basic', + 'sparse_batch_op': 'basic', + 'sparse_cat': 'basic', + 'sparse_unbind': 'basic', + 'SparseGroupNorm': 'norm', + 'SparseLayerNorm': 'norm', + 'SparseGroupNorm32': 'norm', + 'SparseLayerNorm32': 'norm', + 'SparseReLU': 'nonlinearity', + 'SparseSiLU': 'nonlinearity', + 'SparseGELU': 'nonlinearity', + 'SparseActivation': 'nonlinearity', + 'SparseLinear': 'linear', + 'sparse_scaled_dot_product_attention': 'attention', + 'SerializeMode': 'attention', + 'sparse_serialized_scaled_dot_product_self_attention': 'attention', + 'sparse_windowed_scaled_dot_product_self_attention': 'attention', + 'SparseMultiHeadAttention': 'attention', + 'SparseConv3d': 'conv', + 'SparseInverseConv3d': 'conv', + 'SparseDownsample': 'spatial', + 'SparseUpsample': 'spatial', + 'SparseSubdivide' : 'spatial' +} + +__submodules = ['transformer'] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .basic import * + from .norm import * + from .nonlinearity import * + from .linear import * + from .attention import * + from .conv import * + from .spatial import * + import transformer diff --git a/trellis/modules/sparse/attention/__init__.py b/trellis/modules/sparse/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..400de9a25960cf8ed32a3d7ec143af95b5f862bc --- /dev/null +++ b/trellis/modules/sparse/attention/__init__.py @@ -0,0 +1,4 @@ +from .full_attn import * +from .serialized_attn import * +from .windowed_attn import * +from .modules import * diff --git a/trellis/modules/sparse/attention/full_attn.py b/trellis/modules/sparse/attention/full_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..c724327678d38a5625699ace09c67107458b4d0a --- /dev/null +++ b/trellis/modules/sparse/attention/full_attn.py @@ -0,0 +1,215 @@ +from typing import * +import torch +from .. import SparseTensor +from .. import DEBUG, ATTN + +if ATTN == 'xformers': + import xformers.ops as xops +elif ATTN == 'flash_attn': + import flash_attn +else: + raise ValueError(f"Unknown attention module: {ATTN}") + + +__all__ = [ + 'sparse_scaled_dot_product_attention', +] + + +@overload +def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs. + kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, L, H, C] dense tensor containing Qs. + kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs. + k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs. + """ + ... + +def sparse_scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}" + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + device = qkv.device + + s = qkv + q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])] + kv_seqlen = q_seqlen + qkv = qkv.feats # [T, 3, H, C] + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \ + isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \ + f"Invalid types, got {type(q)} and {type(kv)}" + assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + device = q.device + + if isinstance(q, SparseTensor): + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, C] + else: + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + s = None + N, L, H, C = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, C) # [T_Q, H, C] + + if isinstance(kv, SparseTensor): + assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]" + kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])] + kv = kv.feats # [T_KV, 2, H, C] + else: + assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + N, L, _, H, C = kv.shape + kv_seqlen = [L] * N + kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C] + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \ + isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \ + f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}" + assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + device = q.device + + if isinstance(q, SparseTensor): + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, Ci] + else: + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + s = None + N, L, H, CI = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, CI) # [T_Q, H, Ci] + + if isinstance(k, SparseTensor): + assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]" + assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]" + kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])] + k = k.feats # [T_KV, H, Ci] + v = v.feats # [T_KV, H, Co] + else: + assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + N, L, H, CI, CO = *k.shape, v.shape[-1] + kv_seqlen = [L] * N + k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] + v = v.reshape(N * L, H, CO) # [T_KV, H, Co] + + if DEBUG: + if s is not None: + for i in range(s.shape[0]): + assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch" + if num_all_args in [2, 3]: + assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch" + if num_all_args == 3: + assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch" + assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch" + + if ATTN == 'xformers': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + out = xops.memory_efficient_attention(q, k, v, mask)[0] + elif ATTN == 'flash_attn': + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args in [2, 3]: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + if num_all_args == 1: + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen)) + elif num_all_args == 2: + out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif num_all_args == 3: + out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + else: + raise ValueError(f"Unknown attention module: {ATTN}") + + if s is not None: + return s.replace(out) + else: + return out.reshape(N, L, H, -1) diff --git a/trellis/modules/sparse/attention/modules.py b/trellis/modules/sparse/attention/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..d8fbb572786483a840dc325097eacc08b815a0a5 --- /dev/null +++ b/trellis/modules/sparse/attention/modules.py @@ -0,0 +1,139 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from .. import SparseTensor +from .full_attn import sparse_scaled_dot_product_attention +from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention +from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention +from ...attention import RotaryPositionEmbedder + + +class SparseMultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]: + x_type = x.dtype + x = x.float() + if isinstance(x, SparseTensor): + x = x.replace(F.normalize(x.feats, dim=-1)) + else: + x = F.normalize(x, dim=-1) + return (x * self.gamma * self.scale).to(x_type) + + +class SparseMultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int] = None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "serialized", "windowed"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + qkv_bias: bool = True, + use_rope: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__() + assert channels % num_heads == 0 + assert type in ["self", "cross"], f"Invalid attention type: {type}" + assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}" + assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" + assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention" + self.channels = channels + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_sequence = shift_sequence + self.shift_window = shift_window + self.serialize_mode = serialize_mode + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads) + self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads) + + self.to_out = nn.Linear(channels, channels) + + if use_rope: + self.rope = RotaryPositionEmbedder(channels) + + @staticmethod + def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]: + if isinstance(x, SparseTensor): + return x.replace(module(x.feats)) + else: + return module(x) + + @staticmethod + def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]: + if isinstance(x, SparseTensor): + return x.reshape(*shape) + else: + return x.reshape(*x.shape[:2], *shape) + + def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]: + if isinstance(x, SparseTensor): + x_feats = x.feats.unsqueeze(0) + else: + x_feats = x + x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1) + return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats + + def _rope(self, qkv: SparseTensor) -> SparseTensor: + q, k, v = qkv.feats.unbind(dim=1) # [T, H, C] + q, k = self.rope(q, k, qkv.coords[:, 1:]) + qkv = qkv.replace(torch.stack([q, k, v], dim=1)) + return qkv + + def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]: + if self._type == "self": + qkv = self._linear(self.to_qkv, x) + qkv = self._fused_pre(qkv, num_fused=3) + if self.use_rope: + qkv = self._rope(qkv) + if self.qk_rms_norm: + q, k, v = qkv.unbind(dim=1) + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1)) + if self.attn_mode == "full": + h = sparse_scaled_dot_product_attention(qkv) + elif self.attn_mode == "serialized": + h = sparse_serialized_scaled_dot_product_self_attention( + qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window + ) + elif self.attn_mode == "windowed": + h = sparse_windowed_scaled_dot_product_self_attention( + qkv, self.window_size, shift_window=self.shift_window + ) + else: + q = self._linear(self.to_q, x) + q = self._reshape_chs(q, (self.num_heads, -1)) + kv = self._linear(self.to_kv, context) + kv = self._fused_pre(kv, num_fused=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=1) + k = self.k_rms_norm(k) + kv = kv.replace(torch.stack([k.feats, v.feats], dim=1)) + h = sparse_scaled_dot_product_attention(q, kv) + h = self._reshape_chs(h, (-1,)) + h = self._linear(self.to_out, h) + return h diff --git a/trellis/modules/sparse/attention/serialized_attn.py b/trellis/modules/sparse/attention/serialized_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..b3da276c4b47db2e2816d61bfa66db413aa6b7aa --- /dev/null +++ b/trellis/modules/sparse/attention/serialized_attn.py @@ -0,0 +1,193 @@ +from typing import * +from enum import Enum +import torch +import math +from .. import SparseTensor +from .. import DEBUG, ATTN + +if ATTN == 'xformers': + import xformers.ops as xops +elif ATTN == 'flash_attn': + import flash_attn +else: + raise ValueError(f"Unknown attention module: {ATTN}") + + +__all__ = [ + 'sparse_serialized_scaled_dot_product_self_attention', +] + + +class SerializeMode(Enum): + Z_ORDER = 0 + Z_ORDER_TRANSPOSED = 1 + HILBERT = 2 + HILBERT_TRANSPOSED = 3 + + +SerializeModes = [ + SerializeMode.Z_ORDER, + SerializeMode.Z_ORDER_TRANSPOSED, + SerializeMode.HILBERT, + SerializeMode.HILBERT_TRANSPOSED +] + + +def calc_serialization( + tensor: SparseTensor, + window_size: int, + serialize_mode: SerializeMode = SerializeMode.Z_ORDER, + shift_sequence: int = 0, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: + """ + Calculate serialization and partitioning for a set of coordinates. + + Args: + tensor (SparseTensor): The input tensor. + window_size (int): The window size to use. + serialize_mode (SerializeMode): The serialization mode to use. + shift_sequence (int): The shift of serialized sequence. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + + Returns: + (torch.Tensor, torch.Tensor): Forwards and backwards indices. + """ + fwd_indices = [] + bwd_indices = [] + seq_lens = [] + seq_batch_indices = [] + offsets = [0] + + if 'vox2seq' not in globals(): + import vox2seq + + # Serialize the input + serialize_coords = tensor.coords[:, 1:].clone() + serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3) + if serialize_mode == SerializeMode.Z_ORDER: + code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2]) + elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED: + code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2]) + elif serialize_mode == SerializeMode.HILBERT: + code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2]) + elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED: + code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2]) + else: + raise ValueError(f"Unknown serialize mode: {serialize_mode}") + + for bi, s in enumerate(tensor.layout): + num_points = s.stop - s.start + num_windows = (num_points + window_size - 1) // window_size + valid_window_size = num_points / num_windows + to_ordered = torch.argsort(code[s.start:s.stop]) + if num_windows == 1: + fwd_indices.append(to_ordered) + bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device))) + fwd_indices[-1] += s.start + bwd_indices[-1] += offsets[-1] + seq_lens.append(num_points) + seq_batch_indices.append(bi) + offsets.append(offsets[-1] + seq_lens[-1]) + else: + # Partition the input + offset = 0 + mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)] + split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)] + bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device) + for i in range(num_windows): + mid = mids[i] + valid_start = split[i] + valid_end = split[i + 1] + padded_start = math.floor(mid - 0.5 * window_size) + padded_end = padded_start + window_size + fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points]) + offset += valid_start - padded_start + bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device)) + offset += padded_end - valid_start + fwd_indices[-1] += s.start + seq_lens.extend([window_size] * num_windows) + seq_batch_indices.extend([bi] * num_windows) + bwd_indices.append(bwd_index + offsets[-1]) + offsets.append(offsets[-1] + num_windows * window_size) + + fwd_indices = torch.cat(fwd_indices) + bwd_indices = torch.cat(bwd_indices) + + return fwd_indices, bwd_indices, seq_lens, seq_batch_indices + + +def sparse_serialized_scaled_dot_product_self_attention( + qkv: SparseTensor, + window_size: int, + serialize_mode: SerializeMode = SerializeMode.Z_ORDER, + shift_sequence: int = 0, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> SparseTensor: + """ + Apply serialized scaled dot product self attention to a sparse tensor. + + Args: + qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + window_size (int): The window size to use. + serialize_mode (SerializeMode): The serialization mode to use. + shift_sequence (int): The shift of serialized sequence. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + shift (int): The shift to use. + """ + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + + serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}' + serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window) + qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) + else: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache + + M = fwd_indices.shape[0] + T = qkv.feats.shape[0] + H = qkv.feats.shape[2] + C = qkv.feats.shape[3] + + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + + if DEBUG: + start = 0 + qkv_coords = qkv.coords[fwd_indices] + for i in range(len(seq_lens)): + assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" + start += seq_lens[i] + + if all([seq_len == window_size for seq_len in seq_lens]): + B = len(seq_lens) + N = window_size + qkv_feats = qkv_feats.reshape(B, N, 3, H, C) + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] + out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] + elif ATTN == 'flash_attn': + out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] + else: + raise ValueError(f"Unknown attention module: {ATTN}") + out = out.reshape(B * N, H, C) # [M, H, C] + else: + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] + elif ATTN == 'flash_attn': + cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ + .to(qkv.device).int() + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C] + + out = out[bwd_indices] # [T, H, C] + + if DEBUG: + qkv_coords = qkv_coords[bwd_indices] + assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" + + return qkv.replace(out) diff --git a/trellis/modules/sparse/attention/windowed_attn.py b/trellis/modules/sparse/attention/windowed_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..11eebf851316d8b4c1f6ca39b881c422d8f2f088 --- /dev/null +++ b/trellis/modules/sparse/attention/windowed_attn.py @@ -0,0 +1,135 @@ +from typing import * +import torch +import math +from .. import SparseTensor +from .. import DEBUG, ATTN + +if ATTN == 'xformers': + import xformers.ops as xops +elif ATTN == 'flash_attn': + import flash_attn +else: + raise ValueError(f"Unknown attention module: {ATTN}") + + +__all__ = [ + 'sparse_windowed_scaled_dot_product_self_attention', +] + + +def calc_window_partition( + tensor: SparseTensor, + window_size: Union[int, Tuple[int, ...]], + shift_window: Union[int, Tuple[int, ...]] = 0 +) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]: + """ + Calculate serialization and partitioning for a set of coordinates. + + Args: + tensor (SparseTensor): The input tensor. + window_size (int): The window size to use. + shift_window (Tuple[int, ...]): The shift of serialized coordinates. + + Returns: + (torch.Tensor): Forwards indices. + (torch.Tensor): Backwards indices. + (List[int]): Sequence lengths. + (List[int]): Sequence batch indices. + """ + DIM = tensor.coords.shape[1] - 1 + shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window + window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size + shifted_coords = tensor.coords.clone().detach() + shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0) + + MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist() + NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)] + OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1] + + shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0) + shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1) + fwd_indices = torch.argsort(shifted_indices) + bwd_indices = torch.empty_like(fwd_indices) + bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device) + seq_lens = torch.bincount(shifted_indices) + seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0] + mask = seq_lens != 0 + seq_lens = seq_lens[mask].tolist() + seq_batch_indices = seq_batch_indices[mask].tolist() + + return fwd_indices, bwd_indices, seq_lens, seq_batch_indices + + +def sparse_windowed_scaled_dot_product_self_attention( + qkv: SparseTensor, + window_size: int, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> SparseTensor: + """ + Apply windowed scaled dot product self attention to a sparse tensor. + + Args: + qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + window_size (int): The window size to use. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + shift (int): The shift to use. + """ + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + + serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}' + serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(qkv, window_size, shift_window) + qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) + else: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache + + M = fwd_indices.shape[0] + T = qkv.feats.shape[0] + H = qkv.feats.shape[2] + C = qkv.feats.shape[3] + + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + + if DEBUG: + start = 0 + qkv_coords = qkv.coords[fwd_indices] + for i in range(len(seq_lens)): + seq_coords = qkv_coords[start:start+seq_lens[i]] + assert (seq_coords[:, 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" + assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \ + f"SparseWindowedScaledDotProductSelfAttention: window size exceeded" + start += seq_lens[i] + + if all([seq_len == window_size for seq_len in seq_lens]): + B = len(seq_lens) + N = window_size + qkv_feats = qkv_feats.reshape(B, N, 3, H, C) + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] + out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] + elif ATTN == 'flash_attn': + out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] + else: + raise ValueError(f"Unknown attention module: {ATTN}") + out = out.reshape(B * N, H, C) # [M, H, C] + else: + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] + elif ATTN == 'flash_attn': + cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ + .to(qkv.device).int() + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C] + + out = out[bwd_indices] # [T, H, C] + + if DEBUG: + qkv_coords = qkv_coords[bwd_indices] + assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" + + return qkv.replace(out) diff --git a/trellis/modules/sparse/basic.py b/trellis/modules/sparse/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..0fc685128f47906e5608d38d789366bb06837513 --- /dev/null +++ b/trellis/modules/sparse/basic.py @@ -0,0 +1,459 @@ +from typing import * +import torch +import torch.nn as nn +from . import BACKEND, DEBUG +SparseTensorData = None # Lazy import + + +__all__ = [ + 'SparseTensor', + 'sparse_batch_broadcast', + 'sparse_batch_op', + 'sparse_cat', + 'sparse_unbind', +] + + +class SparseTensor: + """ + Sparse tensor with support for both torchsparse and spconv backends. + + Parameters: + - feats (torch.Tensor): Features of the sparse tensor. + - coords (torch.Tensor): Coordinates of the sparse tensor. + - shape (torch.Size): Shape of the sparse tensor. + - layout (List[slice]): Layout of the sparse tensor for each batch + - data (SparseTensorData): Sparse tensor data used for convolusion + + NOTE: + - Data corresponding to a same batch should be contiguous. + - Coords should be in [0, 1023] + """ + @overload + def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ... + + @overload + def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ... + + def __init__(self, *args, **kwargs): + # Lazy import of sparse tensor backend + global SparseTensorData + if SparseTensorData is None: + import importlib + if BACKEND == 'torchsparse': + SparseTensorData = importlib.import_module('torchsparse').SparseTensor + elif BACKEND == 'spconv': + SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor + + method_id = 0 + if len(args) != 0: + method_id = 0 if isinstance(args[0], torch.Tensor) else 1 + else: + method_id = 1 if 'data' in kwargs else 0 + + if method_id == 0: + feats, coords, shape, layout = args + (None,) * (4 - len(args)) + if 'feats' in kwargs: + feats = kwargs['feats'] + del kwargs['feats'] + if 'coords' in kwargs: + coords = kwargs['coords'] + del kwargs['coords'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + if 'layout' in kwargs: + layout = kwargs['layout'] + del kwargs['layout'] + + if shape is None: + shape = self.__cal_shape(feats, coords) + if layout is None: + layout = self.__cal_layout(coords, shape[0]) + if BACKEND == 'torchsparse': + self.data = SparseTensorData(feats, coords, **kwargs) + elif BACKEND == 'spconv': + spatial_shape = list(coords.max(0)[0] + 1)[1:] + self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs) + self.data._features = feats + elif method_id == 1: + data, shape, layout = args + (None,) * (3 - len(args)) + if 'data' in kwargs: + data = kwargs['data'] + del kwargs['data'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + if 'layout' in kwargs: + layout = kwargs['layout'] + del kwargs['layout'] + + self.data = data + if shape is None: + shape = self.__cal_shape(self.feats, self.coords) + if layout is None: + layout = self.__cal_layout(self.coords, shape[0]) + + self._shape = shape + self._layout = layout + self._scale = kwargs.get('scale', (1, 1, 1)) + self._spatial_cache = kwargs.get('spatial_cache', {}) + + if DEBUG: + try: + assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}" + assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}" + assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}" + for i in range(self.shape[0]): + assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous" + except Exception as e: + print('Debugging information:') + print(f"- Shape: {self.shape}") + print(f"- Layout: {self.layout}") + print(f"- Scale: {self._scale}") + print(f"- Coords: {self.coords}") + raise e + + def __cal_shape(self, feats, coords): + shape = [] + shape.append(coords[:, 0].max().item() + 1) + shape.extend([*feats.shape[1:]]) + return torch.Size(shape) + + def __cal_layout(self, coords, batch_size): + seq_len = torch.bincount(coords[:, 0], minlength=batch_size) + offset = torch.cumsum(seq_len, dim=0) + layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)] + return layout + + @property + def shape(self) -> torch.Size: + return self._shape + + def dim(self) -> int: + return len(self.shape) + + @property + def layout(self) -> List[slice]: + return self._layout + + @property + def feats(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.F + elif BACKEND == 'spconv': + return self.data.features + + @feats.setter + def feats(self, value: torch.Tensor): + if BACKEND == 'torchsparse': + self.data.F = value + elif BACKEND == 'spconv': + self.data.features = value + + @property + def coords(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.C + elif BACKEND == 'spconv': + return self.data.indices + + @coords.setter + def coords(self, value: torch.Tensor): + if BACKEND == 'torchsparse': + self.data.C = value + elif BACKEND == 'spconv': + self.data.indices = value + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @overload + def to(self, dtype: torch.dtype) -> 'SparseTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ... + + def to(self, *args, **kwargs) -> 'SparseTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + + new_feats = self.feats.to(device=device, dtype=dtype) + new_coords = self.coords.to(device=device) + return self.replace(new_feats, new_coords) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'SparseTensor': + new_feats = self.feats.cpu() + new_coords = self.coords.cpu() + return self.replace(new_feats, new_coords) + + def cuda(self) -> 'SparseTensor': + new_feats = self.feats.cuda() + new_coords = self.coords.cuda() + return self.replace(new_feats, new_coords) + + def half(self) -> 'SparseTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'SparseTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'SparseTensor': + new_coords = self.coords.detach() + new_feats = self.feats.detach() + return self.replace(new_feats, new_coords) + + def dense(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.dense() + elif BACKEND == 'spconv': + return self.data.dense() + + def reshape(self, *shape) -> 'SparseTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['SparseTensor']: + return sparse_unbind(self, dim) + + def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor': + new_shape = [self.shape[0]] + new_shape.extend(feats.shape[1:]) + if BACKEND == 'torchsparse': + new_data = SparseTensorData( + feats=feats, + coords=self.data.coords if coords is None else coords, + stride=self.data.stride, + spatial_range=self.data.spatial_range, + ) + new_data._caches = self.data._caches + elif BACKEND == 'spconv': + new_data = SparseTensorData( + self.data.features.reshape(self.data.features.shape[0], -1), + self.data.indices, + self.data.spatial_shape, + self.data.batch_size, + self.data.grid, + self.data.voxel_num, + self.data.indice_dict + ) + new_data._features = feats + new_data.benchmark = self.data.benchmark + new_data.benchmark_record = self.data.benchmark_record + new_data.thrust_allocator = self.data.thrust_allocator + new_data._timer = self.data._timer + new_data.force_algo = self.data.force_algo + new_data.int8_scale = self.data.int8_scale + if coords is not None: + new_data.indices = coords + new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache) + return new_tensor + + @staticmethod + def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor': + N, C = dim + x = torch.arange(aabb[0], aabb[3] + 1) + y = torch.arange(aabb[1], aabb[4] + 1) + z = torch.arange(aabb[2], aabb[5] + 1) + coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3) + coords = torch.cat([ + torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1), + coords.repeat(N, 1), + ], dim=1).to(dtype=torch.int32, device=device) + feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device) + return SparseTensor(feats=feats, coords=coords) + + def __merge_sparse_cache(self, other: 'SparseTensor') -> dict: + new_cache = {} + for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())): + if k in self._spatial_cache: + new_cache[k] = self._spatial_cache[k] + if k in other._spatial_cache: + if k not in new_cache: + new_cache[k] = other._spatial_cache[k] + else: + new_cache[k].update(other._spatial_cache[k]) + return new_cache + + def __neg__(self) -> 'SparseTensor': + return self.replace(-self.feats) + + def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor': + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = sparse_batch_broadcast(self, other) + except: + pass + if isinstance(other, SparseTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + if isinstance(other, SparseTensor): + new_tensor._spatial_cache = self.__merge_sparse_cache(other) + return new_tensor + + def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.add) + + def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.add) + + def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.sub) + + def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, lambda x, y: torch.sub(y, x)) + + def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.mul) + + def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.mul) + + def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.div) + + def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, lambda x, y: torch.div(y, x)) + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + coords = [] + feats = [] + for new_idx, old_idx in enumerate(idx): + coords.append(self.coords[self.layout[old_idx]].clone()) + coords[-1][:, 0] = new_idx + feats.append(self.feats[self.layout[old_idx]]) + coords = torch.cat(coords, dim=0).contiguous() + feats = torch.cat(feats, dim=0).contiguous() + return SparseTensor(feats=feats, coords=coords) + + def register_spatial_cache(self, key, value) -> None: + """ + Register a spatial cache. + The spatial cache can be any thing you want to cache. + The registery and retrieval of the cache is based on current scale. + """ + scale_key = str(self._scale) + if scale_key not in self._spatial_cache: + self._spatial_cache[scale_key] = {} + self._spatial_cache[scale_key][key] = value + + def get_spatial_cache(self, key=None): + """ + Get a spatial cache. + """ + scale_key = str(self._scale) + cur_scale_cache = self._spatial_cache.get(scale_key, {}) + if key is None: + return cur_scale_cache + return cur_scale_cache.get(key, None) + + +def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor: + """ + Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. + + Args: + input (torch.Tensor): 1D tensor to broadcast. + target (SparseTensor): Sparse tensor to broadcast to. + op (callable): Operation to perform after broadcasting. Defaults to torch.add. + """ + coords, feats = input.coords, input.feats + broadcasted = torch.zeros_like(feats) + for k in range(input.shape[0]): + broadcasted[input.layout[k]] = other[k] + return broadcasted + + +def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor: + """ + Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. + + Args: + input (torch.Tensor): 1D tensor to broadcast. + target (SparseTensor): Sparse tensor to broadcast to. + op (callable): Operation to perform after broadcasting. Defaults to torch.add. + """ + return input.replace(op(input.feats, sparse_batch_broadcast(input, other))) + + +def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor: + """ + Concatenate a list of sparse tensors. + + Args: + inputs (List[SparseTensor]): List of sparse tensors to concatenate. + """ + if dim == 0: + start = 0 + coords = [] + for input in inputs: + coords.append(input.coords.clone()) + coords[-1][:, 0] += start + start += input.shape[0] + coords = torch.cat(coords, dim=0) + feats = torch.cat([input.feats for input in inputs], dim=0) + output = SparseTensor( + coords=coords, + feats=feats, + ) + else: + feats = torch.cat([input.feats for input in inputs], dim=dim) + output = inputs[0].replace(feats) + + return output + + +def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: + """ + Unbind a sparse tensor along a dimension. + + Args: + input (SparseTensor): Sparse tensor to unbind. + dim (int): Dimension to unbind. + """ + if dim == 0: + return [input[i] for i in range(input.shape[0])] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] diff --git a/trellis/modules/sparse/conv/__init__.py b/trellis/modules/sparse/conv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3fdaddf7e07d42e296d056df28e56a544d2db5f2 --- /dev/null +++ b/trellis/modules/sparse/conv/__init__.py @@ -0,0 +1,21 @@ +from .. import BACKEND + + +SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native' + +def __from_env(): + import os + + global SPCONV_ALGO + env_spconv_algo = os.environ.get('SPCONV_ALGO') + if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']: + SPCONV_ALGO = env_spconv_algo + print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}") + + +__from_env() + +if BACKEND == 'torchsparse': + from .conv_torchsparse import * +elif BACKEND == 'spconv': + from .conv_spconv import * diff --git a/trellis/modules/sparse/conv/conv_spconv.py b/trellis/modules/sparse/conv/conv_spconv.py new file mode 100644 index 0000000000000000000000000000000000000000..856405dea4b24e5800cc056106bb34bb40f6eef0 --- /dev/null +++ b/trellis/modules/sparse/conv/conv_spconv.py @@ -0,0 +1,80 @@ +import torch +import torch.nn as nn +from .. import SparseTensor +from .. import DEBUG +from . import SPCONV_ALGO + +class SparseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + super(SparseConv3d, self).__init__() + if 'spconv' not in globals(): + import spconv.pytorch as spconv + algo = None + if SPCONV_ALGO == 'native': + algo = spconv.ConvAlgo.Native + elif SPCONV_ALGO == 'implicit_gemm': + algo = spconv.ConvAlgo.MaskImplicitGemm + if stride == 1 and (padding is None): + self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo) + else: + self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo) + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + self.padding = padding + + def forward(self, x: SparseTensor) -> SparseTensor: + spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None) + new_data = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + + if spatial_changed and (x.shape[0] != 1): + # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords + fwd = new_data.indices[:, 0].argsort() + bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device)) + sorted_feats = new_data.features[fwd] + sorted_coords = new_data.indices[fwd] + unsorted_data = new_data + new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore + + out = SparseTensor( + new_data, shape=torch.Size(new_shape), layout=new_layout, + scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + + if spatial_changed and (x.shape[0] != 1): + out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data) + out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd) + + return out + + +class SparseInverseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + super(SparseInverseConv3d, self).__init__() + if 'spconv' not in globals(): + import spconv.pytorch as spconv + self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key) + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + + def forward(self, x: SparseTensor) -> SparseTensor: + spatial_changed = any(s != 1 for s in self.stride) + if spatial_changed: + # recover the original spconv order + data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data') + bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd') + data = data.replace_feature(x.feats[bwd]) + if DEBUG: + assert torch.equal(data.indices, x.coords[bwd]), 'Recover the original order failed' + else: + data = x.data + + new_data = self.conv(data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + out = SparseTensor( + new_data, shape=torch.Size(new_shape), layout=new_layout, + scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + return out diff --git a/trellis/modules/sparse/conv/conv_torchsparse.py b/trellis/modules/sparse/conv/conv_torchsparse.py new file mode 100644 index 0000000000000000000000000000000000000000..a10bd9105581a96117abcbb7349ea5975e4304ba --- /dev/null +++ b/trellis/modules/sparse/conv/conv_torchsparse.py @@ -0,0 +1,38 @@ +import torch +import torch.nn as nn +from .. import SparseTensor + + +class SparseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + super(SparseConv3d, self).__init__() + if 'torchsparse' not in globals(): + import torchsparse + self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias) + + def forward(self, x: SparseTensor) -> SparseTensor: + out = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out._spatial_cache = x._spatial_cache + out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)]) + return out + + +class SparseInverseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + super(SparseInverseConv3d, self).__init__() + if 'torchsparse' not in globals(): + import torchsparse + self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True) + + def forward(self, x: SparseTensor) -> SparseTensor: + out = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out._spatial_cache = x._spatial_cache + out._scale = tuple([s // stride for s, stride in zip(x._scale, self.conv.stride)]) + return out + + + diff --git a/trellis/modules/sparse/linear.py b/trellis/modules/sparse/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..0c25a32a4ab08665a0d3f6bc44e61ef1c1cb2861 --- /dev/null +++ b/trellis/modules/sparse/linear.py @@ -0,0 +1,15 @@ +import torch +import torch.nn as nn +from . import SparseTensor + +__all__ = [ + 'SparseLinear' +] + + +class SparseLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super(SparseLinear, self).__init__(in_features, out_features, bias) + + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) diff --git a/trellis/modules/sparse/nonlinearity.py b/trellis/modules/sparse/nonlinearity.py new file mode 100644 index 0000000000000000000000000000000000000000..db81ee886f5d047a6bc98bb8f4d4ce867d2a302d --- /dev/null +++ b/trellis/modules/sparse/nonlinearity.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn +from . import SparseTensor + +__all__ = [ + 'SparseReLU', + 'SparseSiLU', + 'SparseGELU', + 'SparseActivation' +] + + +class SparseReLU(nn.ReLU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + + +class SparseSiLU(nn.SiLU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + + +class SparseGELU(nn.GELU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + + +class SparseActivation(nn.Module): + def __init__(self, activation: nn.Module): + super().__init__() + self.activation = activation + + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(self.activation(input.feats)) + diff --git a/trellis/modules/sparse/norm.py b/trellis/modules/sparse/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..7c132d8389277bc9a60937633b58256784597c4d --- /dev/null +++ b/trellis/modules/sparse/norm.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn +from . import SparseTensor +from . import DEBUG + +__all__ = [ + 'SparseGroupNorm', + 'SparseLayerNorm', + 'SparseGroupNorm32', + 'SparseLayerNorm32', +] + + +class SparseGroupNorm(nn.GroupNorm): + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): + super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine) + + def forward(self, input: SparseTensor) -> SparseTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + if DEBUG: + assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch" + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseLayerNorm(nn.LayerNorm): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, input: SparseTensor) -> SparseTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseGroupNorm32(SparseGroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: SparseTensor) -> SparseTensor: + return super().forward(x.float()).type(x.dtype) + +class SparseLayerNorm32(SparseLayerNorm): + """ + A LayerNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: SparseTensor) -> SparseTensor: + return super().forward(x.float()).type(x.dtype) diff --git a/trellis/modules/sparse/spatial.py b/trellis/modules/sparse/spatial.py new file mode 100644 index 0000000000000000000000000000000000000000..7a4713e82c42c6a1c7e72a00e3ab17e50f6f32a2 --- /dev/null +++ b/trellis/modules/sparse/spatial.py @@ -0,0 +1,110 @@ +from typing import * +import torch +import torch.nn as nn +from . import SparseTensor + +__all__ = [ + 'SparseDownsample', + 'SparseUpsample', + 'SparseSubdivide' +] + + +class SparseDownsample(nn.Module): + """ + Downsample a sparse tensor by a factor of `factor`. + Implemented as average pooling. + """ + def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]): + super(SparseDownsample, self).__init__() + self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor + + def forward(self, input: SparseTensor) -> SparseTensor: + DIM = input.coords.shape[-1] - 1 + factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM + assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.' + + coord = list(input.coords.unbind(dim=-1)) + for i, f in enumerate(factor): + coord[i+1] = coord[i+1] // f + + MAX = [coord[i+1].max().item() + 1 for i in range(DIM)] + OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] + code = sum([c * o for c, o in zip(coord, OFFSET)]) + code, idx = code.unique(return_inverse=True) + + new_feats = torch.scatter_reduce( + torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype), + dim=0, + index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]), + src=input.feats, + reduce='mean' + ) + new_coords = torch.stack( + [code // OFFSET[0]] + + [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], + dim=-1 + ) + out = SparseTensor(new_feats, new_coords, input.shape,) + out._scale = tuple([s // f for s, f in zip(input._scale, factor)]) + out._spatial_cache = input._spatial_cache + + out.register_spatial_cache(f'upsample_{factor}_coords', input.coords) + out.register_spatial_cache(f'upsample_{factor}_layout', input.layout) + out.register_spatial_cache(f'upsample_{factor}_idx', idx) + + return out + + +class SparseUpsample(nn.Module): + """ + Upsample a sparse tensor by a factor of `factor`. + Implemented as nearest neighbor interpolation. + """ + def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]): + super(SparseUpsample, self).__init__() + self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor + + def forward(self, input: SparseTensor) -> SparseTensor: + DIM = input.coords.shape[-1] - 1 + factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM + assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.' + + new_coords = input.get_spatial_cache(f'upsample_{factor}_coords') + new_layout = input.get_spatial_cache(f'upsample_{factor}_layout') + idx = input.get_spatial_cache(f'upsample_{factor}_idx') + if any([x is None for x in [new_coords, new_layout, idx]]): + raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.') + new_feats = input.feats[idx] + out = SparseTensor(new_feats, new_coords, input.shape, new_layout) + out._scale = tuple([s * f for s, f in zip(input._scale, factor)]) + out._spatial_cache = input._spatial_cache + return out + +class SparseSubdivide(nn.Module): + """ + Upsample a sparse tensor by a factor of `factor`. + Implemented as nearest neighbor interpolation. + """ + def __init__(self): + super(SparseSubdivide, self).__init__() + + def forward(self, input: SparseTensor) -> SparseTensor: + DIM = input.coords.shape[-1] - 1 + # upsample scale=2^DIM + n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int) + n_coords = torch.nonzero(n_cube) + n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) + factor = n_coords.shape[0] + assert factor == 2 ** DIM + # print(n_coords.shape) + new_coords = input.coords.clone() + new_coords[:, 1:] *= 2 + new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype) + + new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:]) + out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape) + out._scale = input._scale * 2 + out._spatial_cache = input._spatial_cache + return out + diff --git a/trellis/modules/sparse/transformer/__init__.py b/trellis/modules/sparse/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..67336cacd084ef5e779bf5a601d66720ea275fe6 --- /dev/null +++ b/trellis/modules/sparse/transformer/__init__.py @@ -0,0 +1,2 @@ +from .blocks import * +from .modulated import * \ No newline at end of file diff --git a/trellis/modules/sparse/transformer/blocks.py b/trellis/modules/sparse/transformer/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..03b7283e3d2e6c8cdfd7fba82548a73ec7dd3130 --- /dev/null +++ b/trellis/modules/sparse/transformer/blocks.py @@ -0,0 +1,151 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import SparseTensor +from ..linear import SparseLinear +from ..nonlinearity import SparseGELU +from ..attention import SparseMultiHeadAttention, SerializeMode +from ...norm import LayerNorm32 + + +class SparseFeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + SparseLinear(channels, int(channels * mlp_ratio)), + SparseGELU(approximate="tanh"), + SparseLinear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: SparseTensor) -> SparseTensor: + return self.mlp(x) + + +class SparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN). + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor) -> SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = self.attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN). + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor): + h = x.replace(self.norm1(x.feats)) + h = self.self_attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) + else: + return self._forward(x, context) diff --git a/trellis/modules/sparse/transformer/modulated.py b/trellis/modules/sparse/transformer/modulated.py new file mode 100644 index 0000000000000000000000000000000000000000..6ec00bbb567dd432130490d799ac6cf480107593 --- /dev/null +++ b/trellis/modules/sparse/transformer/modulated.py @@ -0,0 +1,166 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import SparseTensor +from ..attention import SparseMultiHeadAttention, SerializeMode +from ...norm import LayerNorm32 +from .blocks import SparseFeedForwardNet + + +class ModulatedSparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) + else: + return self._forward(x, mod) + + +class ModulatedSparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.self_attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) + else: + return self._forward(x, mod, context) diff --git a/trellis/modules/spatial.py b/trellis/modules/spatial.py new file mode 100644 index 0000000000000000000000000000000000000000..5e3b750c1da9462818ad5e25cc50e59a7d92f786 --- /dev/null +++ b/trellis/modules/spatial.py @@ -0,0 +1,48 @@ +import torch + + +def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: + """ + 3D pixel shuffle. + """ + B, C, H, W, D = x.shape + C_ = C // scale_factor**3 + x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) + x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor) + return x + + +def patchify(x: torch.Tensor, patch_size: int): + """ + Patchify a tensor. + + Args: + x (torch.Tensor): (N, C, *spatial) tensor + patch_size (int): Patch size + """ + DIM = x.dim() - 2 + for d in range(2, DIM + 2): + assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}" + + x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], [])) + x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)])) + x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:])) + return x + + +def unpatchify(x: torch.Tensor, patch_size: int): + """ + Unpatchify a tensor. + + Args: + x (torch.Tensor): (N, C, *spatial) tensor + patch_size (int): Patch size + """ + DIM = x.dim() - 2 + assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}" + + x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:])) + x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], []))) + x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)]) + return x diff --git a/trellis/modules/transformer/__init__.py b/trellis/modules/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..67336cacd084ef5e779bf5a601d66720ea275fe6 --- /dev/null +++ b/trellis/modules/transformer/__init__.py @@ -0,0 +1,2 @@ +from .blocks import * +from .modulated import * \ No newline at end of file diff --git a/trellis/modules/transformer/blocks.py b/trellis/modules/transformer/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..ab65a2a20172b00e1f35e8e2db45701f393ff82d --- /dev/null +++ b/trellis/modules/transformer/blocks.py @@ -0,0 +1,182 @@ +from typing import * +import torch +import torch.nn as nn +from ..attention import MultiHeadAttention +from ..norm import LayerNorm32 + + +class AbsolutePositionEmbedder(nn.Module): + """ + Embeds spatial positions into vector representations. + """ + def __init__(self, channels: int, in_channels: int = 3): + super().__init__() + self.channels = channels + self.in_channels = in_channels + self.freq_dim = channels // in_channels // 2 + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = 1.0 / (10000 ** self.freqs) + + def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor: + """ + Create sinusoidal position embeddings. + + Args: + x: a 1-D Tensor of N indices + + Returns: + an (N, D) Tensor of positional embeddings. + """ + self.freqs = self.freqs.to(x.device) + out = torch.outer(x, self.freqs) + out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1) + return out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): (N, D) tensor of spatial positions + """ + N, D = x.shape + assert D == self.in_channels, "Input dimension must match number of input channels" + embed = self._sin_cos_embedding(x.reshape(-1)) + embed = embed.reshape(N, -1) + if embed.shape[1] < self.channels: + embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1) + return embed + + +class FeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(channels, int(channels * mlp_ratio)), + nn.GELU(approximate="tanh"), + nn.Linear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + + +class TransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN). + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[int] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = self.attn(h) + x = x + h + h = self.norm2(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class TransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN). + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor, context: torch.Tensor): + h = self.norm1(x) + h = self.self_attn(h) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) + else: + return self._forward(x, context) + \ No newline at end of file diff --git a/trellis/modules/transformer/modulated.py b/trellis/modules/transformer/modulated.py new file mode 100644 index 0000000000000000000000000000000000000000..a8b90190b6dabc04b38499a6033483334fcfa69a --- /dev/null +++ b/trellis/modules/transformer/modulated.py @@ -0,0 +1,157 @@ +from typing import * +import torch +import torch.nn as nn +from ..attention import MultiHeadAttention +from ..norm import LayerNorm32 +from .blocks import FeedForwardNet + + +class ModulatedTransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.attn(h) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) + else: + return self._forward(x, mod) + + +class ModulatedTransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor): + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.self_attn(h) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) + else: + return self._forward(x, mod, context) + \ No newline at end of file diff --git a/trellis/modules/utils.py b/trellis/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..215fe98059ab49eca42703a4f1f92d80c5343f6b --- /dev/null +++ b/trellis/modules/utils.py @@ -0,0 +1,54 @@ +import torch.nn as nn +from ..modules import sparse as sp + +FP16_MODULES = ( + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + nn.Linear, + sp.SparseConv3d, + sp.SparseInverseConv3d, + sp.SparseLinear, +) + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, FP16_MODULES): + for p in l.parameters(): + p.data = p.data.half() + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, FP16_MODULES): + for p in l.parameters(): + p.data = p.data.float() + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) diff --git a/trellis/pipelines/__init__.py b/trellis/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c01e15250c6b976fb61567e887c6f33bb849fe50 --- /dev/null +++ b/trellis/pipelines/__init__.py @@ -0,0 +1,24 @@ +from . import samplers +from .trellis_image_to_3d import TrellisImageTo3DPipeline + + +def from_pretrained(path: str): + """ + Load a pipeline from a model folder or a Hugging Face model hub. + + Args: + path: The path to the model. Can be either local path or a Hugging Face model name. + """ + import os + import json + is_local = os.path.exists(f"{path}/pipeline.json") + + if is_local: + config_file = f"{path}/pipeline.json" + else: + from huggingface_hub import hf_hub_download + config_file = hf_hub_download(path, "pipeline.json") + + with open(config_file, 'r') as f: + config = json.load(f) + return globals()[config['name']].from_pretrained(path) diff --git a/trellis/pipelines/base.py b/trellis/pipelines/base.py new file mode 100644 index 0000000000000000000000000000000000000000..041ddce12a483557a25355bd6f2ccfe5bbfb5a17 --- /dev/null +++ b/trellis/pipelines/base.py @@ -0,0 +1,66 @@ +from typing import * +import torch +import torch.nn as nn +from .. import models + + +class Pipeline: + """ + A base class for pipelines. + """ + def __init__( + self, + models: dict[str, nn.Module] = None, + ): + if models is None: + return + self.models = models + for model in self.models.values(): + model.eval() + + @staticmethod + def from_pretrained(path: str) -> "Pipeline": + """ + Load a pretrained model. + """ + import os + import json + is_local = os.path.exists(f"{path}/pipeline.json") + + if is_local: + config_file = f"{path}/pipeline.json" + else: + from huggingface_hub import hf_hub_download + config_file = hf_hub_download(path, "pipeline.json") + + with open(config_file, 'r') as f: + args = json.load(f)['args'] + + _models = { + k: models.from_pretrained(f"{path}/{v}") + for k, v in args['models'].items() + } + + new_pipeline = Pipeline(_models) + new_pipeline._pretrained_args = args + return new_pipeline + + @property + def device(self) -> torch.device: + for model in self.models.values(): + if hasattr(model, 'device'): + return model.device + for model in self.models.values(): + if hasattr(model, 'parameters'): + return next(model.parameters()).device + raise RuntimeError("No device found.") + + def to(self, device: torch.device) -> None: + for model in self.models.values(): + model.to(device) + + def cuda(self) -> None: + self.to(torch.device("cuda")) + + def cpu(self) -> None: + self.to(torch.device("cpu")) diff --git a/trellis/pipelines/samplers/__init__.py b/trellis/pipelines/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4b1111715e7d39c8b5db1b70bbc1360d9bb6e0c6 --- /dev/null +++ b/trellis/pipelines/samplers/__init__.py @@ -0,0 +1,2 @@ +from .base import Sampler +from .flow_euler import FlowEulerSampler, FlowEulerCfgSampler, FlowEulerGuidanceIntervalSampler \ No newline at end of file diff --git a/trellis/pipelines/samplers/base.py b/trellis/pipelines/samplers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..bb70700117317477e738845e566b9ea87a768d0a --- /dev/null +++ b/trellis/pipelines/samplers/base.py @@ -0,0 +1,20 @@ +from typing import * +from abc import ABC, abstractmethod + + +class Sampler(ABC): + """ + A base class for samplers. + """ + + @abstractmethod + def sample( + self, + model, + **kwargs + ): + """ + Sample from a model. + """ + pass + \ No newline at end of file diff --git a/trellis/pipelines/samplers/classifier_free_guidance_mixin.py b/trellis/pipelines/samplers/classifier_free_guidance_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..076e1e3d9f6d1e7207c9659db530990894d614f9 --- /dev/null +++ b/trellis/pipelines/samplers/classifier_free_guidance_mixin.py @@ -0,0 +1,12 @@ +from typing import * + + +class ClassifierFreeGuidanceSamplerMixin: + """ + A mixin class for samplers that apply classifier-free guidance. + """ + + def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, **kwargs): + pred = super()._inference_model(model, x_t, t, cond, **kwargs) + neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs) + return (1 + cfg_strength) * pred - cfg_strength * neg_pred diff --git a/trellis/pipelines/samplers/flow_euler.py b/trellis/pipelines/samplers/flow_euler.py new file mode 100644 index 0000000000000000000000000000000000000000..a84c9d472b0c74b807663083b8aea63ff1eb3c7e --- /dev/null +++ b/trellis/pipelines/samplers/flow_euler.py @@ -0,0 +1,199 @@ +from typing import * +import torch +import numpy as np +from tqdm import tqdm +from easydict import EasyDict as edict +from .base import Sampler +from .classifier_free_guidance_mixin import ClassifierFreeGuidanceSamplerMixin +from .guidance_interval_mixin import GuidanceIntervalSamplerMixin + + +class FlowEulerSampler(Sampler): + """ + Generate samples from a flow-matching model using Euler sampling. + + Args: + sigma_min: The minimum scale of noise in flow. + """ + def __init__( + self, + sigma_min: float, + ): + self.sigma_min = sigma_min + + def _eps_to_xstart(self, x_t, t, eps): + assert x_t.shape == eps.shape + return (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * eps) / (1 - t) + + def _xstart_to_eps(self, x_t, t, x_0): + assert x_t.shape == x_0.shape + return (x_t - (1 - t) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t) + + def _v_to_xstart_eps(self, x_t, t, v): + assert x_t.shape == v.shape + eps = (1 - t) * v + x_t + x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v + return x_0, eps + + def _inference_model(self, model, x_t, t, cond=None, **kwargs): + t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32) + return model(x_t, t, cond, **kwargs) + + def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs): + pred_v = self._inference_model(model, x_t, t, cond, **kwargs) + pred_x_0, pred_eps = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v) + return pred_x_0, pred_eps, pred_v + + @torch.no_grad() + def sample_once( + self, + model, + x_t, + t: float, + t_prev: float, + cond: Optional[Any] = None, + **kwargs + ): + """ + Sample x_{t-1} from the model using Euler method. + + Args: + model: The model to sample from. + x_t: The [N x C x ...] tensor of noisy inputs at time t. + t: The current timestep. + t_prev: The previous timestep. + cond: conditional information. + **kwargs: Additional arguments for model inference. + + Returns: + a dict containing the following + - 'pred_x_prev': x_{t-1}. + - 'pred_x_0': a prediction of x_0. + """ + pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs) + pred_x_prev = x_t - (t - t_prev) * pred_v + return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0}) + + @torch.no_grad() + def sample( + self, + model, + noise, + cond: Optional[Any] = None, + steps: int = 50, + rescale_t: float = 1.0, + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + sample = noise + t_seq = np.linspace(1, 0, steps + 1) + t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq) + t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps)) + ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []}) + for t, t_prev in tqdm(t_pairs, desc="Sampling", disable=not verbose): + out = self.sample_once(model, sample, t, t_prev, cond, **kwargs) + sample = out.pred_x_prev + ret.pred_x_t.append(out.pred_x_prev) + ret.pred_x_0.append(out.pred_x_0) + ret.samples = sample + return ret + + +class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler): + """ + Generate samples from a flow-matching model using Euler sampling with classifier-free guidance. + """ + @torch.no_grad() + def sample( + self, + model, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + cfg_strength: float = 3.0, + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + cfg_strength: The strength of classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, **kwargs) + + +class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSampler): + """ + Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval. + """ + @torch.no_grad() + def sample( + self, + model, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + cfg_strength: float = 3.0, + cfg_interval: Tuple[float, float] = (0.0, 1.0), + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + cfg_strength: The strength of classifier-free guidance. + cfg_interval: The interval for classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs) diff --git a/trellis/pipelines/samplers/guidance_interval_mixin.py b/trellis/pipelines/samplers/guidance_interval_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..10524ce011de121db28b17fcbc2589e60019042e --- /dev/null +++ b/trellis/pipelines/samplers/guidance_interval_mixin.py @@ -0,0 +1,15 @@ +from typing import * + + +class GuidanceIntervalSamplerMixin: + """ + A mixin class for samplers that apply classifier-free guidance with interval. + """ + + def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs): + if cfg_interval[0] <= t <= cfg_interval[1]: + pred = super()._inference_model(model, x_t, t, cond, **kwargs) + neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs) + return (1 + cfg_strength) * pred - cfg_strength * neg_pred + else: + return super()._inference_model(model, x_t, t, cond, **kwargs) diff --git a/trellis/pipelines/trellis_image_to_3d.py b/trellis/pipelines/trellis_image_to_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..faeb32297f62e82cdbed6692a2d2b30da1c9c56d --- /dev/null +++ b/trellis/pipelines/trellis_image_to_3d.py @@ -0,0 +1,376 @@ +from typing import * +from contextlib import contextmanager +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from tqdm import tqdm +from easydict import EasyDict as edict +from torchvision import transforms +from PIL import Image +import rembg +from .base import Pipeline +from . import samplers +from ..modules import sparse as sp +from ..representations import Gaussian, Strivec, MeshExtractResult + + +class TrellisImageTo3DPipeline(Pipeline): + """ + Pipeline for inferring Trellis image-to-3D models. + + Args: + models (dict[str, nn.Module]): The models to use in the pipeline. + sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure. + slat_sampler (samplers.Sampler): The sampler for the structured latent. + slat_normalization (dict): The normalization parameters for the structured latent. + image_cond_model (str): The name of the image conditioning model. + """ + def __init__( + self, + models: dict[str, nn.Module] = None, + sparse_structure_sampler: samplers.Sampler = None, + slat_sampler: samplers.Sampler = None, + slat_normalization: dict = None, + image_cond_model: str = None, + ): + if models is None: + return + super().__init__(models) + self.sparse_structure_sampler = sparse_structure_sampler + self.slat_sampler = slat_sampler + self.sparse_structure_sampler_params = {} + self.slat_sampler_params = {} + self.slat_normalization = slat_normalization + self.rembg_session = None + self._init_image_cond_model(image_cond_model) + + @staticmethod + def from_pretrained(path: str) -> "TrellisImageTo3DPipeline": + """ + Load a pretrained model. + + Args: + path (str): The path to the model. Can be either local path or a Hugging Face repository. + """ + pipeline = super(TrellisImageTo3DPipeline, TrellisImageTo3DPipeline).from_pretrained(path) + new_pipeline = TrellisImageTo3DPipeline() + new_pipeline.__dict__ = pipeline.__dict__ + args = pipeline._pretrained_args + + new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args']) + new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params'] + + new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(**args['slat_sampler']['args']) + new_pipeline.slat_sampler_params = args['slat_sampler']['params'] + + new_pipeline.slat_normalization = args['slat_normalization'] + + new_pipeline._init_image_cond_model(args['image_cond_model']) + + return new_pipeline + + def _init_image_cond_model(self, name: str): + """ + Initialize the image conditioning model. + """ + dinov2_model = torch.hub.load('facebookresearch/dinov2', name, pretrained=True) + dinov2_model.eval() + self.models['image_cond_model'] = dinov2_model + transform = transforms.Compose([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + self.image_cond_model_transform = transform + + def preprocess_image(self, input: Image.Image) -> Image.Image: + """ + Preprocess the input image. + """ + # if has alpha channel, use it directly; otherwise, remove background + has_alpha = False + if input.mode == 'RGBA': + alpha = np.array(input)[:, :, 3] + if not np.all(alpha == 255): + has_alpha = True + if has_alpha: + output = input + else: + input = input.convert('RGB') + max_size = max(input.size) + scale = min(1, 1024 / max_size) + if scale < 1: + input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) + if getattr(self, 'rembg_session', None) is None: + self.rembg_session = rembg.new_session('u2net') + output = rembg.remove(input, session=self.rembg_session) + output_np = np.array(output) + alpha = output_np[:, :, 3] + bbox = np.argwhere(alpha > 0.8 * 255) + bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) + center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 + size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) + size = int(size * 1.2) + bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 + output = output.crop(bbox) # type: ignore + output = output.resize((518, 518), Image.Resampling.LANCZOS) + output = np.array(output).astype(np.float32) / 255 + output = output[:, :, :3] * output[:, :, 3:4] + output = Image.fromarray((output * 255).astype(np.uint8)) + return output + + @torch.no_grad() + def encode_image(self, image: Union[torch.Tensor, list[Image.Image]]) -> torch.Tensor: + """ + Encode the image. + + Args: + image (Union[torch.Tensor, list[Image.Image]]): The image to encode + + Returns: + torch.Tensor: The encoded features. + """ + if isinstance(image, torch.Tensor): + assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" + elif isinstance(image, list): + assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" + image = [i.resize((518, 518), Image.LANCZOS) for i in image] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + image = torch.stack(image).to(self.device) + else: + raise ValueError(f"Unsupported type of image: {type(image)}") + + image = self.image_cond_model_transform(image).to(self.device) + features = self.models['image_cond_model'](image, is_training=True)['x_prenorm'] + patchtokens = F.layer_norm(features, features.shape[-1:]) + return patchtokens + + def get_cond(self, image: Union[torch.Tensor, list[Image.Image]]) -> dict: + """ + Get the conditioning information for the model. + + Args: + image (Union[torch.Tensor, list[Image.Image]]): The image prompts. + + Returns: + dict: The conditioning information + """ + cond = self.encode_image(image) + neg_cond = torch.zeros_like(cond) + return { + 'cond': cond, + 'neg_cond': neg_cond, + } + + def sample_sparse_structure( + self, + cond: dict, + num_samples: int = 1, + sampler_params: dict = {}, + ) -> torch.Tensor: + """ + Sample sparse structures with the given conditioning. + + Args: + cond (dict): The conditioning information. + num_samples (int): The number of samples to generate. + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample occupancy latent + flow_model = self.models['sparse_structure_flow_model'] + reso = flow_model.resolution + noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device) + sampler_params = {**self.sparse_structure_sampler_params, **sampler_params} + z_s = self.sparse_structure_sampler.sample( + flow_model, + noise, + **cond, + **sampler_params, + verbose=True + ).samples + + # Decode occupancy latent + decoder = self.models['sparse_structure_decoder'] + coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int() + + return coords + + def decode_slat( + self, + slat: sp.SparseTensor, + formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], + ) -> dict: + """ + Decode the structured latent. + + Args: + slat (sp.SparseTensor): The structured latent. + formats (List[str]): The formats to decode the structured latent to. + + Returns: + dict: The decoded structured latent. + """ + ret = {} + if 'mesh' in formats: + ret['mesh'] = self.models['slat_decoder_mesh'](slat) + if 'gaussian' in formats: + ret['gaussian'] = self.models['slat_decoder_gs'](slat) + if 'radiance_field' in formats: + ret['radiance_field'] = self.models['slat_decoder_rf'](slat) + return ret + + def sample_slat( + self, + cond: dict, + coords: torch.Tensor, + sampler_params: dict = {}, + ) -> sp.SparseTensor: + """ + Sample structured latent with the given conditioning. + + Args: + cond (dict): The conditioning information. + coords (torch.Tensor): The coordinates of the sparse structure. + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample structured latent + flow_model = self.models['slat_flow_model'] + noise = sp.SparseTensor( + feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device), + coords=coords, + ) + sampler_params = {**self.slat_sampler_params, **sampler_params} + slat = self.slat_sampler.sample( + flow_model, + noise, + **cond, + **sampler_params, + verbose=True + ).samples + + std = torch.tensor(self.slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + return slat + + @torch.no_grad() + def run( + self, + image: Image.Image, + num_samples: int = 1, + seed: int = 42, + sparse_structure_sampler_params: dict = {}, + slat_sampler_params: dict = {}, + formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], + preprocess_image: bool = True, + ) -> dict: + """ + Run the pipeline. + + Args: + image (Image.Image): The image prompt. + num_samples (int): The number of samples to generate. + sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler. + slat_sampler_params (dict): Additional parameters for the structured latent sampler. + preprocess_image (bool): Whether to preprocess the image. + """ + if preprocess_image: + image = self.preprocess_image(image) + cond = self.get_cond([image]) + torch.manual_seed(seed) + coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params) + slat = self.sample_slat(cond, coords, slat_sampler_params) + return self.decode_slat(slat, formats) + + @contextmanager + def inject_sampler_multi_image( + self, + sampler_name: str, + num_images: int, + num_steps: int, + mode: Literal['stochastic', 'multidiffusion'] = 'stochastic', + ): + """ + Inject a sampler with multiple images as condition. + + Args: + sampler_name (str): The name of the sampler to inject. + num_images (int): The number of images to condition on. + num_steps (int): The number of steps to run the sampler for. + """ + sampler = getattr(self, sampler_name) + setattr(sampler, f'_old_inference_model', sampler._inference_model) + + if mode == 'stochastic': + if num_images > num_steps: + print(f"\033[93mWarning: number of conditioning images is greater than number of steps for {sampler_name}. " + "This may lead to performance degradation.\033[0m") + + cond_indices = (np.arange(num_steps) % num_images).tolist() + def _new_inference_model(self, model, x_t, t, cond, **kwargs): + cond_idx = cond_indices.pop(0) + cond_i = cond[cond_idx:cond_idx+1] + return self._old_inference_model(model, x_t, t, cond=cond_i, **kwargs) + + elif mode =='multidiffusion': + from .samplers import FlowEulerSampler + def _new_inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs): + if cfg_interval[0] <= t <= cfg_interval[1]: + preds = [] + for i in range(len(cond)): + preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs)) + pred = sum(preds) / len(preds) + neg_pred = FlowEulerSampler._inference_model(self, model, x_t, t, neg_cond, **kwargs) + return (1 + cfg_strength) * pred - cfg_strength * neg_pred + else: + preds = [] + for i in range(len(cond)): + preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs)) + pred = sum(preds) / len(preds) + return pred + + else: + raise ValueError(f"Unsupported mode: {mode}") + + sampler._inference_model = _new_inference_model.__get__(sampler, type(sampler)) + + yield + + sampler._inference_model = sampler._old_inference_model + delattr(sampler, f'_old_inference_model') + + @torch.no_grad() + def run_multi_image( + self, + images: List[Image.Image], + num_samples: int = 1, + seed: int = 42, + sparse_structure_sampler_params: dict = {}, + slat_sampler_params: dict = {}, + formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], + preprocess_image: bool = True, + mode: Literal['stochastic', 'multidiffusion'] = 'stochastic', + ) -> dict: + """ + Run the pipeline with multiple images as condition + + Args: + images (List[Image.Image]): The multi-view images of the assets + num_samples (int): The number of samples to generate. + sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler. + slat_sampler_params (dict): Additional parameters for the structured latent sampler. + preprocess_image (bool): Whether to preprocess the image. + """ + if preprocess_image: + images = [self.preprocess_image(image) for image in images] + cond = self.get_cond(images) + cond['neg_cond'] = cond['neg_cond'][:1] + torch.manual_seed(seed) + ss_steps = {**self.sparse_structure_sampler_params, **sparse_structure_sampler_params}.get('steps') + with self.inject_sampler_multi_image('sparse_structure_sampler', len(images), ss_steps, mode=mode): + coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params) + slat_steps = {**self.slat_sampler_params, **slat_sampler_params}.get('steps') + with self.inject_sampler_multi_image('slat_sampler', len(images), slat_steps, mode=mode): + slat = self.sample_slat(cond, coords, slat_sampler_params) + return self.decode_slat(slat, formats) diff --git a/trellis/renderers/__init__.py b/trellis/renderers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec397d1e795baedf05d52ef49f5161885151cafc --- /dev/null +++ b/trellis/renderers/__init__.py @@ -0,0 +1,31 @@ +import importlib + +__attributes = { + 'OctreeRenderer': 'octree_renderer', + 'GaussianRenderer': 'gaussian_render', + 'MeshRenderer': 'mesh_renderer', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .octree_renderer import OctreeRenderer + from .gaussian_render import GaussianRenderer + from .mesh_renderer import MeshRenderer \ No newline at end of file diff --git a/trellis/renderers/gaussian_render.py b/trellis/renderers/gaussian_render.py new file mode 100644 index 0000000000000000000000000000000000000000..272cf07ceaf2cb14bc7b9b82772721d97fd954c8 --- /dev/null +++ b/trellis/renderers/gaussian_render.py @@ -0,0 +1,231 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import math +from easydict import EasyDict as edict +import numpy as np +from ..representations.gaussian import Gaussian +from .sh_utils import eval_sh +import torch.nn.functional as F +from easydict import EasyDict as edict + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = far / (far - near) + ret[2, 3] = near * far / (near - far) + ret[3, 2] = 1. + return ret + + +def render(viewpoint_camera, pc : Gaussian, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None): + """ + Render the scene. + + Background tensor (bg_color) must be on GPU! + """ + # lazy import + if 'GaussianRasterizer' not in globals(): + from diff_gaussian_rasterization import GaussianRasterizer, GaussianRasterizationSettings + + # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means + screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 + try: + screenspace_points.retain_grad() + except: + pass + # Set up rasterization configuration + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + + kernel_size = pipe.kernel_size + subpixel_offset = torch.zeros((int(viewpoint_camera.image_height), int(viewpoint_camera.image_width), 2), dtype=torch.float32, device="cuda") + + raster_settings = GaussianRasterizationSettings( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + kernel_size=kernel_size, + subpixel_offset=subpixel_offset, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=pc.active_sh_degree, + campos=viewpoint_camera.camera_center, + prefiltered=False, + debug=pipe.debug + ) + + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + + means3D = pc.get_xyz + means2D = screenspace_points + opacity = pc.get_opacity + + # If precomputed 3d covariance is provided, use it. If not, then it will be computed from + # scaling / rotation by the rasterizer. + scales = None + rotations = None + cov3D_precomp = None + if pipe.compute_cov3D_python: + cov3D_precomp = pc.get_covariance(scaling_modifier) + else: + scales = pc.get_scaling + rotations = pc.get_rotation + + # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors + # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. + shs = None + colors_precomp = None + if override_color is None: + if pipe.convert_SHs_python: + shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) + dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) + dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) + sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) + colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) + else: + shs = pc.get_features + else: + colors_precomp = override_color + + # Rasterize visible Gaussians to image, obtain their radii (on screen). + rendered_image, radii = rasterizer( + means3D = means3D, + means2D = means2D, + shs = shs, + colors_precomp = colors_precomp, + opacities = opacity, + scales = scales, + rotations = rotations, + cov3D_precomp = cov3D_precomp + ) + + # Those Gaussians that were frustum culled or had a radius of 0 were not visible. + # They will be excluded from value updates used in the splitting criteria. + return edict({"render": rendered_image, + "viewspace_points": screenspace_points, + "visibility_filter" : radii > 0, + "radii": radii}) + + +class GaussianRenderer: + """ + Renderer for the Voxel representation. + + Args: + rendering_options (dict): Rendering options. + """ + + def __init__(self, rendering_options={}) -> None: + self.pipe = edict({ + "kernel_size": 0.1, + "convert_SHs_python": False, + "compute_cov3D_python": False, + "scale_modifier": 1.0, + "debug": False + }) + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "bg_color": 'random', + }) + self.rendering_options.update(rendering_options) + self.bg_color = None + + def render( + self, + gausssian: Gaussian, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + colors_overwrite: torch.Tensor = None + ) -> edict: + """ + Render the gausssian. + + Args: + gaussian : gaussianmodule + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + colors_overwrite (torch.Tensor): (N, 3) override color + + Returns: + edict containing: + color (torch.Tensor): (3, H, W) rendered color image + """ + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + + if self.rendering_options["bg_color"] == 'random': + self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") + if np.random.rand() < 0.5: + self.bg_color += 1 + else: + self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda") + + view = extrinsics + perspective = intrinsics_to_projection(intrinsics, near, far) + camera = torch.inverse(view)[:3, 3] + focalx = intrinsics[0, 0] + focaly = intrinsics[1, 1] + fovx = 2 * torch.atan(0.5 / focalx) + fovy = 2 * torch.atan(0.5 / focaly) + + camera_dict = edict({ + "image_height": resolution * ssaa, + "image_width": resolution * ssaa, + "FoVx": fovx, + "FoVy": fovy, + "znear": near, + "zfar": far, + "world_view_transform": view.T.contiguous(), + "projection_matrix": perspective.T.contiguous(), + "full_proj_transform": (perspective @ view).T.contiguous(), + "camera_center": camera + }) + + # Render + render_ret = render(camera_dict, gausssian, self.pipe, self.bg_color, override_color=colors_overwrite, scaling_modifier=self.pipe.scale_modifier) + + if ssaa > 1: + render_ret.render = F.interpolate(render_ret.render[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + + ret = edict({ + 'color': render_ret['render'] + }) + return ret diff --git a/trellis/renderers/mesh_renderer.py b/trellis/renderers/mesh_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..b0726def9a515834db145be3a24817fa97b09071 --- /dev/null +++ b/trellis/renderers/mesh_renderer.py @@ -0,0 +1,140 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import torch +import nvdiffrast.torch as dr +from easydict import EasyDict as edict +from ..representations.mesh import MeshExtractResult +import torch.nn.functional as F + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = far / (far - near) + ret[2, 3] = near * far / (near - far) + ret[3, 2] = 1. + return ret + + +class MeshRenderer: + """ + Renderer for the Mesh representation. + + Args: + rendering_options (dict): Rendering options. + glctx (nvdiffrast.torch.RasterizeGLContext): RasterizeGLContext object for CUDA/OpenGL interop. + """ + def __init__(self, rendering_options={}, device='cuda'): + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1 + }) + self.rendering_options.update(rendering_options) + self.glctx = dr.RasterizeCudaContext(device=device) + self.device=device + + def render( + self, + mesh : MeshExtractResult, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + return_types = ["mask", "normal", "depth"] + ) -> edict: + """ + Render the mesh. + + Args: + mesh : meshmodel + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + return_types (list): list of return types, can be "mask", "depth", "normal_map", "normal", "color" + + Returns: + edict based on return_types containing: + color (torch.Tensor): [3, H, W] rendered color image + depth (torch.Tensor): [H, W] rendered depth image + normal (torch.Tensor): [3, H, W] rendered normal image + normal_map (torch.Tensor): [3, H, W] rendered normal map image + mask (torch.Tensor): [H, W] rendered mask image + """ + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + + if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0: + default_img = torch.zeros((1, resolution, resolution, 3), dtype=torch.float32, device=self.device) + ret_dict = {k : default_img if k in ['normal', 'normal_map', 'color'] else default_img[..., :1] for k in return_types} + return ret_dict + + perspective = intrinsics_to_projection(intrinsics, near, far) + + RT = extrinsics.unsqueeze(0) + full_proj = (perspective @ extrinsics).unsqueeze(0) + + vertices = mesh.vertices.unsqueeze(0) + + vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1) + vertices_camera = torch.bmm(vertices_homo, RT.transpose(-1, -2)) + vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2)) + faces_int = mesh.faces.int() + rast, _ = dr.rasterize( + self.glctx, vertices_clip, faces_int, (resolution * ssaa, resolution * ssaa)) + + out_dict = edict() + for type in return_types: + img = None + if type == "mask" : + img = dr.antialias((rast[..., -1:] > 0).float(), rast, vertices_clip, faces_int) + elif type == "depth": + img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces_int)[0] + img = dr.antialias(img, rast, vertices_clip, faces_int) + elif type == "normal" : + img = dr.interpolate( + mesh.face_normal.reshape(1, -1, 3), rast, + torch.arange(mesh.faces.shape[0] * 3, device=self.device, dtype=torch.int).reshape(-1, 3) + )[0] + img = dr.antialias(img, rast, vertices_clip, faces_int) + # normalize norm pictures + img = (img + 1) / 2 + elif type == "normal_map" : + img = dr.interpolate(mesh.vertex_attrs[:, 3:].contiguous(), rast, faces_int)[0] + img = dr.antialias(img, rast, vertices_clip, faces_int) + elif type == "color" : + img = dr.interpolate(mesh.vertex_attrs[:, :3].contiguous(), rast, faces_int)[0] + img = dr.antialias(img, rast, vertices_clip, faces_int) + + if ssaa > 1: + img = F.interpolate(img.permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True) + img = img.squeeze() + else: + img = img.permute(0, 3, 1, 2).squeeze() + out_dict[type] = img + + return out_dict diff --git a/trellis/renderers/octree_renderer.py b/trellis/renderers/octree_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..c72541888c60591109c8690bd269669faad667c0 --- /dev/null +++ b/trellis/renderers/octree_renderer.py @@ -0,0 +1,300 @@ +import numpy as np +import torch +import torch.nn.functional as F +import math +import cv2 +from scipy.stats import qmc +from easydict import EasyDict as edict +from ..representations.octree import DfsOctree + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = far / (far - near) + ret[2, 3] = near * far / (near - far) + ret[3, 2] = 1. + return ret + + +def render(viewpoint_camera, octree : DfsOctree, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, used_rank = None, colors_overwrite = None, aux=None, halton_sampler=None): + """ + Render the scene. + + Background tensor (bg_color) must be on GPU! + """ + # lazy import + if 'OctreeTrivecRasterizer' not in globals(): + from diffoctreerast import OctreeVoxelRasterizer, OctreeGaussianRasterizer, OctreeTrivecRasterizer, OctreeDecoupolyRasterizer + + # Set up rasterization configuration + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + + raster_settings = edict( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=octree.active_sh_degree, + campos=viewpoint_camera.camera_center, + with_distloss=pipe.with_distloss, + jitter=pipe.jitter, + debug=pipe.debug, + ) + + positions = octree.get_xyz + if octree.primitive == "voxel": + densities = octree.get_density + elif octree.primitive == "gaussian": + opacities = octree.get_opacity + elif octree.primitive == "trivec": + trivecs = octree.get_trivec + densities = octree.get_density + raster_settings.density_shift = octree.density_shift + elif octree.primitive == "decoupoly": + decoupolys_V, decoupolys_g = octree.get_decoupoly + densities = octree.get_density + raster_settings.density_shift = octree.density_shift + else: + raise ValueError(f"Unknown primitive {octree.primitive}") + depths = octree.get_depth + + # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors + # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. + colors_precomp = None + shs = octree.get_features + if octree.primitive in ["voxel", "gaussian"] and colors_overwrite is not None: + colors_precomp = colors_overwrite + shs = None + + ret = edict() + + if octree.primitive == "voxel": + renderer = OctreeVoxelRasterizer(raster_settings=raster_settings) + rgb, depth, alpha, distloss = renderer( + positions = positions, + densities = densities, + shs = shs, + colors_precomp = colors_precomp, + depths = depths, + aabb = octree.aabb, + aux = aux, + ) + ret['rgb'] = rgb + ret['depth'] = depth + ret['alpha'] = alpha + ret['distloss'] = distloss + elif octree.primitive == "gaussian": + renderer = OctreeGaussianRasterizer(raster_settings=raster_settings) + rgb, depth, alpha = renderer( + positions = positions, + opacities = opacities, + shs = shs, + colors_precomp = colors_precomp, + depths = depths, + aabb = octree.aabb, + aux = aux, + ) + ret['rgb'] = rgb + ret['depth'] = depth + ret['alpha'] = alpha + elif octree.primitive == "trivec": + raster_settings.used_rank = used_rank if used_rank is not None else trivecs.shape[1] + renderer = OctreeTrivecRasterizer(raster_settings=raster_settings) + rgb, depth, alpha, percent_depth = renderer( + positions = positions, + trivecs = trivecs, + densities = densities, + shs = shs, + colors_precomp = colors_precomp, + colors_overwrite = colors_overwrite, + depths = depths, + aabb = octree.aabb, + aux = aux, + halton_sampler = halton_sampler, + ) + ret['percent_depth'] = percent_depth + ret['rgb'] = rgb + ret['depth'] = depth + ret['alpha'] = alpha + elif octree.primitive == "decoupoly": + raster_settings.used_rank = used_rank if used_rank is not None else decoupolys_V.shape[1] + renderer = OctreeDecoupolyRasterizer(raster_settings=raster_settings) + rgb, depth, alpha = renderer( + positions = positions, + decoupolys_V = decoupolys_V, + decoupolys_g = decoupolys_g, + densities = densities, + shs = shs, + colors_precomp = colors_precomp, + depths = depths, + aabb = octree.aabb, + aux = aux, + ) + ret['rgb'] = rgb + ret['depth'] = depth + ret['alpha'] = alpha + + return ret + + +class OctreeRenderer: + """ + Renderer for the Voxel representation. + + Args: + rendering_options (dict): Rendering options. + """ + + def __init__(self, rendering_options={}) -> None: + try: + import diffoctreerast + except ImportError: + print("\033[93m[WARNING] diffoctreerast is not installed. The renderer will be disabled.\033[0m") + self.unsupported = True + else: + self.unsupported = False + + self.pipe = edict({ + "with_distloss": False, + "with_aux": False, + "scale_modifier": 1.0, + "used_rank": None, + "jitter": False, + "debug": False, + }) + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "bg_color": 'random', + }) + self.halton_sampler = qmc.Halton(2, scramble=False) + self.rendering_options.update(rendering_options) + self.bg_color = None + + def render( + self, + octree: DfsOctree, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + colors_overwrite: torch.Tensor = None, + ) -> edict: + """ + Render the octree. + + Args: + octree (Octree): octree + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + colors_overwrite (torch.Tensor): (N, 3) override color + + Returns: + edict containing: + color (torch.Tensor): (3, H, W) rendered color + depth (torch.Tensor): (H, W) rendered depth + alpha (torch.Tensor): (H, W) rendered alpha + distloss (Optional[torch.Tensor]): (H, W) rendered distance loss + percent_depth (Optional[torch.Tensor]): (H, W) rendered percent depth + aux (Optional[edict]): auxiliary tensors + """ + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + + if self.unsupported: + image = np.zeros((512, 512, 3), dtype=np.uint8) + text_bbox = cv2.getTextSize("Unsupported", cv2.FONT_HERSHEY_SIMPLEX, 2, 3)[0] + origin = (512 - text_bbox[0]) // 2, (512 - text_bbox[1]) // 2 + image = cv2.putText(image, "Unsupported", origin, cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 255, 255), 3, cv2.LINE_AA) + return { + 'color': torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255, + } + + if self.rendering_options["bg_color"] == 'random': + self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") + if np.random.rand() < 0.5: + self.bg_color += 1 + else: + self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda") + + if self.pipe["with_aux"]: + aux = { + 'grad_color2': torch.zeros((octree.num_leaf_nodes, 3), dtype=torch.float32, requires_grad=True, device="cuda") + 0, + 'contributions': torch.zeros((octree.num_leaf_nodes, 1), dtype=torch.float32, requires_grad=True, device="cuda") + 0, + } + for k in aux.keys(): + aux[k].requires_grad_() + aux[k].retain_grad() + else: + aux = None + + view = extrinsics + perspective = intrinsics_to_projection(intrinsics, near, far) + camera = torch.inverse(view)[:3, 3] + focalx = intrinsics[0, 0] + focaly = intrinsics[1, 1] + fovx = 2 * torch.atan(0.5 / focalx) + fovy = 2 * torch.atan(0.5 / focaly) + + camera_dict = edict({ + "image_height": resolution * ssaa, + "image_width": resolution * ssaa, + "FoVx": fovx, + "FoVy": fovy, + "znear": near, + "zfar": far, + "world_view_transform": view.T.contiguous(), + "projection_matrix": perspective.T.contiguous(), + "full_proj_transform": (perspective @ view).T.contiguous(), + "camera_center": camera + }) + + # Render + render_ret = render(camera_dict, octree, self.pipe, self.bg_color, aux=aux, colors_overwrite=colors_overwrite, scaling_modifier=self.pipe.scale_modifier, used_rank=self.pipe.used_rank, halton_sampler=self.halton_sampler) + + if ssaa > 1: + render_ret.rgb = F.interpolate(render_ret.rgb[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + render_ret.depth = F.interpolate(render_ret.depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + render_ret.alpha = F.interpolate(render_ret.alpha[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + if hasattr(render_ret, 'percent_depth'): + render_ret.percent_depth = F.interpolate(render_ret.percent_depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + + ret = edict({ + 'color': render_ret.rgb, + 'depth': render_ret.depth, + 'alpha': render_ret.alpha, + }) + if self.pipe["with_distloss"] and 'distloss' in render_ret: + ret['distloss'] = render_ret.distloss + if self.pipe["with_aux"]: + ret['aux'] = aux + if hasattr(render_ret, 'percent_depth'): + ret['percent_depth'] = render_ret.percent_depth + return ret diff --git a/trellis/renderers/sh_utils.py b/trellis/renderers/sh_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a54612b24cde4e1ab6da3fb37142cc5d0248ada8 --- /dev/null +++ b/trellis/renderers/sh_utils.py @@ -0,0 +1,118 @@ +# Copyright 2021 The PlenOctree Authors. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +import torch + +C0 = 0.28209479177387814 +C1 = 0.4886025119029199 +C2 = [ + 1.0925484305920792, + -1.0925484305920792, + 0.31539156525252005, + -1.0925484305920792, + 0.5462742152960396 +] +C3 = [ + -0.5900435899266435, + 2.890611442640554, + -0.4570457994644658, + 0.3731763325901154, + -0.4570457994644658, + 1.445305721320277, + -0.5900435899266435 +] +C4 = [ + 2.5033429417967046, + -1.7701307697799304, + 0.9461746957575601, + -0.6690465435572892, + 0.10578554691520431, + -0.6690465435572892, + 0.47308734787878004, + -1.7701307697799304, + 0.6258357354491761, +] + + +def eval_sh(deg, sh, dirs): + """ + Evaluate spherical harmonics at unit directions + using hardcoded SH polynomials. + Works with torch/np/jnp. + ... Can be 0 or more batch dimensions. + Args: + deg: int SH deg. Currently, 0-3 supported + sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] + dirs: jnp.ndarray unit directions [..., 3] + Returns: + [..., C] + """ + assert deg <= 4 and deg >= 0 + coeff = (deg + 1) ** 2 + assert sh.shape[-1] >= coeff + + result = C0 * sh[..., 0] + if deg > 0: + x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] + result = (result - + C1 * y * sh[..., 1] + + C1 * z * sh[..., 2] - + C1 * x * sh[..., 3]) + + if deg > 1: + xx, yy, zz = x * x, y * y, z * z + xy, yz, xz = x * y, y * z, x * z + result = (result + + C2[0] * xy * sh[..., 4] + + C2[1] * yz * sh[..., 5] + + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + + C2[3] * xz * sh[..., 7] + + C2[4] * (xx - yy) * sh[..., 8]) + + if deg > 2: + result = (result + + C3[0] * y * (3 * xx - yy) * sh[..., 9] + + C3[1] * xy * z * sh[..., 10] + + C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + + C3[5] * z * (xx - yy) * sh[..., 14] + + C3[6] * x * (xx - 3 * yy) * sh[..., 15]) + + if deg > 3: + result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + + C4[1] * yz * (3 * xx - yy) * sh[..., 17] + + C4[2] * xy * (7 * zz - 1) * sh[..., 18] + + C4[3] * yz * (7 * zz - 3) * sh[..., 19] + + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + + C4[5] * xz * (7 * zz - 3) * sh[..., 21] + + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) + return result + +def RGB2SH(rgb): + return (rgb - 0.5) / C0 + +def SH2RGB(sh): + return sh * C0 + 0.5 \ No newline at end of file diff --git a/trellis/representations/__init__.py b/trellis/representations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..26c62155fc77f668549092c55fca0c610aa02540 --- /dev/null +++ b/trellis/representations/__init__.py @@ -0,0 +1,4 @@ +from .radiance_field import Strivec +from .octree import DfsOctree as Octree +from .gaussian import Gaussian +from .mesh import MeshExtractResult diff --git a/trellis/representations/gaussian/__init__.py b/trellis/representations/gaussian/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e3de6e180bd732836af876d748255595be2d4d74 --- /dev/null +++ b/trellis/representations/gaussian/__init__.py @@ -0,0 +1 @@ +from .gaussian_model import Gaussian \ No newline at end of file diff --git a/trellis/representations/gaussian/gaussian_model.py b/trellis/representations/gaussian/gaussian_model.py new file mode 100644 index 0000000000000000000000000000000000000000..373411cb16aa2baf466eaa600fb4d4248bd550c0 --- /dev/null +++ b/trellis/representations/gaussian/gaussian_model.py @@ -0,0 +1,209 @@ +import torch +import numpy as np +from plyfile import PlyData, PlyElement +from .general_utils import inverse_sigmoid, strip_symmetric, build_scaling_rotation +import utils3d + + +class Gaussian: + def __init__( + self, + aabb : list, + sh_degree : int = 0, + mininum_kernel_size : float = 0.0, + scaling_bias : float = 0.01, + opacity_bias : float = 0.1, + scaling_activation : str = "exp", + device='cuda' + ): + self.init_params = { + 'aabb': aabb, + 'sh_degree': sh_degree, + 'mininum_kernel_size': mininum_kernel_size, + 'scaling_bias': scaling_bias, + 'opacity_bias': opacity_bias, + 'scaling_activation': scaling_activation, + } + + self.sh_degree = sh_degree + self.active_sh_degree = sh_degree + self.mininum_kernel_size = mininum_kernel_size + self.scaling_bias = scaling_bias + self.opacity_bias = opacity_bias + self.scaling_activation_type = scaling_activation + self.device = device + self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device) + self.setup_functions() + + self._xyz = None + self._features_dc = None + self._features_rest = None + self._scaling = None + self._rotation = None + self._opacity = None + + def setup_functions(self): + def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): + L = build_scaling_rotation(scaling_modifier * scaling, rotation) + actual_covariance = L @ L.transpose(1, 2) + symm = strip_symmetric(actual_covariance) + return symm + + if self.scaling_activation_type == "exp": + self.scaling_activation = torch.exp + self.inverse_scaling_activation = torch.log + elif self.scaling_activation_type == "softplus": + self.scaling_activation = torch.nn.functional.softplus + self.inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x)) + + self.covariance_activation = build_covariance_from_scaling_rotation + + self.opacity_activation = torch.sigmoid + self.inverse_opacity_activation = inverse_sigmoid + + self.rotation_activation = torch.nn.functional.normalize + + self.scale_bias = self.inverse_scaling_activation(torch.tensor(self.scaling_bias)).cuda() + self.rots_bias = torch.zeros((4)).cuda() + self.rots_bias[0] = 1 + self.opacity_bias = self.inverse_opacity_activation(torch.tensor(self.opacity_bias)).cuda() + + @property + def get_scaling(self): + scales = self.scaling_activation(self._scaling + self.scale_bias) + scales = torch.square(scales) + self.mininum_kernel_size ** 2 + scales = torch.sqrt(scales) + return scales + + @property + def get_rotation(self): + return self.rotation_activation(self._rotation + self.rots_bias[None, :]) + + @property + def get_xyz(self): + return self._xyz * self.aabb[None, 3:] + self.aabb[None, :3] + + @property + def get_features(self): + return torch.cat((self._features_dc, self._features_rest), dim=2) if self._features_rest is not None else self._features_dc + + @property + def get_opacity(self): + return self.opacity_activation(self._opacity + self.opacity_bias) + + def get_covariance(self, scaling_modifier = 1): + return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation + self.rots_bias[None, :]) + + def from_scaling(self, scales): + scales = torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2) + self._scaling = self.inverse_scaling_activation(scales) - self.scale_bias + + def from_rotation(self, rots): + self._rotation = rots - self.rots_bias[None, :] + + def from_xyz(self, xyz): + self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:] + + def from_features(self, features): + self._features_dc = features + + def from_opacity(self, opacities): + self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias + + def construct_list_of_attributes(self): + l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] + # All channels except the 3 DC + for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]): + l.append('f_dc_{}'.format(i)) + l.append('opacity') + for i in range(self._scaling.shape[1]): + l.append('scale_{}'.format(i)) + for i in range(self._rotation.shape[1]): + l.append('rot_{}'.format(i)) + return l + + def save_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]): + xyz = self.get_xyz.detach().cpu().numpy() + normals = np.zeros_like(xyz) + f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + opacities = inverse_sigmoid(self.get_opacity).detach().cpu().numpy() + scale = torch.log(self.get_scaling).detach().cpu().numpy() + rotation = (self._rotation + self.rots_bias[None, :]).detach().cpu().numpy() + + if transform is not None: + transform = np.array(transform) + xyz = np.matmul(xyz, transform.T) + rotation = utils3d.numpy.quaternion_to_matrix(rotation) + rotation = np.matmul(transform, rotation) + rotation = utils3d.numpy.matrix_to_quaternion(rotation) + + dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] + + elements = np.empty(xyz.shape[0], dtype=dtype_full) + attributes = np.concatenate((xyz, normals, f_dc, opacities, scale, rotation), axis=1) + elements[:] = list(map(tuple, attributes)) + el = PlyElement.describe(elements, 'vertex') + PlyData([el]).write(path) + + def load_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]): + plydata = PlyData.read(path) + + xyz = np.stack((np.asarray(plydata.elements[0]["x"]), + np.asarray(plydata.elements[0]["y"]), + np.asarray(plydata.elements[0]["z"])), axis=1) + opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] + + features_dc = np.zeros((xyz.shape[0], 3, 1)) + features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) + features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) + features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) + + if self.sh_degree > 0: + extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] + extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) + assert len(extra_f_names)==3*(self.sh_degree + 1) ** 2 - 3 + features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) + for idx, attr_name in enumerate(extra_f_names): + features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) + # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) + features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) + + scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] + scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) + scales = np.zeros((xyz.shape[0], len(scale_names))) + for idx, attr_name in enumerate(scale_names): + scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] + rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) + rots = np.zeros((xyz.shape[0], len(rot_names))) + for idx, attr_name in enumerate(rot_names): + rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + if transform is not None: + transform = np.array(transform) + xyz = np.matmul(xyz, transform) + rotation = utils3d.numpy.quaternion_to_matrix(rotation) + rotation = np.matmul(rotation, transform) + rotation = utils3d.numpy.matrix_to_quaternion(rotation) + + # convert to actual gaussian attributes + xyz = torch.tensor(xyz, dtype=torch.float, device=self.device) + features_dc = torch.tensor(features_dc, dtype=torch.float, device=self.device).transpose(1, 2).contiguous() + if self.sh_degree > 0: + features_extra = torch.tensor(features_extra, dtype=torch.float, device=self.device).transpose(1, 2).contiguous() + opacities = torch.sigmoid(torch.tensor(opacities, dtype=torch.float, device=self.device)) + scales = torch.exp(torch.tensor(scales, dtype=torch.float, device=self.device)) + rots = torch.tensor(rots, dtype=torch.float, device=self.device) + + # convert to _hidden attributes + self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:] + self._features_dc = features_dc + if self.sh_degree > 0: + self._features_rest = features_extra + else: + self._features_rest = None + self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias + self._scaling = self.inverse_scaling_activation(torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2)) - self.scale_bias + self._rotation = rots - self.rots_bias[None, :] + \ No newline at end of file diff --git a/trellis/representations/gaussian/general_utils.py b/trellis/representations/gaussian/general_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ae982066ab2d04fac15e997df9dbd37620ad08ed --- /dev/null +++ b/trellis/representations/gaussian/general_utils.py @@ -0,0 +1,133 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import sys +from datetime import datetime +import numpy as np +import random + +def inverse_sigmoid(x): + return torch.log(x/(1-x)) + +def PILtoTorch(pil_image, resolution): + resized_image_PIL = pil_image.resize(resolution) + resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 + if len(resized_image.shape) == 3: + return resized_image.permute(2, 0, 1) + else: + return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) + +def get_expon_lr_func( + lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 +): + """ + Copied from Plenoxels + + Continuous learning rate decay function. Adapted from JaxNeRF + The returned rate is lr_init when step=0 and lr_final when step=max_steps, and + is log-linearly interpolated elsewhere (equivalent to exponential decay). + If lr_delay_steps>0 then the learning rate will be scaled by some smooth + function of lr_delay_mult, such that the initial learning rate is + lr_init*lr_delay_mult at the beginning of optimization but will be eased back + to the normal learning rate when steps>lr_delay_steps. + :param conf: config subtree 'lr' or similar + :param max_steps: int, the number of steps during optimization. + :return HoF which takes step as input + """ + + def helper(step): + if step < 0 or (lr_init == 0.0 and lr_final == 0.0): + # Disable this parameter + return 0.0 + if lr_delay_steps > 0: + # A kind of reverse cosine decay. + delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( + 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) + ) + else: + delay_rate = 1.0 + t = np.clip(step / max_steps, 0, 1) + log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) + return delay_rate * log_lerp + + return helper + +def strip_lowerdiag(L): + uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") + + uncertainty[:, 0] = L[:, 0, 0] + uncertainty[:, 1] = L[:, 0, 1] + uncertainty[:, 2] = L[:, 0, 2] + uncertainty[:, 3] = L[:, 1, 1] + uncertainty[:, 4] = L[:, 1, 2] + uncertainty[:, 5] = L[:, 2, 2] + return uncertainty + +def strip_symmetric(sym): + return strip_lowerdiag(sym) + +def build_rotation(r): + norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) + + q = r / norm[:, None] + + R = torch.zeros((q.size(0), 3, 3), device='cuda') + + r = q[:, 0] + x = q[:, 1] + y = q[:, 2] + z = q[:, 3] + + R[:, 0, 0] = 1 - 2 * (y*y + z*z) + R[:, 0, 1] = 2 * (x*y - r*z) + R[:, 0, 2] = 2 * (x*z + r*y) + R[:, 1, 0] = 2 * (x*y + r*z) + R[:, 1, 1] = 1 - 2 * (x*x + z*z) + R[:, 1, 2] = 2 * (y*z - r*x) + R[:, 2, 0] = 2 * (x*z - r*y) + R[:, 2, 1] = 2 * (y*z + r*x) + R[:, 2, 2] = 1 - 2 * (x*x + y*y) + return R + +def build_scaling_rotation(s, r): + L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") + R = build_rotation(r) + + L[:,0,0] = s[:,0] + L[:,1,1] = s[:,1] + L[:,2,2] = s[:,2] + + L = R @ L + return L + +def safe_state(silent): + old_f = sys.stdout + class F: + def __init__(self, silent): + self.silent = silent + + def write(self, x): + if not self.silent: + if x.endswith("\n"): + old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) + else: + old_f.write(x) + + def flush(self): + old_f.flush() + + sys.stdout = F(silent) + + random.seed(0) + np.random.seed(0) + torch.manual_seed(0) + torch.cuda.set_device(torch.device("cuda:0")) diff --git a/trellis/representations/mesh/__init__.py b/trellis/representations/mesh/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fffa2c12907ed38dce29455df884c474704bd663 --- /dev/null +++ b/trellis/representations/mesh/__init__.py @@ -0,0 +1 @@ +from .cube2mesh import SparseFeatures2Mesh, MeshExtractResult diff --git a/trellis/representations/mesh/cube2mesh.py b/trellis/representations/mesh/cube2mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..8a6d3b57c49b7140198b5d438a3e618e0ba23d64 --- /dev/null +++ b/trellis/representations/mesh/cube2mesh.py @@ -0,0 +1,150 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import torch +from ...modules.sparse import SparseTensor +from easydict import EasyDict as edict +from .utils_cube import * +from .flexicube import FlexiCubes + + +class MeshExtractResult: + def __init__(self, + vertices, + faces, + vertex_attrs=None, + res=64 + ): + self.vertices = vertices + self.faces = faces.long() + self.vertex_attrs = vertex_attrs + self.face_normal = self.comput_face_normals(vertices, faces) + self.res = res + self.success = (vertices.shape[0] != 0 and faces.shape[0] != 0) + + # training only + self.tsdf_v = None + self.tsdf_s = None + self.reg_loss = None + + def comput_face_normals(self, verts, faces): + i0 = faces[..., 0].long() + i1 = faces[..., 1].long() + i2 = faces[..., 2].long() + + v0 = verts[i0, :] + v1 = verts[i1, :] + v2 = verts[i2, :] + face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) + face_normals = torch.nn.functional.normalize(face_normals, dim=1) + # print(face_normals.min(), face_normals.max(), face_normals.shape) + return face_normals[:, None, :].repeat(1, 3, 1) + + def comput_v_normals(self, verts, faces): + i0 = faces[..., 0].long() + i1 = faces[..., 1].long() + i2 = faces[..., 2].long() + + v0 = verts[i0, :] + v1 = verts[i1, :] + v2 = verts[i2, :] + face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) + v_normals = torch.zeros_like(verts) + v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals) + v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals) + v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals) + + v_normals = torch.nn.functional.normalize(v_normals, dim=1) + return v_normals + + +class SparseFeatures2Mesh: + def __init__(self, device="cuda", res=64, use_color=True): + ''' + a model to generate a mesh from sparse features structures using flexicube + ''' + super().__init__() + self.device=device + self.res = res + self.mesh_extractor = FlexiCubes(device=device) + self.sdf_bias = -1.0 / res + verts, cube = construct_dense_grid(self.res, self.device) + self.reg_c = cube.to(self.device) + self.reg_v = verts.to(self.device) + self.use_color = use_color + self._calc_layout() + + def _calc_layout(self): + LAYOUTS = { + 'sdf': {'shape': (8, 1), 'size': 8}, + 'deform': {'shape': (8, 3), 'size': 8 * 3}, + 'weights': {'shape': (21,), 'size': 21} + } + if self.use_color: + ''' + 6 channel color including normal map + ''' + LAYOUTS['color'] = {'shape': (8, 6,), 'size': 8 * 6} + self.layouts = edict(LAYOUTS) + start = 0 + for k, v in self.layouts.items(): + v['range'] = (start, start + v['size']) + start += v['size'] + self.feats_channels = start + + def get_layout(self, feats : torch.Tensor, name : str): + if name not in self.layouts: + return None + return feats[:, self.layouts[name]['range'][0]:self.layouts[name]['range'][1]].reshape(-1, *self.layouts[name]['shape']) + + def __call__(self, cubefeats : SparseTensor, training=False): + """ + Generates a mesh based on the specified sparse voxel structures. + Args: + cube_attrs [Nx21] : Sparse Tensor attrs about cube weights + verts_attrs [Nx10] : [0:1] SDF [1:4] deform [4:7] color [7:10] normal + Returns: + return the success tag and ni you loss, + """ + # add sdf bias to verts_attrs + coords = cubefeats.coords[:, 1:] + feats = cubefeats.feats + + sdf, deform, color, weights = [self.get_layout(feats, name) for name in ['sdf', 'deform', 'color', 'weights']] + sdf += self.sdf_bias + v_attrs = [sdf, deform, color] if self.use_color else [sdf, deform] + v_pos, v_attrs, reg_loss = sparse_cube2verts(coords, torch.cat(v_attrs, dim=-1), training=training) + v_attrs_d = get_dense_attrs(v_pos, v_attrs, res=self.res+1, sdf_init=True) + weights_d = get_dense_attrs(coords, weights, res=self.res, sdf_init=False) + if self.use_color: + sdf_d, deform_d, colors_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4], v_attrs_d[..., 4:] + else: + sdf_d, deform_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4] + colors_d = None + + x_nx3 = get_defomed_verts(self.reg_v, deform_d, self.res) + + vertices, faces, L_dev, colors = self.mesh_extractor( + voxelgrid_vertices=x_nx3, + scalar_field=sdf_d, + cube_idx=self.reg_c, + resolution=self.res, + beta=weights_d[:, :12], + alpha=weights_d[:, 12:20], + gamma_f=weights_d[:, 20], + voxelgrid_colors=colors_d, + training=training) + + mesh = MeshExtractResult(vertices=vertices, faces=faces, vertex_attrs=colors, res=self.res) + if training: + if mesh.success: + reg_loss += L_dev.mean() * 0.5 + reg_loss += (weights[:,:20]).abs().mean() * 0.2 + mesh.reg_loss = reg_loss + mesh.tsdf_v = get_defomed_verts(v_pos, v_attrs[:, 1:4], self.res) + mesh.tsdf_s = v_attrs[:, 0] + return mesh diff --git a/trellis/representations/mesh/flexicube.py b/trellis/representations/mesh/flexicube.py new file mode 100644 index 0000000000000000000000000000000000000000..0071b1ba000a826e870e43971a68a4963b23a6c3 --- /dev/null +++ b/trellis/representations/mesh/flexicube.py @@ -0,0 +1,359 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +from .tables import * + +__all__ = [ + 'FlexiCubes' +] + + +class FlexiCubes: + def __init__(self, device="cuda"): + + self.device = device + self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False) + self.num_vd_table = torch.tensor(num_vd_table, + dtype=torch.long, device=device, requires_grad=False) + self.check_table = torch.tensor( + check_table, + dtype=torch.long, device=device, requires_grad=False) + + self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False) + self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False) + self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False) + self.quad_split_train = torch.tensor( + [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False) + + self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ + 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device) + self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False)) + self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, + 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False) + + self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1], + dtype=torch.long, device=device) + self.dir_faces_table = torch.tensor([ + [[5, 4], [3, 2], [4, 5], [2, 3]], + [[5, 4], [1, 0], [4, 5], [0, 1]], + [[3, 2], [1, 0], [2, 3], [0, 1]] + ], dtype=torch.long, device=device) + self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device) + + def __call__(self, voxelgrid_vertices, scalar_field, cube_idx, resolution, qef_reg_scale=1e-3, + weight_scale=0.99, beta=None, alpha=None, gamma_f=None, voxelgrid_colors=None, training=False): + surf_cubes, occ_fx8 = self._identify_surf_cubes(scalar_field, cube_idx) + if surf_cubes.sum() == 0: + return ( + torch.zeros((0, 3), device=self.device), + torch.zeros((0, 3), dtype=torch.long, device=self.device), + torch.zeros((0), device=self.device), + torch.zeros((0, voxelgrid_colors.shape[-1]), device=self.device) if voxelgrid_colors is not None else None + ) + beta, alpha, gamma_f = self._normalize_weights( + beta, alpha, gamma_f, surf_cubes, weight_scale) + + if voxelgrid_colors is not None: + voxelgrid_colors = torch.sigmoid(voxelgrid_colors) + + case_ids = self._get_case_id(occ_fx8, surf_cubes, resolution) + + surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges( + scalar_field, cube_idx, surf_cubes + ) + + vd, L_dev, vd_gamma, vd_idx_map, vd_color = self._compute_vd( + voxelgrid_vertices, cube_idx[surf_cubes], surf_edges, scalar_field, + case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors) + vertices, faces, s_edges, edge_indices, vertices_color = self._triangulate( + scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map, + vd_idx_map, surf_edges_mask, training, vd_color) + return vertices, faces, L_dev, vertices_color + + def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges): + """ + Regularizer L_dev as in Equation 8 + """ + dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1) + mean_l2 = torch.zeros_like(vd[:, 0]) + mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float() + mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs() + return mad + + def _normalize_weights(self, beta, alpha, gamma_f, surf_cubes, weight_scale): + """ + Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones. + """ + n_cubes = surf_cubes.shape[0] + + if beta is not None: + beta = (torch.tanh(beta) * weight_scale + 1) + else: + beta = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device) + + if alpha is not None: + alpha = (torch.tanh(alpha) * weight_scale + 1) + else: + alpha = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device) + + if gamma_f is not None: + gamma_f = torch.sigmoid(gamma_f) * weight_scale + (1 - weight_scale) / 2 + else: + gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device) + + return beta[surf_cubes], alpha[surf_cubes], gamma_f[surf_cubes] + + @torch.no_grad() + def _get_case_id(self, occ_fx8, surf_cubes, res): + """ + Obtains the ID of topology cases based on cell corner occupancy. This function resolves the + ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the + supplementary material. It should be noted that this function assumes a regular grid. + """ + case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1) + + problem_config = self.check_table.to(self.device)[case_ids] + to_check = problem_config[..., 0] == 1 + problem_config = problem_config[to_check] + if not isinstance(res, (list, tuple)): + res = [res, res, res] + + # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array, + # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes). + # This allows efficient checking on adjacent cubes. + problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long) + vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3 + vol_idx_problem = vol_idx[surf_cubes][to_check] + problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config + vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4] + + within_range = ( + vol_idx_problem_adj[..., 0] >= 0) & ( + vol_idx_problem_adj[..., 0] < res[0]) & ( + vol_idx_problem_adj[..., 1] >= 0) & ( + vol_idx_problem_adj[..., 1] < res[1]) & ( + vol_idx_problem_adj[..., 2] >= 0) & ( + vol_idx_problem_adj[..., 2] < res[2]) + + vol_idx_problem = vol_idx_problem[within_range] + vol_idx_problem_adj = vol_idx_problem_adj[within_range] + problem_config = problem_config[within_range] + problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0], + vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]] + # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted. + to_invert = (problem_config_adj[..., 0] == 1) + idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert] + case_ids.index_put_((idx,), problem_config[to_invert][..., -1]) + return case_ids + + @torch.no_grad() + def _identify_surf_edges(self, scalar_field, cube_idx, surf_cubes): + """ + Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge + can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge + and marks the cube edges with this index. + """ + occ_n = scalar_field < 0 + all_edges = cube_idx[surf_cubes][:, self.cube_edges].reshape(-1, 2) + unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + + surf_edges_mask = mask_edges[_idx_map] + counts = counts[_idx_map] + + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_idx.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_idx.device) + # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index + # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1. + idx_map = mapping[_idx_map] + surf_edges = unique_edges[mask_edges] + return surf_edges, idx_map, counts, surf_edges_mask + + @torch.no_grad() + def _identify_surf_cubes(self, scalar_field, cube_idx): + """ + Identifies grid cubes that intersect with the underlying surface by checking if the signs at + all corners are not identical. + """ + occ_n = scalar_field < 0 + occ_fx8 = occ_n[cube_idx.reshape(-1)].reshape(-1, 8) + _occ_sum = torch.sum(occ_fx8, -1) + surf_cubes = (_occ_sum > 0) & (_occ_sum < 8) + return surf_cubes, occ_fx8 + + def _linear_interp(self, edges_weight, edges_x): + """ + Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'. + """ + edge_dim = edges_weight.dim() - 2 + assert edges_weight.shape[edge_dim] == 2 + edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), - + torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)] + , edge_dim) + denominator = edges_weight.sum(edge_dim) + ue = (edges_x * edges_weight).sum(edge_dim) / denominator + return ue + + def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3, qef_reg_scale): + p_bxnx3 = p_bxnx3.reshape(-1, 7, 3) + norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3) + c_bx3 = c_bx3.reshape(-1, 3) + A = norm_bxnx3 + B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True) + + A_reg = (torch.eye(3, device=p_bxnx3.device) * qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1) + B_reg = (qef_reg_scale * c_bx3).unsqueeze(-1) + A = torch.cat([A, A_reg], 1) + B = torch.cat([B, B_reg], 1) + dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1) + return dual_verts + + def _compute_vd(self, voxelgrid_vertices, surf_cubes_fx8, surf_edges, scalar_field, + case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors): + """ + Computes the location of dual vertices as described in Section 4.2 + """ + alpha_nx12x2 = torch.index_select(input=alpha, index=self.cube_edges, dim=1).reshape(-1, 12, 2) + surf_edges_x = torch.index_select(input=voxelgrid_vertices, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3) + surf_edges_s = torch.index_select(input=scalar_field, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1) + zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x) + + if voxelgrid_colors is not None: + C = voxelgrid_colors.shape[-1] + surf_edges_c = torch.index_select(input=voxelgrid_colors, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, C) + + idx_map = idx_map.reshape(-1, 12) + num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0) + edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], [] + + # if color is not None: + # vd_color = [] + + total_num_vd = 0 + vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False) + + for num in torch.unique(num_vd): + cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching) + curr_num_vd = cur_cubes.sum() * num + curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7) + curr_edge_group_to_vd = torch.arange( + curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd + total_num_vd += curr_num_vd + curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[ + cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group) + + curr_mask = (curr_edge_group != -1) + edge_group.append(torch.masked_select(curr_edge_group, curr_mask)) + edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask)) + edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask)) + vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True)) + vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1)) + # if color is not None: + # vd_color.append(color[cur_cubes].unsqueeze(1).repeat(1, num, 1).reshape(-1, 3)) + + edge_group = torch.cat(edge_group) + edge_group_to_vd = torch.cat(edge_group_to_vd) + edge_group_to_cube = torch.cat(edge_group_to_cube) + vd_num_edges = torch.cat(vd_num_edges) + vd_gamma = torch.cat(vd_gamma) + # if color is not None: + # vd_color = torch.cat(vd_color) + # else: + # vd_color = None + + vd = torch.zeros((total_num_vd, 3), device=self.device) + beta_sum = torch.zeros((total_num_vd, 1), device=self.device) + + idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group) + + x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3) + s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1) + + + zero_crossing_group = torch.index_select( + input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3) + + alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0, + index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1) + ue_group = self._linear_interp(s_group * alpha_group, x_group) + + beta_group = torch.gather(input=beta.reshape(-1), dim=0, + index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1) + beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group) + vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum + + ''' + interpolate colors use the same method as dual vertices + ''' + if voxelgrid_colors is not None: + vd_color = torch.zeros((total_num_vd, C), device=self.device) + c_group = torch.index_select(input=surf_edges_c, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, C) + uc_group = self._linear_interp(s_group * alpha_group, c_group) + vd_color = vd_color.index_add_(0, index=edge_group_to_vd, source=uc_group * beta_group) / beta_sum + else: + vd_color = None + + L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges) + + v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd + + vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube * + 12 + edge_group, src=v_idx[edge_group_to_vd]) + + return vd, L_dev, vd_gamma, vd_idx_map, vd_color + + def _triangulate(self, scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, vd_color): + """ + Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into + triangles based on the gamma parameter, as described in Section 4.3. + """ + with torch.no_grad(): + group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes. + group = idx_map.reshape(-1)[group_mask] + vd_idx = vd_idx_map[group_mask] + edge_indices, indices = torch.sort(group, stable=True) + quad_vd_idx = vd_idx[indices].reshape(-1, 4) + + # Ensure all face directions point towards the positive SDF to maintain consistent winding. + s_edges = scalar_field[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2) + flip_mask = s_edges[:, 0] > 0 + quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]], + quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]])) + + quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4) + gamma_02 = quad_gamma[:, 0] * quad_gamma[:, 2] + gamma_13 = quad_gamma[:, 1] * quad_gamma[:, 3] + if not training: + mask = (gamma_02 > gamma_13) + faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device) + faces[mask] = quad_vd_idx[mask][:, self.quad_split_1] + faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2] + faces = faces.reshape(-1, 3) + else: + vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) + vd_02 = (vd_quad[:, 0] + vd_quad[:, 2]) / 2 + vd_13 = (vd_quad[:, 1] + vd_quad[:, 3]) / 2 + weight_sum = (gamma_02 + gamma_13) + 1e-8 + vd_center = (vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1) + + if vd_color is not None: + color_quad = torch.index_select(input=vd_color, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, vd_color.shape[-1]) + color_02 = (color_quad[:, 0] + color_quad[:, 2]) / 2 + color_13 = (color_quad[:, 1] + color_quad[:, 3]) / 2 + color_center = (color_02 * gamma_02.unsqueeze(-1) + color_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1) + vd_color = torch.cat([vd_color, color_center]) + + + vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0] + vd = torch.cat([vd, vd_center]) + faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2) + faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3) + return vd, faces, s_edges, edge_indices, vd_color diff --git a/trellis/representations/mesh/tables.py b/trellis/representations/mesh/tables.py new file mode 100644 index 0000000000000000000000000000000000000000..203d6b1c40410e57d56f5d6f86973d22eea61057 --- /dev/null +++ b/trellis/representations/mesh/tables.py @@ -0,0 +1,791 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +dmc_table = [ +[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]] +] +num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2, +2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, +1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1, +1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2, +2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, +3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1, +2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, +1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, +1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, +1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0] +check_table = [ +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 194], +[1, -1, 0, 0, 193], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 164], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 161], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 152], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 145], +[1, 0, 0, 1, 144], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 137], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 133], +[1, 0, 1, 0, 132], +[1, 1, 0, 0, 131], +[1, 1, 0, 0, 130], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 100], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 98], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 96], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 88], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 82], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 74], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 72], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 70], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 67], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 65], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 56], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 52], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 44], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 40], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 38], +[1, 0, -1, 0, 37], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 33], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 28], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 26], +[1, 0, 0, -1, 25], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 20], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 18], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 9], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 6], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0] +] +tet_table = [ +[-1, -1, -1, -1, -1, -1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, -1], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[0, 4, 0, 4, 4, -1], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[5, 5, 5, 5, 5, 5], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, -1, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, -1, 2, 4, 4, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 4, 4, 2], +[1, 1, 1, 1, 1, 1], +[2, 4, 2, 4, 4, 2], +[0, 4, 0, 4, 4, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 5, 2, 5, 5, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, -1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[4, 1, 1, 4, 4, 1], +[0, 1, 1, 0, 0, 1], +[4, 0, 0, 4, 4, 0], +[2, 2, 2, 2, 2, 2], +[-1, 1, 1, 4, 4, 1], +[0, 1, 1, 4, 4, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[5, 1, 1, 5, 5, 1], +[0, 1, 1, 0, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[8, 8, 8, 8, 8, 8], +[1, 1, 1, 4, 4, 1], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 4, 4, 1], +[0, 4, 0, 4, 4, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 5, 5, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[6, 6, 6, 6, 6, 6], +[6, -1, 0, 6, 0, 6], +[6, 0, 0, 6, 0, 6], +[6, 1, 1, 6, 1, 6], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[6, 4, -1, 6, 4, 6], +[6, 4, 0, 6, 4, 6], +[6, 0, 0, 6, 0, 6], +[6, 1, 1, 6, 1, 6], +[5, 5, 5, 5, 5, 5], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[2, 4, 2, 2, 4, 2], +[0, 4, 0, 4, 4, 0], +[2, 0, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[6, 1, 1, 6, -1, 6], +[6, 1, 1, 6, 0, 6], +[6, 0, 0, 6, 0, 6], +[6, 2, 2, 6, 2, 6], +[4, 1, 1, 4, 4, 1], +[0, 1, 1, 0, 0, 1], +[4, 0, 0, 4, 4, 4], +[2, 2, 2, 2, 2, 2], +[6, 1, 1, 6, 4, 6], +[6, 1, 1, 6, 4, 6], +[6, 0, 0, 6, 0, 6], +[6, 2, 2, 6, 2, 6], +[5, 1, 1, 5, 5, 1], +[0, 1, 1, 0, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[6, 6, 6, 6, 6, 6], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 4, 1], +[0, 4, 0, 4, 4, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 5, 0, 5, 0, 5], +[5, 5, 5, 5, 5, 5], +[5, 5, 5, 5, 5, 5], +[0, 5, 0, 5, 0, 5], +[-1, 5, 0, 5, 0, 5], +[1, 5, 1, 5, 1, 5], +[4, 5, -1, 5, 4, 5], +[0, 5, 0, 5, 0, 5], +[4, 5, 0, 5, 4, 5], +[1, 5, 1, 5, 1, 5], +[4, 4, 4, 4, 4, 4], +[0, 4, 0, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[6, 6, 6, 6, 6, 6], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 5, 2, 5, -1, 5], +[0, 5, 0, 5, 0, 5], +[2, 5, 2, 5, 0, 5], +[1, 5, 1, 5, 1, 5], +[2, 5, 2, 5, 4, 5], +[0, 5, 0, 5, 0, 5], +[2, 5, 2, 5, 4, 5], +[1, 5, 1, 5, 1, 5], +[2, 4, 2, 4, 4, 2], +[0, 4, 0, 4, 4, 4], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 6, 2, 6, 6, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, 1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[4, 1, 1, 1, 4, 1], +[0, 1, 1, 1, 0, 1], +[4, 0, 0, 4, 4, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[1, 1, 1, 1, 4, 1], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[6, 0, 0, 6, 0, 6], +[0, 0, 0, 0, 0, 0], +[6, 6, 6, 6, 6, 6], +[5, 5, 5, 5, 5, 5], +[5, 5, 0, 5, 0, 5], +[5, 5, 0, 5, 0, 5], +[5, 5, 1, 5, 1, 5], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 4, 0, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[4, 4, 0, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[8, 8, 8, 8, 8, 8], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 1, 1, 4, 4, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 4, 2, 4, 4, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[12, 12, 12, 12, 12, 12] +] \ No newline at end of file diff --git a/trellis/representations/mesh/utils_cube.py b/trellis/representations/mesh/utils_cube.py new file mode 100644 index 0000000000000000000000000000000000000000..9befc1de4561d2d682873e0b752948275ddc9189 --- /dev/null +++ b/trellis/representations/mesh/utils_cube.py @@ -0,0 +1,61 @@ +import torch +cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ + 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.int) +cube_neighbor = torch.tensor([[1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -1, 0], [0, 0, 1], [0, 0, -1]]) +cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, + 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, requires_grad=False) + +def construct_dense_grid(res, device='cuda'): + '''construct a dense grid based on resolution''' + res_v = res + 1 + vertsid = torch.arange(res_v ** 3, device=device) + coordsid = vertsid.reshape(res_v, res_v, res_v)[:res, :res, :res].flatten() + cube_corners_bias = (cube_corners[:, 0] * res_v + cube_corners[:, 1]) * res_v + cube_corners[:, 2] + cube_fx8 = (coordsid.unsqueeze(1) + cube_corners_bias.unsqueeze(0).to(device)) + verts = torch.stack([vertsid // (res_v ** 2), (vertsid // res_v) % res_v, vertsid % res_v], dim=1) + return verts, cube_fx8 + + +def construct_voxel_grid(coords): + verts = (cube_corners.unsqueeze(0).to(coords) + coords.unsqueeze(1)).reshape(-1, 3) + verts_unique, inverse_indices = torch.unique(verts, dim=0, return_inverse=True) + cubes = inverse_indices.reshape(-1, 8) + return verts_unique, cubes + + +def cubes_to_verts(num_verts, cubes, value, reduce='mean'): + """ + Args: + cubes [Vx8] verts index for each cube + value [Vx8xM] value to be scattered + Operation: + reduced[cubes[i][j]][k] += value[i][k] + """ + M = value.shape[2] # number of channels + reduced = torch.zeros(num_verts, M, device=cubes.device) + return torch.scatter_reduce(reduced, 0, + cubes.unsqueeze(-1).expand(-1, -1, M).flatten(0, 1), + value.flatten(0, 1), reduce=reduce, include_self=False) + +def sparse_cube2verts(coords, feats, training=True): + new_coords, cubes = construct_voxel_grid(coords) + new_feats = cubes_to_verts(new_coords.shape[0], cubes, feats) + if training: + con_loss = torch.mean((feats - new_feats[cubes]) ** 2) + else: + con_loss = 0.0 + return new_coords, new_feats, con_loss + + +def get_dense_attrs(coords : torch.Tensor, feats : torch.Tensor, res : int, sdf_init=True): + F = feats.shape[-1] + dense_attrs = torch.zeros([res] * 3 + [F], device=feats.device) + if sdf_init: + dense_attrs[..., 0] = 1 # initial outside sdf value + dense_attrs[coords[:, 0], coords[:, 1], coords[:, 2], :] = feats + return dense_attrs.reshape(-1, F) + + +def get_defomed_verts(v_pos : torch.Tensor, deform : torch.Tensor, res): + return v_pos / res - 0.5 + (1 - 1e-8) / (res * 2) * torch.tanh(deform) + \ No newline at end of file diff --git a/trellis/representations/octree/__init__.py b/trellis/representations/octree/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f66a39a5a7498e2e99fe9d94d663796b3bc157b5 --- /dev/null +++ b/trellis/representations/octree/__init__.py @@ -0,0 +1 @@ +from .octree_dfs import DfsOctree \ No newline at end of file diff --git a/trellis/representations/octree/octree_dfs.py b/trellis/representations/octree/octree_dfs.py new file mode 100644 index 0000000000000000000000000000000000000000..710f18b73c6a68acbc7bfb470efef7632fbbd6ed --- /dev/null +++ b/trellis/representations/octree/octree_dfs.py @@ -0,0 +1,362 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +DEFAULT_TRIVEC_CONFIG = { + 'dim': 8, + 'rank': 8, +} + +DEFAULT_VOXEL_CONFIG = { + 'solid': False, +} + +DEFAULT_DECOPOLY_CONFIG = { + 'degree': 8, + 'rank': 16, +} + + +class DfsOctree: + """ + Sparse Voxel Octree (SVO) implementation for PyTorch. + Using Depth-First Search (DFS) order to store the octree. + DFS order suits rendering and ray tracing. + + The structure and data are separatedly stored. + Structure is stored as a continuous array, each element is a 3*32 bits descriptor. + |-----------------------------------------| + | 0:3 bits | 4:31 bits | + | leaf num | unused | + |-----------------------------------------| + | 0:31 bits | + | child ptr | + |-----------------------------------------| + | 0:31 bits | + | data ptr | + |-----------------------------------------| + Each element represents a non-leaf node in the octree. + The valid mask is used to indicate whether the children are valid. + The leaf mask is used to indicate whether the children are leaf nodes. + The child ptr is used to point to the first non-leaf child. Non-leaf children descriptors are stored continuously from the child ptr. + The data ptr is used to point to the data of leaf children. Leaf children data are stored continuously from the data ptr. + + There are also auxiliary arrays to store the additional structural information to facilitate parallel processing. + - Position: the position of the octree nodes. + - Depth: the depth of the octree nodes. + + Args: + depth (int): the depth of the octree. + """ + + def __init__( + self, + depth, + aabb=[0,0,0,1,1,1], + sh_degree=2, + primitive='voxel', + primitive_config={}, + device='cuda', + ): + self.max_depth = depth + self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device) + self.device = device + self.sh_degree = sh_degree + self.active_sh_degree = sh_degree + self.primitive = primitive + self.primitive_config = primitive_config + + self.structure = torch.tensor([[8, 1, 0]], dtype=torch.int32, device=self.device) + self.position = torch.zeros((8, 3), dtype=torch.float32, device=self.device) + self.depth = torch.zeros((8, 1), dtype=torch.uint8, device=self.device) + self.position[:, 0] = torch.tensor([0.25, 0.75, 0.25, 0.75, 0.25, 0.75, 0.25, 0.75], device=self.device) + self.position[:, 1] = torch.tensor([0.25, 0.25, 0.75, 0.75, 0.25, 0.25, 0.75, 0.75], device=self.device) + self.position[:, 2] = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.75, 0.75, 0.75, 0.75], device=self.device) + self.depth[:, 0] = 1 + + self.data = ['position', 'depth'] + self.param_names = [] + + if primitive == 'voxel': + self.features_dc = torch.zeros((8, 1, 3), dtype=torch.float32, device=self.device) + self.features_ac = torch.zeros((8, (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) + self.data += ['features_dc', 'features_ac'] + self.param_names += ['features_dc', 'features_ac'] + if not primitive_config.get('solid', False): + self.density = torch.zeros((8, 1), dtype=torch.float32, device=self.device) + self.data.append('density') + self.param_names.append('density') + elif primitive == 'gaussian': + self.features_dc = torch.zeros((8, 1, 3), dtype=torch.float32, device=self.device) + self.features_ac = torch.zeros((8, (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) + self.opacity = torch.zeros((8, 1), dtype=torch.float32, device=self.device) + self.data += ['features_dc', 'features_ac', 'opacity'] + self.param_names += ['features_dc', 'features_ac', 'opacity'] + elif primitive == 'trivec': + self.trivec = torch.zeros((8, primitive_config['rank'], 3, primitive_config['dim']), dtype=torch.float32, device=self.device) + self.density = torch.zeros((8, primitive_config['rank']), dtype=torch.float32, device=self.device) + self.features_dc = torch.zeros((8, primitive_config['rank'], 1, 3), dtype=torch.float32, device=self.device) + self.features_ac = torch.zeros((8, primitive_config['rank'], (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) + self.density_shift = 0 + self.data += ['trivec', 'density', 'features_dc', 'features_ac'] + self.param_names += ['trivec', 'density', 'features_dc', 'features_ac'] + elif primitive == 'decoupoly': + self.decoupoly_V = torch.zeros((8, primitive_config['rank'], 3), dtype=torch.float32, device=self.device) + self.decoupoly_g = torch.zeros((8, primitive_config['rank'], primitive_config['degree']), dtype=torch.float32, device=self.device) + self.density = torch.zeros((8, primitive_config['rank']), dtype=torch.float32, device=self.device) + self.features_dc = torch.zeros((8, primitive_config['rank'], 1, 3), dtype=torch.float32, device=self.device) + self.features_ac = torch.zeros((8, primitive_config['rank'], (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) + self.density_shift = 0 + self.data += ['decoupoly_V', 'decoupoly_g', 'density', 'features_dc', 'features_ac'] + self.param_names += ['decoupoly_V', 'decoupoly_g', 'density', 'features_dc', 'features_ac'] + + self.setup_functions() + + def setup_functions(self): + self.density_activation = (lambda x: torch.exp(x - 2)) if self.primitive != 'trivec' else (lambda x: x) + self.opacity_activation = lambda x: torch.sigmoid(x - 6) + self.inverse_opacity_activation = lambda x: torch.log(x / (1 - x)) + 6 + self.color_activation = lambda x: torch.sigmoid(x) + + @property + def num_non_leaf_nodes(self): + return self.structure.shape[0] + + @property + def num_leaf_nodes(self): + return self.depth.shape[0] + + @property + def cur_depth(self): + return self.depth.max().item() + + @property + def occupancy(self): + return self.num_leaf_nodes / 8 ** self.cur_depth + + @property + def get_xyz(self): + return self.position + + @property + def get_depth(self): + return self.depth + + @property + def get_density(self): + if self.primitive == 'voxel' and self.voxel_config['solid']: + return torch.full((self.position.shape[0], 1), 1000, dtype=torch.float32, device=self.device) + return self.density_activation(self.density) + + @property + def get_opacity(self): + return self.opacity_activation(self.density) + + @property + def get_trivec(self): + return self.trivec + + @property + def get_decoupoly(self): + return F.normalize(self.decoupoly_V, dim=-1), self.decoupoly_g + + @property + def get_color(self): + return self.color_activation(self.colors) + + @property + def get_features(self): + if self.sh_degree == 0: + return self.features_dc + return torch.cat([self.features_dc, self.features_ac], dim=-2) + + def state_dict(self): + ret = {'structure': self.structure, 'position': self.position, 'depth': self.depth, 'sh_degree': self.sh_degree, 'active_sh_degree': self.active_sh_degree, 'trivec_config': self.trivec_config, 'voxel_config': self.voxel_config, 'primitive': self.primitive} + if hasattr(self, 'density_shift'): + ret['density_shift'] = self.density_shift + for data in set(self.data + self.param_names): + if not isinstance(getattr(self, data), nn.Module): + ret[data] = getattr(self, data) + else: + ret[data] = getattr(self, data).state_dict() + return ret + + def load_state_dict(self, state_dict): + keys = list(set(self.data + self.param_names + list(state_dict.keys()) + ['structure', 'position', 'depth'])) + for key in keys: + if key not in state_dict: + print(f"Warning: key {key} not found in the state_dict.") + continue + try: + if not isinstance(getattr(self, key), nn.Module): + setattr(self, key, state_dict[key]) + else: + getattr(self, key).load_state_dict(state_dict[key]) + except Exception as e: + print(e) + raise ValueError(f"Error loading key {key}.") + + def gather_from_leaf_children(self, data): + """ + Gather the data from the leaf children. + + Args: + data (torch.Tensor): the data to gather. The first dimension should be the number of leaf nodes. + """ + leaf_cnt = self.structure[:, 0] + leaf_cnt_masks = [leaf_cnt == i for i in range(1, 9)] + ret = torch.zeros((self.num_non_leaf_nodes,), dtype=data.dtype, device=self.device) + for i in range(8): + if leaf_cnt_masks[i].sum() == 0: + continue + start = self.structure[leaf_cnt_masks[i], 2] + for j in range(i+1): + ret[leaf_cnt_masks[i]] += data[start + j] + return ret + + def gather_from_non_leaf_children(self, data): + """ + Gather the data from the non-leaf children. + + Args: + data (torch.Tensor): the data to gather. The first dimension should be the number of leaf nodes. + """ + non_leaf_cnt = 8 - self.structure[:, 0] + non_leaf_cnt_masks = [non_leaf_cnt == i for i in range(1, 9)] + ret = torch.zeros_like(data, device=self.device) + for i in range(8): + if non_leaf_cnt_masks[i].sum() == 0: + continue + start = self.structure[non_leaf_cnt_masks[i], 1] + for j in range(i+1): + ret[non_leaf_cnt_masks[i]] += data[start + j] + return ret + + def structure_control(self, mask): + """ + Control the structure of the octree. + + Args: + mask (torch.Tensor): the mask to control the structure. 1 for subdivide, -1 for merge, 0 for keep. + """ + # Dont subdivide when the depth is the maximum. + mask[self.depth.squeeze() == self.max_depth] = torch.clamp_max(mask[self.depth.squeeze() == self.max_depth], 0) + # Dont merge when the depth is the minimum. + mask[self.depth.squeeze() == 1] = torch.clamp_min(mask[self.depth.squeeze() == 1], 0) + + # Gather control mask + structre_ctrl = self.gather_from_leaf_children(mask) + structre_ctrl[structre_ctrl==-8] = -1 + + new_leaf_num = self.structure[:, 0].clone() + # Modify the leaf num. + structre_valid = structre_ctrl >= 0 + new_leaf_num[structre_valid] -= structre_ctrl[structre_valid] # Add the new nodes. + structre_delete = structre_ctrl < 0 + merged_nodes = self.gather_from_non_leaf_children(structre_delete.int()) + new_leaf_num += merged_nodes # Delete the merged nodes. + + # Update the structure array to allocate new nodes. + mem_offset = torch.zeros((self.num_non_leaf_nodes + 1,), dtype=torch.int32, device=self.device) + mem_offset.index_add_(0, self.structure[structre_valid, 1], structre_ctrl[structre_valid]) # Add the new nodes. + mem_offset[:-1] -= structre_delete.int() # Delete the merged nodes. + new_structre_idx = torch.arange(0, self.num_non_leaf_nodes + 1, dtype=torch.int32, device=self.device) + mem_offset.cumsum(0) + new_structure_length = new_structre_idx[-1].item() + new_structre_idx = new_structre_idx[:-1] + new_structure = torch.empty((new_structure_length, 3), dtype=torch.int32, device=self.device) + new_structure[new_structre_idx[structre_valid], 0] = new_leaf_num[structre_valid] + + # Initialize the new nodes. + new_node_mask = torch.ones((new_structure_length,), dtype=torch.bool, device=self.device) + new_node_mask[new_structre_idx[structre_valid]] = False + new_structure[new_node_mask, 0] = 8 # Initialize to all leaf nodes. + new_node_num = new_node_mask.sum().item() + + # Rebuild child ptr. + non_leaf_cnt = 8 - new_structure[:, 0] + new_child_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), non_leaf_cnt.cumsum(0)[:-1]]) + new_structure[:, 1] = new_child_ptr + 1 + + # Rebuild data ptr with old data. + leaf_cnt = torch.zeros((new_structure_length,), dtype=torch.int32, device=self.device) + leaf_cnt.index_add_(0, new_structre_idx, self.structure[:, 0]) + old_data_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), leaf_cnt.cumsum(0)[:-1]]) + + # Update the data array + subdivide_mask = mask == 1 + merge_mask = mask == -1 + data_valid = ~(subdivide_mask | merge_mask) + mem_offset = torch.zeros((self.num_leaf_nodes + 1,), dtype=torch.int32, device=self.device) + mem_offset.index_add_(0, old_data_ptr[new_node_mask], torch.full((new_node_num,), 8, dtype=torch.int32, device=self.device)) # Add data array for new nodes + mem_offset[:-1] -= subdivide_mask.int() # Delete data elements for subdivide nodes + mem_offset[:-1] -= merge_mask.int() # Delete data elements for merge nodes + mem_offset.index_add_(0, self.structure[structre_valid, 2], merged_nodes[structre_valid]) # Add data elements for merge nodes + new_data_idx = torch.arange(0, self.num_leaf_nodes + 1, dtype=torch.int32, device=self.device) + mem_offset.cumsum(0) + new_data_length = new_data_idx[-1].item() + new_data_idx = new_data_idx[:-1] + new_data = {data: torch.empty((new_data_length,) + getattr(self, data).shape[1:], dtype=getattr(self, data).dtype, device=self.device) for data in self.data} + for data in self.data: + new_data[data][new_data_idx[data_valid]] = getattr(self, data)[data_valid] + + # Rebuild data ptr + leaf_cnt = new_structure[:, 0] + new_data_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), leaf_cnt.cumsum(0)[:-1]]) + new_structure[:, 2] = new_data_ptr + + # Initialize the new data array + ## For subdivide nodes + if subdivide_mask.sum() > 0: + subdivide_data_ptr = new_structure[new_node_mask, 2] + for data in self.data: + for i in range(8): + if data == 'position': + offset = torch.tensor([i // 4, (i // 2) % 2, i % 2], dtype=torch.float32, device=self.device) - 0.5 + scale = 2 ** (-1.0 - self.depth[subdivide_mask]) + new_data['position'][subdivide_data_ptr + i] = self.position[subdivide_mask] + offset * scale + elif data == 'depth': + new_data['depth'][subdivide_data_ptr + i] = self.depth[subdivide_mask] + 1 + elif data == 'opacity': + new_data['opacity'][subdivide_data_ptr + i] = self.inverse_opacity_activation(torch.sqrt(self.opacity_activation(self.opacity[subdivide_mask]))) + elif data == 'trivec': + offset = torch.tensor([i // 4, (i // 2) % 2, i % 2], dtype=torch.float32, device=self.device) * 0.5 + coord = (torch.linspace(0, 0.5, self.trivec.shape[-1], dtype=torch.float32, device=self.device)[None] + offset[:, None]).reshape(1, 3, self.trivec.shape[-1], 1) + axis = torch.linspace(0, 1, 3, dtype=torch.float32, device=self.device).reshape(1, 3, 1, 1).repeat(1, 1, self.trivec.shape[-1], 1) + coord = torch.stack([coord, axis], dim=3).reshape(1, 3, self.trivec.shape[-1], 2).expand(self.trivec[subdivide_mask].shape[0], -1, -1, -1) * 2 - 1 + new_data['trivec'][subdivide_data_ptr + i] = F.grid_sample(self.trivec[subdivide_mask], coord, align_corners=True) + else: + new_data[data][subdivide_data_ptr + i] = getattr(self, data)[subdivide_mask] + ## For merge nodes + if merge_mask.sum() > 0: + merge_data_ptr = torch.empty((merged_nodes.sum().item(),), dtype=torch.int32, device=self.device) + merge_nodes_cumsum = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), merged_nodes.cumsum(0)[:-1]]) + for i in range(8): + merge_data_ptr[merge_nodes_cumsum[merged_nodes > i] + i] = new_structure[new_structre_idx[merged_nodes > i], 2] + i + old_merge_data_ptr = self.structure[structre_delete, 2] + for data in self.data: + if data == 'position': + scale = 2 ** (1.0 - self.depth[old_merge_data_ptr]) + new_data['position'][merge_data_ptr] = ((self.position[old_merge_data_ptr] + 0.5) / scale).floor() * scale + 0.5 * scale - 0.5 + elif data == 'depth': + new_data['depth'][merge_data_ptr] = self.depth[old_merge_data_ptr] - 1 + elif data == 'opacity': + new_data['opacity'][subdivide_data_ptr + i] = self.inverse_opacity_activation(self.opacity_activation(self.opacity[subdivide_mask])**2) + elif data == 'trivec': + new_data['trivec'][merge_data_ptr] = self.trivec[old_merge_data_ptr] + else: + new_data[data][merge_data_ptr] = getattr(self, data)[old_merge_data_ptr] + + # Update the structure and data array + self.structure = new_structure + for data in self.data: + setattr(self, data, new_data[data]) + + # Save data array control temp variables + self.data_rearrange_buffer = { + 'subdivide_mask': subdivide_mask, + 'merge_mask': merge_mask, + 'data_valid': data_valid, + 'new_data_idx': new_data_idx, + 'new_data_length': new_data_length, + 'new_data': new_data + } diff --git a/trellis/representations/radiance_field/__init__.py b/trellis/representations/radiance_field/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b72a1b7e76b509ee5a5e6979858eb17b4158a151 --- /dev/null +++ b/trellis/representations/radiance_field/__init__.py @@ -0,0 +1 @@ +from .strivec import Strivec \ No newline at end of file diff --git a/trellis/representations/radiance_field/strivec.py b/trellis/representations/radiance_field/strivec.py new file mode 100644 index 0000000000000000000000000000000000000000..f2dc78cdd5a08e994daa57247a3f01f3d43986f9 --- /dev/null +++ b/trellis/representations/radiance_field/strivec.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..octree import DfsOctree as Octree + + +class Strivec(Octree): + def __init__( + self, + resolution: int, + aabb: list, + sh_degree: int = 0, + rank: int = 8, + dim: int = 8, + device: str = "cuda", + ): + assert np.log2(resolution) % 1 == 0, "Resolution must be a power of 2" + self.resolution = resolution + depth = int(np.round(np.log2(resolution))) + super().__init__( + depth=depth, + aabb=aabb, + sh_degree=sh_degree, + primitive="trivec", + primitive_config={"rank": rank, "dim": dim}, + device=device, + ) diff --git a/trellis/utils/__init__.py b/trellis/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/trellis/utils/general_utils.py b/trellis/utils/general_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b91e6d075dbb2c02438b5f345ec3deb164fa7a8f --- /dev/null +++ b/trellis/utils/general_utils.py @@ -0,0 +1,187 @@ +import numpy as np +import cv2 +import torch + + +# Dictionary utils +def _dict_merge(dicta, dictb, prefix=''): + """ + Merge two dictionaries. + """ + assert isinstance(dicta, dict), 'input must be a dictionary' + assert isinstance(dictb, dict), 'input must be a dictionary' + dict_ = {} + all_keys = set(dicta.keys()).union(set(dictb.keys())) + for key in all_keys: + if key in dicta.keys() and key in dictb.keys(): + if isinstance(dicta[key], dict) and isinstance(dictb[key], dict): + dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}') + else: + raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}') + elif key in dicta.keys(): + dict_[key] = dicta[key] + else: + dict_[key] = dictb[key] + return dict_ + + +def dict_merge(dicta, dictb): + """ + Merge two dictionaries. + """ + return _dict_merge(dicta, dictb, prefix='') + + +def dict_foreach(dic, func, special_func={}): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + dic[key] = dict_foreach(dic[key], func) + else: + if key in special_func.keys(): + dic[key] = special_func[key](dic[key]) + else: + dic[key] = func(dic[key]) + return dic + + +def dict_reduce(dicts, func, special_func={}): + """ + Reduce a list of dictionaries. Leaf values must be scalars. + """ + assert isinstance(dicts, list), 'input must be a list of dictionaries' + assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries' + assert len(dicts) > 0, 'input must be a non-empty list of dictionaries' + all_keys = set([key for dict_ in dicts for key in dict_.keys()]) + reduced_dict = {} + for key in all_keys: + vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()] + if isinstance(vlist[0], dict): + reduced_dict[key] = dict_reduce(vlist, func, special_func) + else: + if key in special_func.keys(): + reduced_dict[key] = special_func[key](vlist) + else: + reduced_dict[key] = func(vlist) + return reduced_dict + + +def dict_any(dic, func): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + if dict_any(dic[key], func): + return True + else: + if func(dic[key]): + return True + return False + + +def dict_all(dic, func): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + if not dict_all(dic[key], func): + return False + else: + if not func(dic[key]): + return False + return True + + +def dict_flatten(dic, sep='.'): + """ + Flatten a nested dictionary into a dictionary with no nested dictionaries. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + flat_dict = {} + for key in dic.keys(): + if isinstance(dic[key], dict): + sub_dict = dict_flatten(dic[key], sep=sep) + for sub_key in sub_dict.keys(): + flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key] + else: + flat_dict[key] = dic[key] + return flat_dict + + +def make_grid(images, nrow=None, ncol=None, aspect_ratio=None): + num_images = len(images) + if nrow is None and ncol is None: + if aspect_ratio is not None: + nrow = int(np.round(np.sqrt(num_images / aspect_ratio))) + else: + nrow = int(np.sqrt(num_images)) + ncol = (num_images + nrow - 1) // nrow + elif nrow is None and ncol is not None: + nrow = (num_images + ncol - 1) // ncol + elif nrow is not None and ncol is None: + ncol = (num_images + nrow - 1) // nrow + else: + assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images' + + grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype) + for i, img in enumerate(images): + row = i // ncol + col = i % ncol + grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img + return grid + + +def notes_on_image(img, notes=None): + img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + if notes is not None: + img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + +def save_image_with_notes(img, path, notes=None): + """ + Save an image with notes. + """ + if isinstance(img, torch.Tensor): + img = img.cpu().numpy().transpose(1, 2, 0) + if img.dtype == np.float32 or img.dtype == np.float64: + img = np.clip(img * 255, 0, 255).astype(np.uint8) + img = notes_on_image(img, notes) + cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + + +# debug utils + +def atol(x, y): + """ + Absolute tolerance. + """ + return torch.abs(x - y) + + +def rtol(x, y): + """ + Relative tolerance. + """ + return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12) + + +# print utils +def indent(s, n=4): + """ + Indent a string. + """ + lines = s.split('\n') + for i in range(1, len(lines)): + lines[i] = ' ' * n + lines[i] + return '\n'.join(lines) + diff --git a/trellis/utils/postprocessing_utils.py b/trellis/utils/postprocessing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e1c35fe2f7138a18db2767c37646ba77358eba1b --- /dev/null +++ b/trellis/utils/postprocessing_utils.py @@ -0,0 +1,587 @@ +from typing import * +import numpy as np +import torch +import utils3d +import nvdiffrast.torch as dr +from tqdm import tqdm +import trimesh +import trimesh.visual +import xatlas +import pyvista as pv +from pymeshfix import _meshfix +import igraph +import cv2 +from PIL import Image +from .random_utils import sphere_hammersley_sequence +from .render_utils import render_multiview +from ..renderers import GaussianRenderer +from ..representations import Strivec, Gaussian, MeshExtractResult + + +@torch.no_grad() +def _fill_holes( + verts, + faces, + max_hole_size=0.04, + max_hole_nbe=32, + resolution=128, + num_views=500, + debug=False, + verbose=False +): + """ + Rasterize a mesh from multiple views and remove invisible faces. + Also includes postprocessing to: + 1. Remove connected components that are have low visibility. + 2. Mincut to remove faces at the inner side of the mesh connected to the outer side with a small hole. + + Args: + verts (torch.Tensor): Vertices of the mesh. Shape (V, 3). + faces (torch.Tensor): Faces of the mesh. Shape (F, 3). + max_hole_size (float): Maximum area of a hole to fill. + resolution (int): Resolution of the rasterization. + num_views (int): Number of views to rasterize the mesh. + verbose (bool): Whether to print progress. + """ + # Construct cameras + yaws = [] + pitchs = [] + for i in range(num_views): + y, p = sphere_hammersley_sequence(i, num_views) + yaws.append(y) + pitchs.append(p) + yaws = torch.tensor(yaws).cuda() + pitchs = torch.tensor(pitchs).cuda() + radius = 2.0 + fov = torch.deg2rad(torch.tensor(40)).cuda() + projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3) + views = [] + for (yaw, pitch) in zip(yaws, pitchs): + orig = torch.tensor([ + torch.sin(yaw) * torch.cos(pitch), + torch.cos(yaw) * torch.cos(pitch), + torch.sin(pitch), + ]).cuda().float() * radius + view = utils3d.torch.view_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + views.append(view) + views = torch.stack(views, dim=0) + + # Rasterize + visblity = torch.zeros(faces.shape[0], dtype=torch.int32, device=verts.device) + rastctx = utils3d.torch.RastContext(backend='cuda') + for i in tqdm(range(views.shape[0]), total=views.shape[0], disable=not verbose, desc='Rasterizing'): + view = views[i] + buffers = utils3d.torch.rasterize_triangle_faces( + rastctx, verts[None], faces, resolution, resolution, view=view, projection=projection + ) + face_id = buffers['face_id'][0][buffers['mask'][0] > 0.95] - 1 + face_id = torch.unique(face_id).long() + visblity[face_id] += 1 + visblity = visblity.float() / num_views + + # Mincut + ## construct outer faces + edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces) + boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1) + connected_components = utils3d.torch.compute_connected_components(faces, edges, face2edge) + outer_face_indices = torch.zeros(faces.shape[0], dtype=torch.bool, device=faces.device) + for i in range(len(connected_components)): + outer_face_indices[connected_components[i]] = visblity[connected_components[i]] > min(max(visblity[connected_components[i]].quantile(0.75).item(), 0.25), 0.5) + outer_face_indices = outer_face_indices.nonzero().reshape(-1) + + ## construct inner faces + inner_face_indices = torch.nonzero(visblity == 0).reshape(-1) + if verbose: + tqdm.write(f'Found {inner_face_indices.shape[0]} invisible faces') + if inner_face_indices.shape[0] == 0: + return verts, faces + + ## Construct dual graph (faces as nodes, edges as edges) + dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(face2edge) + dual_edge2edge = edges[dual_edge2edge] + dual_edges_weights = torch.norm(verts[dual_edge2edge[:, 0]] - verts[dual_edge2edge[:, 1]], dim=1) + if verbose: + tqdm.write(f'Dual graph: {dual_edges.shape[0]} edges') + + ## solve mincut problem + ### construct main graph + g = igraph.Graph() + g.add_vertices(faces.shape[0]) + g.add_edges(dual_edges.cpu().numpy()) + g.es['weight'] = dual_edges_weights.cpu().numpy() + + ### source and target + g.add_vertex('s') + g.add_vertex('t') + + ### connect invisible faces to source + g.add_edges([(f, 's') for f in inner_face_indices], attributes={'weight': torch.ones(inner_face_indices.shape[0], dtype=torch.float32).cpu().numpy()}) + + ### connect outer faces to target + g.add_edges([(f, 't') for f in outer_face_indices], attributes={'weight': torch.ones(outer_face_indices.shape[0], dtype=torch.float32).cpu().numpy()}) + + ### solve mincut + cut = g.mincut('s', 't', (np.array(g.es['weight']) * 1000).tolist()) + remove_face_indices = torch.tensor([v for v in cut.partition[0] if v < faces.shape[0]], dtype=torch.long, device=faces.device) + if verbose: + tqdm.write(f'Mincut solved, start checking the cut') + + ### check if the cut is valid with each connected component + to_remove_cc = utils3d.torch.compute_connected_components(faces[remove_face_indices]) + if debug: + tqdm.write(f'Number of connected components of the cut: {len(to_remove_cc)}') + valid_remove_cc = [] + cutting_edges = [] + for cc in to_remove_cc: + #### check if the connected component has low visibility + visblity_median = visblity[remove_face_indices[cc]].median() + if debug: + tqdm.write(f'visblity_median: {visblity_median}') + if visblity_median > 0.25: + continue + + #### check if the cuting loop is small enough + cc_edge_indices, cc_edges_degree = torch.unique(face2edge[remove_face_indices[cc]], return_counts=True) + cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1] + cc_new_boundary_edge_indices = cc_boundary_edge_indices[~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)] + if len(cc_new_boundary_edge_indices) > 0: + cc_new_boundary_edge_cc = utils3d.torch.compute_edge_connected_components(edges[cc_new_boundary_edge_indices]) + cc_new_boundary_edges_cc_center = [verts[edges[cc_new_boundary_edge_indices[edge_cc]]].mean(dim=1).mean(dim=0) for edge_cc in cc_new_boundary_edge_cc] + cc_new_boundary_edges_cc_area = [] + for i, edge_cc in enumerate(cc_new_boundary_edge_cc): + _e1 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]] - cc_new_boundary_edges_cc_center[i] + _e2 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]] - cc_new_boundary_edges_cc_center[i] + cc_new_boundary_edges_cc_area.append(torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() * 0.5) + if debug: + cutting_edges.append(cc_new_boundary_edge_indices) + tqdm.write(f'Area of the cutting loop: {cc_new_boundary_edges_cc_area}') + if any([l > max_hole_size for l in cc_new_boundary_edges_cc_area]): + continue + + valid_remove_cc.append(cc) + + if debug: + face_v = verts[faces].mean(dim=1).cpu().numpy() + vis_dual_edges = dual_edges.cpu().numpy() + vis_colors = np.zeros((faces.shape[0], 3), dtype=np.uint8) + vis_colors[inner_face_indices.cpu().numpy()] = [0, 0, 255] + vis_colors[outer_face_indices.cpu().numpy()] = [0, 255, 0] + vis_colors[remove_face_indices.cpu().numpy()] = [255, 0, 255] + if len(valid_remove_cc) > 0: + vis_colors[remove_face_indices[torch.cat(valid_remove_cc)].cpu().numpy()] = [255, 0, 0] + utils3d.io.write_ply('dbg_dual.ply', face_v, edges=vis_dual_edges, vertex_colors=vis_colors) + + vis_verts = verts.cpu().numpy() + vis_edges = edges[torch.cat(cutting_edges)].cpu().numpy() + utils3d.io.write_ply('dbg_cut.ply', vis_verts, edges=vis_edges) + + + if len(valid_remove_cc) > 0: + remove_face_indices = remove_face_indices[torch.cat(valid_remove_cc)] + mask = torch.ones(faces.shape[0], dtype=torch.bool, device=faces.device) + mask[remove_face_indices] = 0 + faces = faces[mask] + faces, verts = utils3d.torch.remove_unreferenced_vertices(faces, verts) + if verbose: + tqdm.write(f'Removed {(~mask).sum()} faces by mincut') + else: + if verbose: + tqdm.write(f'Removed 0 faces by mincut') + + mesh = _meshfix.PyTMesh() + mesh.load_array(verts.cpu().numpy(), faces.cpu().numpy()) + mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True) + verts, faces = mesh.return_arrays() + verts, faces = torch.tensor(verts, device='cuda', dtype=torch.float32), torch.tensor(faces, device='cuda', dtype=torch.int32) + + return verts, faces + + +def postprocess_mesh( + vertices: np.array, + faces: np.array, + simplify: bool = True, + simplify_ratio: float = 0.9, + fill_holes: bool = True, + fill_holes_max_hole_size: float = 0.04, + fill_holes_max_hole_nbe: int = 32, + fill_holes_resolution: int = 1024, + fill_holes_num_views: int = 1000, + debug: bool = False, + verbose: bool = False, +): + """ + Postprocess a mesh by simplifying, removing invisible faces, and removing isolated pieces. + + Args: + vertices (np.array): Vertices of the mesh. Shape (V, 3). + faces (np.array): Faces of the mesh. Shape (F, 3). + simplify (bool): Whether to simplify the mesh, using quadric edge collapse. + simplify_ratio (float): Ratio of faces to keep after simplification. + fill_holes (bool): Whether to fill holes in the mesh. + fill_holes_max_hole_size (float): Maximum area of a hole to fill. + fill_holes_max_hole_nbe (int): Maximum number of boundary edges of a hole to fill. + fill_holes_resolution (int): Resolution of the rasterization. + fill_holes_num_views (int): Number of views to rasterize the mesh. + verbose (bool): Whether to print progress. + """ + + if verbose: + tqdm.write(f'Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces') + + # Simplify + if simplify and simplify_ratio > 0: + mesh = pv.PolyData(vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1)) + mesh = mesh.decimate(simplify_ratio, progress_bar=verbose) + vertices, faces = mesh.points, mesh.faces.reshape(-1, 4)[:, 1:] + if verbose: + tqdm.write(f'After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces') + + # Remove invisible faces + if fill_holes: + vertices, faces = torch.tensor(vertices).cuda(), torch.tensor(faces.astype(np.int32)).cuda() + vertices, faces = _fill_holes( + vertices, faces, + max_hole_size=fill_holes_max_hole_size, + max_hole_nbe=fill_holes_max_hole_nbe, + resolution=fill_holes_resolution, + num_views=fill_holes_num_views, + debug=debug, + verbose=verbose, + ) + vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy() + if verbose: + tqdm.write(f'After remove invisible faces: {vertices.shape[0]} vertices, {faces.shape[0]} faces') + + return vertices, faces + + +def parametrize_mesh(vertices: np.array, faces: np.array): + """ + Parametrize a mesh to a texture space, using xatlas. + + Args: + vertices (np.array): Vertices of the mesh. Shape (V, 3). + faces (np.array): Faces of the mesh. Shape (F, 3). + """ + + vmapping, indices, uvs = xatlas.parametrize(vertices, faces) + + vertices = vertices[vmapping] + faces = indices + + return vertices, faces, uvs + + +def bake_texture( + vertices: np.array, + faces: np.array, + uvs: np.array, + observations: List[np.array], + masks: List[np.array], + extrinsics: List[np.array], + intrinsics: List[np.array], + texture_size: int = 2048, + near: float = 0.1, + far: float = 10.0, + mode: Literal['fast', 'opt'] = 'opt', + lambda_tv: float = 1e-2, + verbose: bool = False, +): + """ + Bake texture to a mesh from multiple observations. + + Args: + vertices (np.array): Vertices of the mesh. Shape (V, 3). + faces (np.array): Faces of the mesh. Shape (F, 3). + uvs (np.array): UV coordinates of the mesh. Shape (V, 2). + observations (List[np.array]): List of observations. Each observation is a 2D image. Shape (H, W, 3). + masks (List[np.array]): List of masks. Each mask is a 2D image. Shape (H, W). + extrinsics (List[np.array]): List of extrinsics. Shape (4, 4). + intrinsics (List[np.array]): List of intrinsics. Shape (3, 3). + texture_size (int): Size of the texture. + near (float): Near plane of the camera. + far (float): Far plane of the camera. + mode (Literal['fast', 'opt']): Mode of texture baking. + lambda_tv (float): Weight of total variation loss in optimization. + verbose (bool): Whether to print progress. + """ + vertices = torch.tensor(vertices).cuda() + faces = torch.tensor(faces.astype(np.int32)).cuda() + uvs = torch.tensor(uvs).cuda() + observations = [torch.tensor(obs / 255.0).float().cuda() for obs in observations] + masks = [torch.tensor(m>0).bool().cuda() for m in masks] + views = [utils3d.torch.extrinsics_to_view(torch.tensor(extr).cuda()) for extr in extrinsics] + projections = [utils3d.torch.intrinsics_to_perspective(torch.tensor(intr).cuda(), near, far) for intr in intrinsics] + + if mode == 'fast': + texture = torch.zeros((texture_size * texture_size, 3), dtype=torch.float32).cuda() + texture_weights = torch.zeros((texture_size * texture_size), dtype=torch.float32).cuda() + rastctx = utils3d.torch.RastContext(backend='cuda') + for observation, view, projection in tqdm(zip(observations, views, projections), total=len(observations), disable=not verbose, desc='Texture baking (fast)'): + with torch.no_grad(): + rast = utils3d.torch.rasterize_triangle_faces( + rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection + ) + uv_map = rast['uv'][0].detach().flip(0) + mask = rast['mask'][0].detach().bool() & masks[0] + + # nearest neighbor interpolation + uv_map = (uv_map * texture_size).floor().long() + obs = observation[mask] + uv_map = uv_map[mask] + idx = uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size + texture = texture.scatter_add(0, idx.view(-1, 1).expand(-1, 3), obs) + texture_weights = texture_weights.scatter_add(0, idx, torch.ones((obs.shape[0]), dtype=torch.float32, device=texture.device)) + + mask = texture_weights > 0 + texture[mask] /= texture_weights[mask][:, None] + texture = np.clip(texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, 0, 255).astype(np.uint8) + + # inpaint + mask = (texture_weights == 0).cpu().numpy().astype(np.uint8).reshape(texture_size, texture_size) + texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) + + elif mode == 'opt': + rastctx = utils3d.torch.RastContext(backend='cuda') + observations = [observations.flip(0) for observations in observations] + masks = [m.flip(0) for m in masks] + _uv = [] + _uv_dr = [] + for observation, view, projection in tqdm(zip(observations, views, projections), total=len(views), disable=not verbose, desc='Texture baking (opt): UV'): + with torch.no_grad(): + rast = utils3d.torch.rasterize_triangle_faces( + rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection + ) + _uv.append(rast['uv'].detach()) + _uv_dr.append(rast['uv_dr'].detach()) + + texture = torch.nn.Parameter(torch.zeros((1, texture_size, texture_size, 3), dtype=torch.float32).cuda()) + optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2) + + def exp_anealing(optimizer, step, total_steps, start_lr, end_lr): + return start_lr * (end_lr / start_lr) ** (step / total_steps) + + def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr): + return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps)) + + def tv_loss(texture): + return torch.nn.functional.l1_loss(texture[:, :-1, :, :], texture[:, 1:, :, :]) + \ + torch.nn.functional.l1_loss(texture[:, :, :-1, :], texture[:, :, 1:, :]) + + total_steps = 2500 + with tqdm(total=total_steps, disable=not verbose, desc='Texture baking (opt): optimizing') as pbar: + for step in range(total_steps): + optimizer.zero_grad() + selected = np.random.randint(0, len(views)) + uv, uv_dr, observation, mask = _uv[selected], _uv_dr[selected], observations[selected], masks[selected] + render = dr.texture(texture, uv, uv_dr)[0] + loss = torch.nn.functional.l1_loss(render[mask], observation[mask]) + if lambda_tv > 0: + loss += lambda_tv * tv_loss(texture) + loss.backward() + optimizer.step() + # annealing + optimizer.param_groups[0]['lr'] = cosine_anealing(optimizer, step, total_steps, 1e-2, 1e-5) + pbar.set_postfix({'loss': loss.item()}) + pbar.update() + texture = np.clip(texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255).astype(np.uint8) + mask = 1 - utils3d.torch.rasterize_triangle_faces( + rastctx, (uvs * 2 - 1)[None], faces, texture_size, texture_size + )['mask'][0].detach().cpu().numpy().astype(np.uint8) + texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) + else: + raise ValueError(f'Unknown mode: {mode}') + + return texture + + +def to_glb( + app_rep: Union[Strivec, Gaussian], + mesh: MeshExtractResult, + simplify: float = 0.95, + fill_holes: bool = True, + fill_holes_max_size: float = 0.04, + texture_size: int = 1024, + debug: bool = False, + verbose: bool = True, +) -> trimesh.Trimesh: + """ + Convert a generated asset to a glb file. + + Args: + app_rep (Union[Strivec, Gaussian]): Appearance representation. + mesh (MeshExtractResult): Extracted mesh. + simplify (float): Ratio of faces to remove in simplification. + fill_holes (bool): Whether to fill holes in the mesh. + fill_holes_max_size (float): Maximum area of a hole to fill. + texture_size (int): Size of the texture. + debug (bool): Whether to print debug information. + verbose (bool): Whether to print progress. + """ + vertices = mesh.vertices.cpu().numpy() + faces = mesh.faces.cpu().numpy() + + # mesh postprocess + vertices, faces = postprocess_mesh( + vertices, faces, + simplify=simplify > 0, + simplify_ratio=simplify, + fill_holes=fill_holes, + fill_holes_max_hole_size=fill_holes_max_size, + fill_holes_max_hole_nbe=int(250 * np.sqrt(1-simplify)), + fill_holes_resolution=1024, + fill_holes_num_views=1000, + debug=debug, + verbose=verbose, + ) + + # parametrize mesh + vertices, faces, uvs = parametrize_mesh(vertices, faces) + + # bake texture + observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100) + masks = [np.any(observation > 0, axis=-1) for observation in observations] + extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))] + intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))] + texture = bake_texture( + vertices, faces, uvs, + observations, masks, extrinsics, intrinsics, + texture_size=texture_size, mode='opt', + lambda_tv=0.01, + verbose=verbose + ) + texture = Image.fromarray(texture) + + # rotate mesh (from z-up to y-up) + vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) + material = trimesh.visual.material.PBRMaterial( + roughnessFactor=1.0, + baseColorTexture=texture, + baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8) + ) + mesh = trimesh.Trimesh(vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, material=material)) + return mesh + + +def simplify_gs( + gs: Gaussian, + simplify: float = 0.95, + verbose: bool = True, +): + """ + Simplify 3D Gaussians + NOTE: this function is not used in the current implementation for the unsatisfactory performance. + + Args: + gs (Gaussian): 3D Gaussian. + simplify (float): Ratio of Gaussians to remove in simplification. + """ + if simplify <= 0: + return gs + + # simplify + observations, extrinsics, intrinsics = render_multiview(gs, resolution=1024, nviews=100) + observations = [torch.tensor(obs / 255.0).float().cuda().permute(2, 0, 1) for obs in observations] + + # Following https://arxiv.org/pdf/2411.06019 + renderer = GaussianRenderer({ + "resolution": 1024, + "near": 0.8, + "far": 1.6, + "ssaa": 1, + "bg_color": (0,0,0), + }) + new_gs = Gaussian(**gs.init_params) + new_gs._features_dc = gs._features_dc.clone() + new_gs._features_rest = gs._features_rest.clone() if gs._features_rest is not None else None + new_gs._opacity = torch.nn.Parameter(gs._opacity.clone()) + new_gs._rotation = torch.nn.Parameter(gs._rotation.clone()) + new_gs._scaling = torch.nn.Parameter(gs._scaling.clone()) + new_gs._xyz = torch.nn.Parameter(gs._xyz.clone()) + + start_lr = [1e-4, 1e-3, 5e-3, 0.025] + end_lr = [1e-6, 1e-5, 5e-5, 0.00025] + optimizer = torch.optim.Adam([ + {"params": new_gs._xyz, "lr": start_lr[0]}, + {"params": new_gs._rotation, "lr": start_lr[1]}, + {"params": new_gs._scaling, "lr": start_lr[2]}, + {"params": new_gs._opacity, "lr": start_lr[3]}, + ], lr=start_lr[0]) + + def exp_anealing(optimizer, step, total_steps, start_lr, end_lr): + return start_lr * (end_lr / start_lr) ** (step / total_steps) + + def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr): + return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps)) + + _zeta = new_gs.get_opacity.clone().detach().squeeze() + _lambda = torch.zeros_like(_zeta) + _delta = 1e-7 + _interval = 10 + num_target = int((1 - simplify) * _zeta.shape[0]) + + with tqdm(total=2500, disable=not verbose, desc='Simplifying Gaussian') as pbar: + for i in range(2500): + # prune + if i % 100 == 0: + mask = new_gs.get_opacity.squeeze() > 0.05 + mask = torch.nonzero(mask).squeeze() + new_gs._xyz = torch.nn.Parameter(new_gs._xyz[mask]) + new_gs._rotation = torch.nn.Parameter(new_gs._rotation[mask]) + new_gs._scaling = torch.nn.Parameter(new_gs._scaling[mask]) + new_gs._opacity = torch.nn.Parameter(new_gs._opacity[mask]) + new_gs._features_dc = new_gs._features_dc[mask] + new_gs._features_rest = new_gs._features_rest[mask] if new_gs._features_rest is not None else None + _zeta = _zeta[mask] + _lambda = _lambda[mask] + # update optimizer state + for param_group, new_param in zip(optimizer.param_groups, [new_gs._xyz, new_gs._rotation, new_gs._scaling, new_gs._opacity]): + stored_state = optimizer.state[param_group['params'][0]] + if 'exp_avg' in stored_state: + stored_state['exp_avg'] = stored_state['exp_avg'][mask] + stored_state['exp_avg_sq'] = stored_state['exp_avg_sq'][mask] + del optimizer.state[param_group['params'][0]] + param_group['params'][0] = new_param + optimizer.state[param_group['params'][0]] = stored_state + + opacity = new_gs.get_opacity.squeeze() + + # sparisfy + if i % _interval == 0: + _zeta = _lambda + opacity.detach() + if opacity.shape[0] > num_target: + index = _zeta.topk(num_target)[1] + _m = torch.ones_like(_zeta, dtype=torch.bool) + _m[index] = 0 + _zeta[_m] = 0 + _lambda = _lambda + opacity.detach() - _zeta + + # sample a random view + view_idx = np.random.randint(len(observations)) + observation = observations[view_idx] + extrinsic = extrinsics[view_idx] + intrinsic = intrinsics[view_idx] + + color = renderer.render(new_gs, extrinsic, intrinsic)['color'] + rgb_loss = torch.nn.functional.l1_loss(color, observation) + loss = rgb_loss + \ + _delta * torch.sum(torch.pow(_lambda + opacity - _zeta, 2)) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # update lr + for j in range(len(optimizer.param_groups)): + optimizer.param_groups[j]['lr'] = cosine_anealing(optimizer, i, 2500, start_lr[j], end_lr[j]) + + pbar.set_postfix({'loss': rgb_loss.item(), 'num': opacity.shape[0], 'lambda': _lambda.mean().item()}) + pbar.update() + + new_gs._xyz = new_gs._xyz.data + new_gs._rotation = new_gs._rotation.data + new_gs._scaling = new_gs._scaling.data + new_gs._opacity = new_gs._opacity.data + + return new_gs diff --git a/trellis/utils/random_utils.py b/trellis/utils/random_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..420c4146cbb4c2973f2cc2e69f376a3836e65eeb --- /dev/null +++ b/trellis/utils/random_utils.py @@ -0,0 +1,30 @@ +import numpy as np + +PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53] + +def radical_inverse(base, n): + val = 0 + inv_base = 1.0 / base + inv_base_n = inv_base + while n > 0: + digit = n % base + val += digit * inv_base_n + n //= base + inv_base_n *= inv_base + return val + +def halton_sequence(dim, n): + return [radical_inverse(PRIMES[dim], n) for dim in range(dim)] + +def hammersley_sequence(dim, n, num_samples): + return [n / num_samples] + halton_sequence(dim - 1, n) + +def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False): + u, v = hammersley_sequence(2, n, num_samples) + u += offset[0] / num_samples + v += offset[1] + if remap: + u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3 + theta = np.arccos(1 - 2 * u) - np.pi / 2 + phi = v * 2 * np.pi + return [phi, theta] \ No newline at end of file diff --git a/trellis/utils/render_utils.py b/trellis/utils/render_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..54a33ce79c6d3e1e358ab1650ea14cfe1d50ba91 --- /dev/null +++ b/trellis/utils/render_utils.py @@ -0,0 +1,116 @@ +import torch +import numpy as np +from tqdm import tqdm +import utils3d +from PIL import Image + +from ..renderers import OctreeRenderer, GaussianRenderer, MeshRenderer +from ..representations import Octree, Gaussian, MeshExtractResult +from ..modules import sparse as sp +from .random_utils import sphere_hammersley_sequence + + +def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs): + is_list = isinstance(yaws, list) + if not is_list: + yaws = [yaws] + pitchs = [pitchs] + if not isinstance(rs, list): + rs = [rs] * len(yaws) + if not isinstance(fovs, list): + fovs = [fovs] * len(yaws) + extrinsics = [] + intrinsics = [] + for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs): + fov = torch.deg2rad(torch.tensor(float(fov))).cuda() + yaw = torch.tensor(float(yaw)).cuda() + pitch = torch.tensor(float(pitch)).cuda() + orig = torch.tensor([ + torch.sin(yaw) * torch.cos(pitch), + torch.cos(yaw) * torch.cos(pitch), + torch.sin(pitch), + ]).cuda() * r + extr = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + extrinsics.append(extr) + intrinsics.append(intr) + if not is_list: + extrinsics = extrinsics[0] + intrinsics = intrinsics[0] + return extrinsics, intrinsics + + +def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=None, verbose=True, **kwargs): + if isinstance(sample, Octree): + renderer = OctreeRenderer() + renderer.rendering_options.resolution = options.get('resolution', 512) + renderer.rendering_options.near = options.get('near', 0.8) + renderer.rendering_options.far = options.get('far', 1.6) + renderer.rendering_options.bg_color = options.get('bg_color', (0, 0, 0)) + renderer.rendering_options.ssaa = options.get('ssaa', 4) + renderer.pipe.primitive = sample.primitive + elif isinstance(sample, Gaussian): + renderer = GaussianRenderer() + renderer.rendering_options.resolution = options.get('resolution', 512) + renderer.rendering_options.near = options.get('near', 0.8) + renderer.rendering_options.far = options.get('far', 1.6) + renderer.rendering_options.bg_color = options.get('bg_color', (0, 0, 0)) + renderer.rendering_options.ssaa = options.get('ssaa', 1) + renderer.pipe.kernel_size = kwargs.get('kernel_size', 0.1) + renderer.pipe.use_mip_gaussian = True + elif isinstance(sample, MeshExtractResult): + renderer = MeshRenderer() + renderer.rendering_options.resolution = options.get('resolution', 512) + renderer.rendering_options.near = options.get('near', 1) + renderer.rendering_options.far = options.get('far', 100) + renderer.rendering_options.ssaa = options.get('ssaa', 4) + else: + raise ValueError(f'Unsupported sample type: {type(sample)}') + + rets = {} + for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), desc='Rendering', disable=not verbose): + if not isinstance(sample, MeshExtractResult): + res = renderer.render(sample, extr, intr, colors_overwrite=colors_overwrite) + if 'color' not in rets: rets['color'] = [] + if 'depth' not in rets: rets['depth'] = [] + rets['color'].append(np.clip(res['color'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)) + if 'percent_depth' in res: + rets['depth'].append(res['percent_depth'].detach().cpu().numpy()) + elif 'depth' in res: + rets['depth'].append(res['depth'].detach().cpu().numpy()) + else: + rets['depth'].append(None) + else: + res = renderer.render(sample, extr, intr) + if 'normal' not in rets: rets['normal'] = [] + rets['normal'].append(np.clip(res['normal'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)) + return rets + + +def render_video(sample, resolution=512, bg_color=(0, 0, 0), num_frames=300, r=2, fov=40, **kwargs): + yaws = torch.linspace(0, 2 * 3.1415, num_frames) + pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames)) + yaws = yaws.tolist() + pitch = pitch.tolist() + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov) + return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs) + + +def render_multiview(sample, resolution=512, nviews=30): + r = 2 + fov = 40 + cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)] + yaws = [cam[0] for cam in cams] + pitchs = [cam[1] for cam in cams] + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov) + res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)}) + return res['color'], extrinsics, intrinsics + + +def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), r=10, fov=8, **kwargs): + yaw = [0, np.pi/2, np.pi, 3*np.pi/2] + yaw_offset = offset[0] + yaw = [y + yaw_offset for y in yaw] + pitch = [offset[1] for _ in range(4)] + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov) + return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs) diff --git a/wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl b/wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl new file mode 100644 index 0000000000000000000000000000000000000000..283474718e00d7677de9ef70a276d5868ec78578 Binary files /dev/null and b/wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl differ diff --git a/wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl b/wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl new file mode 100644 index 0000000000000000000000000000000000000000..79dfc23fcf95efcd37179cbf024ccae0f313baf4 --- /dev/null +++ b/wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:026b3031cc647d279b5beb0a3ec2bfe992666d85f66431662d8f26be2b6894f9 +size 1047624