abreza commited on
Commit
8e8cc15
1 Parent(s): 7297952

Update scripts/utils.py

Browse files
Files changed (1) hide show
  1. scripts/utils.py +252 -110
scripts/utils.py CHANGED
@@ -1,19 +1,18 @@
1
  import torch
2
  import numpy as np
3
  from PIL import Image
4
- import pymeshlab as ml
 
5
  from pytorch3d.renderer import TexturesVertex
6
  from pytorch3d.structures import Meshes
7
  from rembg import new_session, remove
8
- import trimesh
9
- from typing import List, Tuple
10
  import torch.nn.functional as F
 
 
11
 
12
  # Constants
13
- NEG_PROMPT = "sketch, sculpture, hand drawing, outline, single color, NSFW, lowres, bad anatomy, bad hands, text, error, missing fingers, yellow sleeves, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, (worst quality:1.4), (low quality:1.4)"
14
 
15
- # CUDA Configuration
16
- CUDA_PROVIDERS = [
17
  ('CUDAExecutionProvider', {
18
  'device_id': 0,
19
  'arena_extend_strategy': 'kSameAsRequested',
@@ -22,21 +21,22 @@ CUDA_PROVIDERS = [
22
  })
23
  ]
24
 
25
- # Initialize rembg session
26
- rembg_session = new_session(providers=CUDA_PROVIDERS)
 
 
 
 
27
 
28
- # Mesh Loading and Conversion Functions
29
  def load_mesh_with_trimesh(file_name, file_type=None):
30
  mesh = trimesh.load(file_name, file_type=file_type)
31
  if isinstance(mesh, trimesh.Scene):
32
  mesh = _process_trimesh_scene(mesh)
33
-
34
- vertices = torch.from_numpy(mesh.vertices).T
35
- faces = torch.from_numpy(mesh.faces).T
36
- colors = _get_mesh_colors(mesh)
37
-
38
  return vertices, faces, colors
39
 
 
40
  def _process_trimesh_scene(mesh):
41
  from io import BytesIO
42
  with BytesIO() as f:
@@ -46,31 +46,63 @@ def _process_trimesh_scene(mesh):
46
  if isinstance(mesh, trimesh.Scene):
47
  mesh = trimesh.util.concatenate(
48
  tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
49
- for g in mesh.geometry.values()))
50
  return mesh
51
 
 
 
 
 
 
 
 
 
52
  def _get_mesh_colors(mesh):
53
  if mesh.visual is not None and hasattr(mesh.visual, 'vertex_colors'):
54
  return torch.from_numpy(mesh.visual.vertex_colors)[..., :3].T / 255.
55
- return torch.ones_like(mesh.vertices.T) * 0.5
 
56
 
57
- # Mesh Conversion Functions
58
- def meshlab_mesh_to_py3dmesh(mesh: ml.Mesh) -> Meshes:
59
  verts = torch.from_numpy(mesh.vertex_matrix()).float()
60
  faces = torch.from_numpy(mesh.face_matrix()).long()
61
  colors = torch.from_numpy(mesh.vertex_color_matrix()[..., :3]).float()
62
  textures = TexturesVertex(verts_features=[colors])
63
  return Meshes(verts=[verts], faces=[faces], textures=textures)
64
 
65
- def py3dmesh_to_meshlab_mesh(meshes: Meshes) -> ml.Mesh:
66
- colors_in = F.pad(meshes.textures.verts_features_packed().cpu().float(), [0,1], value=1).numpy().astype(np.float64)
67
- return ml.Mesh(
 
 
68
  vertex_matrix=meshes.verts_packed().cpu().float().numpy().astype(np.float64),
69
  face_matrix=meshes.faces_packed().cpu().long().numpy().astype(np.int32),
70
- v_normals_matrix=meshes.verts_normals_packed().cpu().float().numpy().astype(np.float64),
 
71
  v_color_matrix=colors_in)
72
 
