Unique3D / scripts /mesh_init.py
Wuvin's picture
init
37aeb5b
from PIL import Image
import torch
import numpy as np
from pytorch3d.structures import Meshes
from pytorch3d.renderer import TexturesVertex
from scripts.utils import meshlab_mesh_to_py3dmesh, py3dmesh_to_meshlab_mesh
import pymeshlab
_MAX_THREAD = 8
# rgb and depth to mesh
def get_ortho_ray_directions_origins(W, H, use_pixel_centers=True, device="cuda"):
pixel_center = 0.5 if use_pixel_centers else 0
i, j = np.meshgrid(
np.arange(W, dtype=np.float32) + pixel_center,
np.arange(H, dtype=np.float32) + pixel_center,
indexing='xy'
)
i, j = torch.from_numpy(i).to(device), torch.from_numpy(j).to(device)
origins = torch.stack([(i/W-0.5)*2, (j/H-0.5)*2 * H / W, torch.zeros_like(i)], dim=-1) # W, H, 3
directions = torch.stack([torch.zeros_like(i), torch.zeros_like(j), torch.ones_like(i)], dim=-1) # W, H, 3
return origins, directions
def depth_and_color_to_mesh(rgb_BCHW, pred_HWC, valid_HWC=None, is_back=False):
if valid_HWC is None:
valid_HWC = torch.ones_like(pred_HWC).bool()
H, W = rgb_BCHW.shape[-2:]
rgb_BCHW = rgb_BCHW.flip(-2)
pred_HWC = pred_HWC.flip(0)
valid_HWC = valid_HWC.flip(0)
rays_o, rays_d = get_ortho_ray_directions_origins(W, H, device=rgb_BCHW.device)
verts = rays_o + rays_d * pred_HWC # [H, W, 3]
verts = verts.reshape(-1, 3) # [V, 3]
indexes = torch.arange(H * W).reshape(H, W).to(rgb_BCHW.device)
faces1 = torch.stack([indexes[:-1, :-1], indexes[:-1, 1:], indexes[1:, :-1]], dim=-1)
# faces1_valid = valid_HWC[:-1, :-1] | valid_HWC[:-1, 1:] | valid_HWC[1:, :-1]
faces1_valid = valid_HWC[:-1, :-1] & valid_HWC[:-1, 1:] & valid_HWC[1:, :-1]
faces2 = torch.stack([indexes[1:, 1:], indexes[1:, :-1], indexes[:-1, 1:]], dim=-1)
# faces2_valid = valid_HWC[1:, 1:] | valid_HWC[1:, :-1] | valid_HWC[:-1, 1:]
faces2_valid = valid_HWC[1:, 1:] & valid_HWC[1:, :-1] & valid_HWC[:-1, 1:]
faces = torch.cat([faces1[faces1_valid.expand_as(faces1)].reshape(-1, 3), faces2[faces2_valid.expand_as(faces2)].reshape(-1, 3)], dim=0) # (F, 3)
colors = (rgb_BCHW[0].permute((1,2,0)) / 2 + 0.5).reshape(-1, 3) # (V, 3)
if is_back:
verts = verts * torch.tensor([-1, 1, -1], dtype=verts.dtype, device=verts.device)
used_verts = faces.unique()
old_to_new_mapping = torch.zeros_like(verts[..., 0]).long()
old_to_new_mapping[used_verts] = torch.arange(used_verts.shape[0], device=verts.device)
new_faces = old_to_new_mapping[faces]
mesh = Meshes(verts=[verts[used_verts]], faces=[new_faces], textures=TexturesVertex(verts_features=[colors[used_verts]]))
return mesh
def normalmap_to_depthmap(normal_np):
from scripts.normal_to_height_map import estimate_height_map
height = estimate_height_map(normal_np, raw_values=True, thread_count=_MAX_THREAD, target_iteration_count=96)
return height
def transform_back_normal_to_front(normal_pil):
arr = np.array(normal_pil) # in [0, 255]
arr[..., 0] = 255-arr[..., 0]
arr[..., 2] = 255-arr[..., 2]
return Image.fromarray(arr.astype(np.uint8))
def calc_w_over_h(normal_pil):
if isinstance(normal_pil, Image.Image):
arr = np.array(normal_pil)
else:
assert isinstance(normal_pil, np.ndarray)
arr = normal_pil
if arr.shape[-1] == 4:
alpha = arr[..., -1] / 255.
alpha[alpha >= 0.5] = 1
alpha[alpha < 0.5] = 0
else:
alpha = ~(arr.min(axis=-1) >= 250)
h_min, w_min = np.min(np.where(alpha), axis=1)
h_max, w_max = np.max(np.where(alpha), axis=1)
return (w_max - w_min) / (h_max - h_min)
def build_mesh(normal_pil, rgb_pil, is_back=False, clamp_min=-1, scale=0.3, init_type="std", offset=0):
if is_back:
normal_pil = transform_back_normal_to_front(normal_pil)
normal_img = np.array(normal_pil)
rgb_img = np.array(rgb_pil)
if normal_img.shape[-1] == 4:
valid_HWC = normal_img[..., [3]] / 255
elif rgb_img.shape[-1] == 4:
valid_HWC = rgb_img[..., [3]] / 255
else:
raise ValueError("invalid input, either normal or rgb should have alpha channel")
real_height_pix = np.max(np.where(valid_HWC>0.5)[0]) - np.min(np.where(valid_HWC>0.5)[0])
heights = normalmap_to_depthmap(normal_img)
rgb_BCHW = torch.from_numpy(rgb_img[..., :3] / 255.).permute((2,0,1))[None]
valid_HWC[valid_HWC < 0.5] = 0
valid_HWC[valid_HWC >= 0.5] = 1
valid_HWC = torch.from_numpy(valid_HWC).bool()
if init_type == "std":
# accurate but not stable
pred_HWC = torch.from_numpy(heights / heights.max() * (real_height_pix / heights.shape[0]) * scale * 2).float()[..., None]
elif init_type == "thin":
heights = heights - heights.min()
heights = (heights / heights.max() * 0.2)
pred_HWC = torch.from_numpy(heights * scale).float()[..., None]
else:
# stable but not accurate
heights = heights - heights.min()
heights = (heights / heights.max() * (1-offset)) + offset # to [0.2, 1]
pred_HWC = torch.from_numpy(heights * scale).float()[..., None]
# set the boarder pixels to 0 height
import cv2
# edge filter
edge = cv2.Canny((valid_HWC[..., 0] * 255).numpy().astype(np.uint8), 0, 255)
edge = torch.from_numpy(edge).bool()[..., None]
pred_HWC[edge] = 0
valid_HWC[pred_HWC < clamp_min] = False
return depth_and_color_to_mesh(rgb_BCHW.cuda(), pred_HWC.cuda(), valid_HWC.cuda(), is_back)
def fix_border_with_pymeshlab_fast(meshes: Meshes, poissson_depth=6, simplification=0):
ms = pymeshlab.MeshSet()
ms.add_mesh(py3dmesh_to_meshlab_mesh(meshes), "cube_vcolor_mesh")
if simplification > 0:
ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True)
ms.apply_filter('generate_surface_reconstruction_screened_poisson', threads = 6, depth = poissson_depth, preclean = True)
if simplification > 0:
ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True)
return meshlab_mesh_to_py3dmesh(ms.current_mesh())