ttxskk
update
d7e58f0
import torch
def vis_z_buffer(z, percentile=1, vis_pad=0.2):
z = z[:, :, 0]
mask = z > 1e-5
if torch.sum(mask) == 0:
z[...] = 0
else:
vmin = torch.quantile(z[mask], percentile / 100)
vmax = torch.quantile(z[mask], 1 - percentile / 100)
pad = (vmax - vmin) * vis_pad
vmin_padded = vmin - pad
vmax_padded = vmax + pad
z[mask] = vmin + vmax - z[mask]
z = (z - vmin_padded) / (vmax_padded - vmin_padded)
z = torch.clip(torch.round(z * 255), 0, 255)
z_cpu = z.to(dtype=torch.uint8).detach().cpu().numpy()
return z_cpu
def vis_normals(coords, normals, vis_pad=0.2):
mask = coords[:, :, 2] > 0
coords_masked = -coords[mask]
normals_masked = normals[mask]
coords_len = torch.sqrt(torch.sum(coords_masked**2, dim=1))
dot = torch.sum(coords_masked * normals_masked, dim=1) / coords_len
h, w = normals.shape[:2]
vis = torch.zeros((h, w), dtype=coords.dtype, device=coords.device)
vis[mask] = torch.clamp(dot, 0, 1) * (1 - 2 * vis_pad) + vis_pad
vis = (vis * 255).to(dtype=torch.uint8)
return vis