73
- # Normal Map Rotation Functions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  def rotate_normalmap_by_angle(normal_map: np.ndarray, angle: float):
75
  angle_rad = np.radians(angle)
76
  R = np.array([
@@ -80,105 +112,164 @@ def rotate_normalmap_by_angle(normal_map: np.ndarray, angle: float):
80
  ])
81
  return np.dot(normal_map.reshape(-1, 3), R.T).reshape(normal_map.shape)
82
 
83
- def rotate_normals(normal_pils, return_types='np', rotate_direction=1):
 
84
  n_views = len(normal_pils)
85
  ret = []
86
  for idx, rgba_normal in enumerate(normal_pils):
87
- normal_np = _process_normal_map(rgba_normal, idx, n_views, rotate_direction)
88
- ret.append(_format_output(normal_np, return_types))
 
 
 
89
  return ret
90
 
91
- def _process_normal_map(rgba_normal, idx, n_views, rotate_direction):
 
92
  normal_np = np.array(rgba_normal)[:, :, :3] / 255 * 2 - 1
93
  alpha_np = np.array(rgba_normal)[:, :, 3] / 255
94
- normal_np = rotate_normalmap_by_angle(normal_np, rotate_direction * idx * (360 / n_views))
95
- normal_np = (normal_np + 1) / 2 * alpha_np[..., None]
 
 
 
 
96
  return np.concatenate([normal_np * 255, alpha_np[:, :, None] * 255], axis=-1)
97
 
98
- def _format_output(normal_np, return_types):
 
99
  if return_types == 'np':
100
- return normal_np
101
  elif return_types == 'pil':
102
- return Image.fromarray(normal_np.astype(np.uint8))
103
  else:
