File size: 6,148 Bytes
37aeb5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from PIL import Image
import torch
import numpy as np
from pytorch3d.structures import Meshes
from pytorch3d.renderer import TexturesVertex
from scripts.utils import meshlab_mesh_to_py3dmesh, py3dmesh_to_meshlab_mesh
import pymeshlab

_MAX_THREAD = 8

# rgb and depth to mesh
def get_ortho_ray_directions_origins(W, H, use_pixel_centers=True, device="cuda"):
    pixel_center = 0.5 if use_pixel_centers else 0
    i, j = np.meshgrid(
        np.arange(W, dtype=np.float32) + pixel_center,
        np.arange(H, dtype=np.float32) + pixel_center,
        indexing='xy'
    )
    i, j = torch.from_numpy(i).to(device), torch.from_numpy(j).to(device)

    origins = torch.stack([(i/W-0.5)*2, (j/H-0.5)*2 * H / W, torch.zeros_like(i)], dim=-1) # W, H, 3
    directions = torch.stack([torch.zeros_like(i), torch.zeros_like(j), torch.ones_like(i)], dim=-1) # W, H, 3

    return origins, directions

def depth_and_color_to_mesh(rgb_BCHW, pred_HWC, valid_HWC=None, is_back=False):
    if valid_HWC is None:
        valid_HWC = torch.ones_like(pred_HWC).bool()
    H, W = rgb_BCHW.shape[-2:]
    rgb_BCHW = rgb_BCHW.flip(-2)
    pred_HWC = pred_HWC.flip(0)
    valid_HWC = valid_HWC.flip(0)
    rays_o, rays_d = get_ortho_ray_directions_origins(W, H, device=rgb_BCHW.device)
    verts = rays_o + rays_d * pred_HWC  # [H, W, 3]
    verts = verts.reshape(-1, 3)    # [V, 3]
    indexes = torch.arange(H * W).reshape(H, W).to(rgb_BCHW.device)
    faces1 = torch.stack([indexes[:-1, :-1], indexes[:-1, 1:], indexes[1:, :-1]], dim=-1)
    # faces1_valid = valid_HWC[:-1, :-1] | valid_HWC[:-1, 1:] | valid_HWC[1:, :-1]
    faces1_valid = valid_HWC[:-1, :-1] & valid_HWC[:-1, 1:] & valid_HWC[1:, :-1]
    faces2 = torch.stack([indexes[1:, 1:], indexes[1:, :-1], indexes[:-1, 1:]], dim=-1)
    # faces2_valid = valid_HWC[1:, 1:] | valid_HWC[1:, :-1] | valid_HWC[:-1, 1:]
    faces2_valid = valid_HWC[1:, 1:] & valid_HWC[1:, :-1] & valid_HWC[:-1, 1:]
    faces = torch.cat([faces1[faces1_valid.expand_as(faces1)].reshape(-1, 3), faces2[faces2_valid.expand_as(faces2)].reshape(-1, 3)], dim=0)  # (F, 3)
    colors = (rgb_BCHW[0].permute((1,2,0)) / 2 + 0.5).reshape(-1, 3)  # (V, 3)
    if is_back:
        verts = verts * torch.tensor([-1, 1, -1], dtype=verts.dtype, device=verts.device) 
    
    used_verts = faces.unique()
    old_to_new_mapping = torch.zeros_like(verts[..., 0]).long()
    old_to_new_mapping[used_verts] = torch.arange(used_verts.shape[0], device=verts.device)
    new_faces = old_to_new_mapping[faces]
    mesh = Meshes(verts=[verts[used_verts]], faces=[new_faces], textures=TexturesVertex(verts_features=[colors[used_verts]]))
    return mesh

def normalmap_to_depthmap(normal_np):
    from scripts.normal_to_height_map import estimate_height_map
    height = estimate_height_map(normal_np, raw_values=True, thread_count=_MAX_THREAD, target_iteration_count=96)
    return height

def transform_back_normal_to_front(normal_pil):
    arr = np.array(normal_pil)  # in [0, 255]
    arr[..., 0] = 255-arr[..., 0]
    arr[..., 2] = 255-arr[..., 2]
    return Image.fromarray(arr.astype(np.uint8))

def calc_w_over_h(normal_pil):
    if isinstance(normal_pil, Image.Image):
        arr = np.array(normal_pil)
    else:
        assert isinstance(normal_pil, np.ndarray)
        arr = normal_pil
    if arr.shape[-1] == 4:
        alpha = arr[..., -1] / 255.
        alpha[alpha >= 0.5] = 1
        alpha[alpha < 0.5] = 0
    else:
        alpha = ~(arr.min(axis=-1) >= 250)
    h_min, w_min = np.min(np.where(alpha), axis=1)
    h_max, w_max = np.max(np.where(alpha), axis=1)
    return (w_max - w_min) / (h_max - h_min)

def build_mesh(normal_pil, rgb_pil, is_back=False, clamp_min=-1, scale=0.3, init_type="std", offset=0):
    if is_back:
        normal_pil = transform_back_normal_to_front(normal_pil)
    normal_img = np.array(normal_pil)
    rgb_img = np.array(rgb_pil)
    if normal_img.shape[-1] == 4:
        valid_HWC = normal_img[..., [3]] / 255
    elif rgb_img.shape[-1] == 4:
        valid_HWC = rgb_img[..., [3]] / 255
    else:
        raise ValueError("invalid input, either normal or rgb should have alpha channel")
    
    real_height_pix = np.max(np.where(valid_HWC>0.5)[0]) - np.min(np.where(valid_HWC>0.5)[0])

    heights = normalmap_to_depthmap(normal_img)
    rgb_BCHW = torch.from_numpy(rgb_img[..., :3] / 255.).permute((2,0,1))[None]
    valid_HWC[valid_HWC < 0.5] = 0
    valid_HWC[valid_HWC >= 0.5] = 1
    valid_HWC = torch.from_numpy(valid_HWC).bool()
    if init_type == "std":
        # accurate but not stable
        pred_HWC = torch.from_numpy(heights / heights.max() * (real_height_pix / heights.shape[0]) * scale * 2).float()[..., None]
    elif init_type == "thin":
        heights = heights - heights.min()
        heights = (heights / heights.max() * 0.2)
        pred_HWC = torch.from_numpy(heights * scale).float()[..., None]
    else:
        # stable but not accurate
        heights = heights - heights.min()
        heights = (heights / heights.max() * (1-offset)) + offset # to [0.2, 1]
        pred_HWC = torch.from_numpy(heights * scale).float()[..., None]
    
    # set the boarder pixels to 0 height
    import cv2
    # edge filter
    edge = cv2.Canny((valid_HWC[..., 0] * 255).numpy().astype(np.uint8), 0, 255)
    edge = torch.from_numpy(edge).bool()[..., None]
    pred_HWC[edge] = 0
    
    valid_HWC[pred_HWC < clamp_min] = False
    return depth_and_color_to_mesh(rgb_BCHW.cuda(), pred_HWC.cuda(), valid_HWC.cuda(), is_back)

def fix_border_with_pymeshlab_fast(meshes: Meshes, poissson_depth=6, simplification=0):
    ms = pymeshlab.MeshSet()
    ms.add_mesh(py3dmesh_to_meshlab_mesh(meshes), "cube_vcolor_mesh")
    if simplification > 0:
        ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True)
    ms.apply_filter('generate_surface_reconstruction_screened_poisson', threads = 6, depth = poissson_depth, preclean = True)
    if simplification > 0:
        ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True)
    return meshlab_mesh_to_py3dmesh(ms.current_mesh())