|
import torch |
|
from pytorch3d.renderer.mesh.shader import ShaderBase |
|
from pytorch3d.renderer import ( |
|
SoftPhongShader, |
|
) |
|
from pytorch3d.renderer import BlendParams |
|
|
|
|
|
class MultiOutputShader(ShaderBase): |
|
def __init__(self, device, cameras, lights, materials, ccm_scale=1.0, choices=None): |
|
super().__init__() |
|
self.device = device |
|
self.cameras = cameras |
|
self.lights = lights |
|
self.materials = materials |
|
self.ccm_scale = ccm_scale |
|
|
|
if choices is None: |
|
self.choices = ["rgb", "mask", "depth", "normal", "albedo", "ccm"] |
|
else: |
|
self.choices = choices |
|
blend_params = BlendParams(sigma=1e-4, gamma=1e-4) |
|
self.phong_shader = SoftPhongShader( |
|
device=self.device, |
|
cameras=self.cameras, |
|
lights=self.lights, |
|
materials=self.materials, |
|
blend_params=blend_params |
|
) |
|
|
|
def forward(self, fragments, meshes, **kwargs): |
|
batch_size, H, W, _ = fragments.zbuf.shape |
|
output = {} |
|
|
|
if "rgb" in self.choices: |
|
rgb_images = self.phong_shader(fragments, meshes, **kwargs) |
|
rgb = rgb_images[..., :3] |
|
output["rgb"] = rgb |
|
|
|
if "mask" in self.choices: |
|
alpha = rgb_images[..., 3:4] |
|
mask = (alpha > 0).float() |
|
output["mask"] = mask |
|
|
|
if "albedo" in self.choices: |
|
albedo = meshes.sample_textures(fragments) |
|
output["albedo"] = albedo[..., 0, :] |
|
|
|
if "depth" in self.choices: |
|
depth = fragments.zbuf |
|
output["depth"] = depth |
|
|
|
if "normal" in self.choices: |
|
pix_to_face = fragments.pix_to_face[..., 0] |
|
bary_coords = fragments.bary_coords[..., 0, :] |
|
valid_mask = pix_to_face >= 0 |
|
face_indices = pix_to_face[valid_mask] |
|
faces_packed = meshes.faces_packed() |
|
normals_packed = meshes.verts_normals_packed() |
|
face_vertex_normals = normals_packed[faces_packed[face_indices]] |
|
bary = bary_coords.view(-1, 3)[valid_mask.view(-1)] |
|
interpolated_normals = ( |
|
bary[..., 0:1] * face_vertex_normals[:, 0, :] + |
|
bary[..., 1:2] * face_vertex_normals[:, 1, :] + |
|
bary[..., 2:3] * face_vertex_normals[:, 2, :] |
|
) |
|
interpolated_normals = interpolated_normals / interpolated_normals.norm(dim=-1, keepdim=True) |
|
normal = torch.zeros(batch_size, H, W, 3, device=self.device) |
|
normal[valid_mask] = interpolated_normals |
|
output["normal"] = normal |
|
|
|
if "ccm" in self.choices: |
|
face_vertices = meshes.verts_packed()[meshes.faces_packed()] |
|
faces_at_pixels = face_vertices[fragments.pix_to_face] |
|
ccm = torch.sum(fragments.bary_coords.unsqueeze(-1) * faces_at_pixels, dim=-2) |
|
ccm = (ccm[..., 0, :] * self.ccm_scale + 1) / 2 |
|
output["ccm"] = ccm |
|
|
|
return output |