104
- raise ValueError(f"return_types should be 'np' or 'pil', but got {return_types}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- # Background Change Functions
107
  def change_bkgd(img_pils, new_bkgd=(0., 0., 0.)):
108
  new_bkgd = np.array(new_bkgd).reshape(1, 1, 3)
109
- return [_process_image(rgba_img, new_bkgd) for rgba_img in img_pils]
 
110
 
111
- def _process_image(rgba_img, new_bkgd):
112
- img_np = np.array(rgba_img)[:, :, :3] / 255
113
- alpha_np = np.array(rgba_img)[:, :, 3] / 255
114
  ori_bkgd = img_np[:1, :1]
115
  alpha_np_clamp = np.clip(alpha_np, 1e-6, 1)
116
- ori_img_np = (img_np - ori_bkgd * (1 - alpha_np[..., None])) / alpha_np_clamp[..., None]
117
- img_np = np.where(alpha_np[..., None] > 0.05, ori_img_np * alpha_np[..., None] + new_bkgd * (1 - alpha_np[..., None]), new_bkgd)
118
- rgba_img_np = np.concatenate([img_np * 255, alpha_np[..., None] * 255], axis=-1)
 
 
 
119
  return Image.fromarray(rgba_img_np.astype(np.uint8))
120
 
121
- # Mesh Cleaning Function
122
- def simple_clean_mesh(pyml_mesh: ml.Mesh, apply_smooth=True, stepsmoothnum=1, apply_sub_divide=False, sub_divide_threshold=0.25):
123
- ms = ml.MeshSet()
124
- ms.add_mesh(pyml_mesh, "cube_mesh")
125
-
126
- if apply_smooth:
127
- ms.apply_filter("apply_coord_laplacian_smoothing", stepsmoothnum=stepsmoothnum, cotangentweight=False)
128
- if apply_sub_divide:
129
- ms.apply_filter("meshing_repair_non_manifold_vertices")
130
- ms.apply_filter("meshing_repair_non_manifold_edges", method='Remove Faces')
131
- ms.apply_filter("meshing_surface_subdivision_loop", iterations=2, threshold=ml.PercentageValue(sub_divide_threshold))
132
- return meshlab_mesh_to_py3dmesh(ms.current_mesh())
133
 
134
- # Image Processing Functions
135
- def expand2square(pil_img, background_color):
136
- width, height = pil_img.size
137
- if width == height:
138
- return pil_img
139
- new_size = max(width, height)
140
- result = Image.new(pil_img.mode, (new_size, new_size), background_color)
141
- offset = ((new_size - width) // 2, (new_size - height) // 2)
142
- result.paste(pil_img, offset)
143
- return result
144
 
145
- def simple_preprocess(input_image, rembg_session=rembg_session, background_color=255):
146
- RES = 2048
147
- input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS)
148
- if input_image.mode != 'RGBA':
149
- image_rem = input_image.convert('RGBA')
150
- input_image = remove(image_rem, alpha_matting=False, session=rembg_session)
151
 
152
- arr = np.asarray(input_image)
153
- alpha = arr[:, :, -1]
154
- x_nonzero, y_nonzero = np.nonzero(alpha > 60)
155
- x_min, x_max = x_nonzero.min(), x_nonzero.max()
156
- y_min, y_max = y_nonzero.min(), y_nonzero.max()
157
- arr = arr[x_min:x_max+1, y_min:y_max+1]
158
- input_image = Image.fromarray(arr)
159
- return expand2square(input_image, (background_color, background_color, background_color, 0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
- # Mesh Saving Functions
162
  def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True):
163
- vertices = meshes.verts_packed().cpu().float().numpy()
164
- triangles = meshes.faces_packed().cpu().long().numpy()
165
- np_color = meshes.textures.verts_features_packed().cpu().float().numpy()
166
-
167
  if save_glb_path.endswith(".glb"):
168
  vertices[:, [0, 2]] = -vertices[:, [0, 2]]
169
 
170
  if apply_sRGB_to_LinearRGB:
171
  np_color = srgb_to_linear(np_color)
172
-
173
- mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color)
 
174
  mesh.remove_unreferenced_vertices()
175
  mesh.export(save_glb_path)
176
-
177
  if save_glb_path.endswith(".glb"):
178
  fix_vert_color_glb(save_glb_path)
179
- print(f"Saved to {save_glb_path}")
180
 
181
- def save_glb_and_video(save_mesh_prefix: str, meshes: Meshes, with_timestamp=True, **kwargs) -> Tuple[str, str]:
 
 
 
 
 
 
 
 
 
 
 
 
182
  import time
183
  if '.' in save_mesh_prefix:
184
  save_mesh_prefix = ".".join(save_mesh_prefix.split('.')[:-1])
@@ -188,38 +279,89 @@ def save_glb_and_video(save_mesh_prefix: str, meshes: Meshes, with_timestamp=Tru
188
  save_py3dmesh_with_trimesh_fast(meshes, ret_mesh)
189
  return ret_mesh, None
190
 
191
- # Utility Functions
192
- def srgb_to_linear(c_srgb):
193
- return np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4).clip(0, 1.)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- def fix_vert_color_glb(mesh_path):
196
- from pygltflib import GLTF2, Material, PbrMetallicRoughness
197
- obj1 = GLTF2().load(mesh_path)
198
- obj1.meshes[0].primitives[0].material = 0
199
- obj1.materials.append(Material(
200
- pbrMetallicRoughness = PbrMetallicRoughness(
201
- baseColorFactor = [1.0, 1.0, 1.0, 1.0],
202
- metallicFactor = 0.,
203
- roughnessFactor = 1.0,
204
- ),
205
- emissiveFactor = [0.0, 0.0, 0.0],
206
- doubleSided = True,
207
- ))
208
- obj1.save(mesh_path)
209
 
210
  def init_target(img_pils, new_bkgd=(0., 0., 0.), device="cuda"):
211
- new_bkgd = torch.tensor(new_bkgd, dtype=torch.float32).view(1, 1, 3).to(device)
212
- imgs = torch.stack([torch.from_numpy(np.array(img, dtype=np.float32)) for img in img_pils]).to(device) / 255
213
- img_nps, alpha_nps = imgs[..., :3], imgs[..., 3]
 
 
 
 
214
  ori_bkgds = img_nps[:, :1, :1]
215
-
216
  alpha_nps_clamp = torch.clamp(alpha_nps, 1e-6, 1)
217
- ori_img_nps = (img_nps - ori_bkgds * (1 - alpha_nps.unsqueeze(-1))) / alpha_nps_clamp.unsqueeze(-1)
 
218
  ori_img_nps = torch.clamp(ori_img_nps, 0, 1)
219
- img_nps = torch.where(alpha_nps.unsqueeze(-1) > 0.05, ori_img_nps * alpha_nps.unsqueeze(-1) + new_bkgd * (1 - alpha_nps.unsqueeze(-1)), new_bkgd)
 
 
 
 
 
220
 
221
  return torch.cat([img_nps, alpha_nps.unsqueeze(-1)], dim=-1)
222
 
 
223
  def save_obj_and_video(save_mesh_prefix: str, meshes: Meshes, with_timestamp=True, **kwargs) -> Tuple[str, str]:
224
  if '.' in save_mesh_prefix:
225
  save_mesh_prefix = ".".join(save_mesh_prefix.split('.')[:-1])
 
1
  import torch
2
  import numpy as np
3
  from PIL import Image
4
+ import pymeshlab
5
+ import trimesh
6
  from pytorch3d.renderer import TexturesVertex
7
  from pytorch3d.structures import Meshes
8
  from rembg import new_session, remove
 
 
9
  import torch.nn.functional as F
10
+ from typing import List, Tuple
11
+ from pygltflib import GLTF2, Material, PbrMetallicRoughness
12
 
13
  # Constants
 
14
 
15
+ providers = [
 
16
  ('CUDAExecutionProvider', {
17
  'device_id': 0,
18
  'arena_extend_strategy': 'kSameAsRequested',
 
21
  })
22
  ]
23
 
24
+ session = new_session(providers=providers)
25
+
26
+ NEG_PROMPT = "sketch, sculpture, hand drawing, outline, single color, NSFW, lowres, bad anatomy,bad hands, text, error, missing fingers, yellow sleeves, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry,(worst quality:1.4),(low quality:1.4)"
27
+
28
+ # Helper functions
29
+
30
 
 
31
  def load_mesh_with_trimesh(file_name, file_type=None):
32
  mesh = trimesh.load(file_name, file_type=file_type)
33
  if isinstance(mesh, trimesh.Scene):
34
  mesh = _process_trimesh_scene(mesh)
35
+
36
+ vertices, faces, colors = _extract_mesh_data(mesh)
 
 
 
37
  return vertices, faces, colors
38
 
39
+
40
  def _process_trimesh_scene(mesh):
41
  from io import BytesIO
42
  with BytesIO() as f:
 
46
  if isinstance(mesh, trimesh.Scene):
47
  mesh = trimesh.util.concatenate(
48
  tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
49
+ for g in mesh.geometry.values()))
50
  return mesh
51
 
52
+
53
+ def _extract_mesh_data(mesh):
54
+ vertices = torch.from_numpy(mesh.vertices).T
55
+ faces = torch.from_numpy(mesh.faces).T
56
+ colors = _get_mesh_colors(mesh)
57
+ return vertices, faces, colors
58
+
59
+
60
  def _get_mesh_colors(mesh):
61
  if mesh.visual is not None and hasattr(mesh.visual, 'vertex_colors'):
62
  return torch.from_numpy(mesh.visual.vertex_colors)[..., :3].T / 255.
63
+ return torch.ones_like(mesh.vertices).T * 0.5
64
+
65
 
66
+ def meshlab_mesh_to_py3dmesh(mesh: pymeshlab.Mesh) -> Meshes:
 
67
  verts = torch.from_numpy(mesh.vertex_matrix()).float()
68
  faces = torch.from_numpy(mesh.face_matrix()).long()
69
  colors = torch.from_numpy(mesh.vertex_color_matrix()[..., :3]).float()
70
  textures = TexturesVertex(verts_features=[colors])
71
  return Meshes(verts=[verts], faces=[faces], textures=textures)
72
 
73
+
74
+ def py3dmesh_to_meshlab_mesh(meshes: Meshes) -> pymeshlab.Mesh:
75
+ colors_in = F.pad(meshes.textures.verts_features_packed().cpu().float(), [
76
+ 0, 1], value=1).numpy().astype(np.float64)
77
+ return pymeshlab.Mesh(
78
  vertex_matrix=meshes.verts_packed().cpu().float().numpy().astype(np.float64),
79
  face_matrix=meshes.faces_packed().cpu().long().numpy().astype(np.int32),
80
+ v_normals_matrix=meshes.verts_normals_packed(
81
+ ).cpu().float().numpy().astype(np.float64),
82
  v_color_matrix=colors_in)
83
 
84
+
85
+ def to_pyml_mesh(vertices, faces):
86
+ return pymeshlab.Mesh(
87
+ vertex_matrix=vertices.cpu().float().numpy().astype(np.float64),
88
+ face_matrix=faces.cpu().long().numpy().astype(np.int32),
89
+ )
90
+
91
+
92
+ def to_py3d_mesh(vertices, faces, normals=None):
93
+ mesh = Meshes(verts=[vertices], faces=[faces], textures=None)
94
+ if normals is None:
95
+ normals = mesh.verts_normals_packed()
96
+ mesh.textures = TexturesVertex(verts_features=[normals / 2 + 0.5])
97
+ return mesh
98
+
99
+
100
+ def from_py3d_mesh(mesh):
101
+ return mesh.verts_list()[0], mesh.faces_list()[0], mesh.textures.verts_features_packed()
102
+
103
+ # Normal map rotation functions
104
+
105
+
106
  def rotate_normalmap_by_angle(normal_map: np.ndarray, angle: float):
107
  angle_rad = np.radians(angle)
108
  R = np.array([
 
112
  ])
113
  return np.dot(normal_map.reshape(-1, 3), R.T).reshape(normal_map.shape)
114
 
115
+
116
+ def rotate_normals(normal_pils, return_types='np', rotate_direction=1) -> np.ndarray:
117
  n_views = len(normal_pils)
118
  ret = []
119
  for idx, rgba_normal in enumerate(normal_pils):
120
+ normal_np, alpha_np = _process_normal_image(rgba_normal)
121
+ normal_np = rotate_normalmap_by_angle(
122
+ normal_np, rotate_direction * idx * (360 / n_views))
123
+ rgba_normal_np = _combine_normal_and_alpha(normal_np, alpha_np)
124
+ ret.append(_format_output(rgba_normal_np, return_types))
125
  return ret
126
 
127
+
128
+ def _process_normal_image(rgba_normal):
129
  normal_np = np.array(rgba_normal)[:, :, :3] / 255 * 2 - 1
130
  alpha_np = np.array(rgba_normal)[:, :, 3] / 255
131
+ return normal_np, alpha_np
132
+
133
+
134
+ def _combine_normal_and_alpha(normal_np, alpha_np):
135
+ normal_np = (normal_np + 1) / 2
136
+ normal_np = normal_np * alpha_np[..., None]
137
  return np.concatenate([normal_np * 255, alpha_np[:, :, None] * 255], axis=-1)
138
 
139
+
140
+ def _format_output(rgba_normal_np, return_types):
141
  if return_types == 'np':
142
+ return rgba_normal_np
143
  elif return_types == 'pil':
144
+ return Image.fromarray(rgba_normal_np.astype(np.uint8))
145
  else:
146
+ raise ValueError(
147
+ f"return_types should be 'np' or 'pil', but got {return_types}")
148
+
149
+
150
+ def rotate_normalmap_by_angle_torch(normal_map, angle):
151
+ angle_rad = torch.tensor(np.radians(angle)).to(normal_map)
152
+ R = torch.tensor([
153
+ [torch.cos(angle_rad), 0, torch.sin(angle_rad)],
154
+ [0, 1, 0],
155
+ [-torch.sin(angle_rad), 0, torch.cos(angle_rad)]
156
+ ]).to(normal_map)
157
+ return torch.matmul(normal_map.view(-1, 3), R.T).view(normal_map.shape)
158
+
159
+
160
+ def do_rotate(rgba_normal, angle):
161
+ rgba_normal = torch.from_numpy(rgba_normal).float().cuda() / 255
162
+ rotated_normal_tensor = rotate_normalmap_by_angle_torch(
163
+ rgba_normal[..., :3] * 2 - 1, angle)
164
+ rotated_normal_tensor = (rotated_normal_tensor + 1) / 2
165
+ rotated_normal_tensor = rotated_normal_tensor * rgba_normal[:, :, [3]]
166
+ return torch.cat([rotated_normal_tensor * 255, rgba_normal[:, :, [3]] * 255], dim=-1).cpu().numpy()
167
+
168
+
169
+ def rotate_normals_torch(normal_pils, return_types='np', rotate_direction=1):
170
+ n_views = len(normal_pils)
171
+ ret = []
172
+ for idx, rgba_normal in enumerate(normal_pils):
173
+ angle = rotate_direction * idx * (360 / n_views)
174
+ rgba_normal_np = do_rotate(np.array(rgba_normal), angle)
175
+ ret.append(_format_output(rgba_normal_np, return_types))
176
+ return ret
177
+
178
+ # Background change functions
179
+
180
 
 
181
  def change_bkgd(img_pils, new_bkgd=(0., 0., 0.)):
182
  new_bkgd = np.array(new_bkgd).reshape(1, 1, 3)
183
+ return [_change_single_image_bkgd(rgba_img, new_bkgd) for rgba_img in img_pils]
184
+
185
 
186
+ def _change_single_image_bkgd(rgba_img, new_bkgd):
187
+ img_np, alpha_np = np.array(
188
+ rgba_img)[:, :, :3] / 255, np.array(rgba_img)[:, :, 3] / 255
189
  ori_bkgd = img_np[:1, :1]
190
  alpha_np_clamp = np.clip(alpha_np, 1e-6, 1)
191
+ ori_img_np = (img_np - ori_bkgd *
192
+ (1 - alpha_np[..., None])) / alpha_np_clamp[..., None]
193
+ img_np = np.where(alpha_np[..., None] > 0.05, ori_img_np *
194
+ alpha_np[..., None] + new_bkgd * (1 - alpha_np[..., None]), new_bkgd)
195
+ rgba_img_np = np.concatenate(
196
+ [img_np * 255, alpha_np[..., None] * 255], axis=-1)
197
  return Image.fromarray(rgba_img_np.astype(np.uint8))
198
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
+ def change_bkgd_to_normal(normal_pils) -> List[Image.Image]:
201
+ n_views = len(normal_pils)
202
+ return [_change_single_normal_bkgd(rgba_normal, idx, n_views) for idx, rgba_normal in enumerate(normal_pils)]
 
 
 
 
 
 
 
203
 
 
 
 
 
 
 
204
 
205
+ def _change_single_normal_bkgd(rgba_normal, idx, n_views):
206
+ target_bkgd = rotate_normalmap_by_angle(
207
+ np.array([[[0., 0., 1.]]]), idx * (360 / n_views))
208
+ normal_np, alpha_np = np.array(
209
+ rgba_normal)[:, :, :3] / 255 * 2 - 1, np.array(rgba_normal)[:, :, 3] / 255
210
+ old_bkgd = normal_np[:1, :1]
211
+ normal_np[alpha_np > 0.05] = (normal_np[alpha_np > 0.05] - old_bkgd * (
212
+ 1 - alpha_np[alpha_np > 0.05][..., None])) / alpha_np[alpha_np > 0.05][..., None]
213
+ normal_np = normal_np * alpha_np[..., None] + \
214
+ target_bkgd * (1 - alpha_np[..., None])
215
+ normal_np = (normal_np + 1) / 2
216
+ rgba_normal_np = np.concatenate(
217
+ [normal_np * 255, alpha_np[..., None] * 255], axis=-1)
218
+ return Image.fromarray(rgba_normal_np.astype(np.uint8))
219
+
220
+ # Mesh and GLB handling functions
221
+
222
+
223
+ def fix_vert_color_glb(mesh_path):
224
+ obj1 = GLTF2().load(mesh_path)
225
+ obj1.meshes[0].primitives[0].material = 0
226
+ obj1.materials.append(Material(
227
+ pbrMetallicRoughness=PbrMetallicRoughness(
228
+ baseColorFactor=[1.0, 1.0, 1.0, 1.0],
229
+ metallicFactor=0.,
230
+ roughnessFactor=1.0,
231
+ ),
232
+ emissiveFactor=[0.0, 0.0, 0.0],
233
+ doubleSided=True,
234
+ ))
235
+ obj1.save(mesh_path)
236
+
237
+
238
+ def srgb_to_linear(c_srgb):
239
+ return np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4).clip(0, 1.)
240
+
241
 
 
242
  def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True):
243
+ vertices, triangles, np_color = _extract_mesh_data_for_trimesh(meshes)
244
+
 
 
245
  if save_glb_path.endswith(".glb"):
246
  vertices[:, [0, 2]] = -vertices[:, [0, 2]]
247
 
248
  if apply_sRGB_to_LinearRGB:
249
  np_color = srgb_to_linear(np_color)
250
+
251
+ mesh = trimesh.Trimesh(
252
+ vertices=vertices, faces=triangles, vertex_colors=np_color)
253
  mesh.remove_unreferenced_vertices()
254
  mesh.export(save_glb_path)
255
+
256
  if save_glb_path.endswith(".glb"):
257
  fix_vert_color_glb(save_glb_path)
258
+ print(f"saving to {save_glb_path}")
259
 
260
+
261
+ def _extract_mesh_data_for_trimesh(meshes):
262
+ vertices = meshes.verts_packed().cpu().float().numpy()
263
+ triangles = meshes.faces_packed().cpu().long().numpy()
264
+ np_color = meshes.textures.verts_features_packed().cpu().float().numpy()
265
+ assert vertices.shape[0] == np_color.shape[0]
266
+ assert np_color.shape[1] == 3
267
+ assert 0 <= np_color.min() and np_color.max(
268
+ ) <= 1, f"min={np_color.min()}, max={np_color.max()}"
269
+ return vertices, triangles, np_color
270
+
271
+
272
+ def save_glb_and_video(save_mesh_prefix: str, meshes: Meshes, with_timestamp=True, dist=3.5, azim_offset=180, resolution=512, fov_in_degrees=1 / 1.15, cam_type="ortho", view_padding=60, export_video=True) -> Tuple[str, str]:
273
  import time
274
  if '.' in save_mesh_prefix:
275
  save_mesh_prefix = ".".join(save_mesh_prefix.split('.')[:-1])
 
279
  save_py3dmesh_with_trimesh_fast(meshes, ret_mesh)
280
  return ret_mesh, None
281
 
282
+ # Mesh cleaning and preprocessing functions (continued)
283
+
284
+
285
+ def simple_clean_mesh(pyml_mesh: pymeshlab.Mesh, apply_smooth=True, stepsmoothnum=1, apply_sub_divide=False, sub_divide_threshold=0.25):
286
+ ms = pymeshlab.MeshSet()
287
+ ms.add_mesh(pyml_mesh, "cube_mesh")
288
+
289
+ if apply_smooth:
290
+ ms.apply_filter("apply_coord_laplacian_smoothing",
291
+ stepsmoothnum=stepsmoothnum, cotangentweight=False)
292
+
293
+ if apply_sub_divide:
294
+ ms.apply_filter("meshing_repair_non_manifold_vertices")
295
+ ms.apply_filter("meshing_repair_non_manifold_edges",
296
+ method='Remove Faces')
297
+ ms.apply_filter("meshing_surface_subdivision_loop", iterations=2,
298
+ threshold=pymeshlab.PercentageValue(sub_divide_threshold))
299
+
300
+ return meshlab_mesh_to_py3dmesh(ms.current_mesh())
301
+
302
+
303
+ def expand2square(pil_img, background_color):
304
+ width, height = pil_img.size
305
+ if width == height:
306
+ return pil_img
307
+
308
+ new_size = max(width, height)
309
+ result = Image.new(pil_img.mode, (new_size, new_size), background_color)
310
+
311
+ if width > height:
312
+ result.paste(pil_img, (0, (width - height) // 2))
313
+ else:
314
+ result.paste(pil_img, ((height - width) // 2, 0))
315
+
316
+ return result
317
+
318
+
319
+ def simple_preprocess(input_image, rembg_session=session, background_color=255):
320
+ RES = 2048
321
+ input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS)
322
+
323
+ if input_image.mode != 'RGBA':
324
+ image_rem = input_image.convert('RGBA')
325
+ input_image = remove(
326
+ image_rem, alpha_matting=False, session=rembg_session)
327
+
328
+ arr = np.asarray(input_image)
329
+ alpha = arr[:, :, -1]
330
+
331
+ x_nonzero, y_nonzero = (alpha > 60).sum(axis=1).nonzero()[
332
+ 0], (alpha > 60).sum(axis=0).nonzero()[0]
333
+ x_min, x_max = int(x_nonzero.min()), int(x_nonzero.max())
334
+ y_min, y_max = int(y_nonzero.min()), int(y_nonzero.max())
335
+
336
+ arr = arr[x_min:x_max, y_min:y_max]
337
+ input_image = Image.fromarray(arr)
338
+ return expand2square(input_image, (background_color, background_color, background_color, 0))
339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
  def init_target(img_pils, new_bkgd=(0., 0., 0.), device="cuda"):
342
+ new_bkgd = torch.tensor(
343
+ new_bkgd, dtype=torch.float32).view(1, 1, 3).to(device)
344
+ imgs = torch.stack([torch.from_numpy(np.array(img, dtype=np.float32))
345
+ for img in img_pils]).to(device) / 255
346
+
347
+ img_nps = imgs[..., :3]
348
+ alpha_nps = imgs[..., 3]
349
  ori_bkgds = img_nps[:, :1, :1]
350
+
351
  alpha_nps_clamp = torch.clamp(alpha_nps, 1e-6, 1)
352
+ ori_img_nps = (img_nps - ori_bkgds * (1 - alpha_nps.unsqueeze(-1))
353
+ ) / alpha_nps_clamp.unsqueeze(-1)
354
  ori_img_nps = torch.clamp(ori_img_nps, 0, 1)
355
+
356
+ img_nps = torch.where(alpha_nps.unsqueeze(-1) > 0.05,
357
+ ori_img_nps *
358
+ alpha_nps.unsqueeze(-1) + new_bkgd *
359
+ (1 - alpha_nps.unsqueeze(-1)),
360
+ new_bkgd)
361
 
362
  return torch.cat([img_nps, alpha_nps.unsqueeze(-1)], dim=-1)
363
 
364
+
365
  def save_obj_and_video(save_mesh_prefix: str, meshes: Meshes, with_timestamp=True, **kwargs) -> Tuple[str, str]:
366
  if '.' in save_mesh_prefix:
367
  save_mesh_prefix = ".".join(save_mesh_prefix.split('.')[:-1])