JiantaoLin commited on
Commit
1e5535f
·
1 Parent(s): 30d56f8
Files changed (2) hide show
  1. shader.py +76 -0
  2. video_render.py +107 -94
shader.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pytorch3d.renderer.mesh.shader import ShaderBase
3
+ from pytorch3d.renderer import (
4
+ SoftPhongShader,
5
+ )
6
+
7
+ class MultiOutputShader(ShaderBase):
8
+ def __init__(self, device, cameras, lights, materials, ccm_scale=1.0, choices=None):
9
+ super().__init__()
10
+ self.device = device
11
+ self.cameras = cameras
12
+ self.lights = lights
13
+ self.materials = materials
14
+ self.ccm_scale = ccm_scale
15
+
16
+ if choices is None:
17
+ self.choices = ["rgb", "mask", "depth", "normal", "albedo", "ccm"]
18
+ else:
19
+ self.choices = choices
20
+
21
+ self.phong_shader = SoftPhongShader(
22
+ device=self.device,
23
+ cameras=self.cameras,
24
+ lights=self.lights,
25
+ materials=self.materials
26
+ )
27
+
28
+ def forward(self, fragments, meshes, **kwargs):
29
+ batch_size, H, W, _ = fragments.zbuf.shape
30
+ output = {}
31
+
32
+ if "rgb" in self.choices:
33
+ rgb_images = self.phong_shader(fragments, meshes, **kwargs)
34
+ rgb = rgb_images[..., :3]
35
+ output["rgb"] = rgb
36
+
37
+ if "mask" in self.choices:
38
+ alpha = rgb_images[..., 3:4]
39
+ mask = (alpha > 0).float()
40
+ output["mask"] = mask
41
+
42
+ if "albedo" in self.choices:
43
+ albedo = meshes.sample_textures(fragments)
44
+ output["albedo"] = albedo[..., 0, :]
45
+
46
+ if "depth" in self.choices:
47
+ depth = fragments.zbuf
48
+ output["depth"] = depth
49
+
50
+ if "normal" in self.choices:
51
+ pix_to_face = fragments.pix_to_face[..., 0]
52
+ bary_coords = fragments.bary_coords[..., 0, :]
53
+ valid_mask = pix_to_face >= 0
54
+ face_indices = pix_to_face[valid_mask]
55
+ faces_packed = meshes.faces_packed()
56
+ normals_packed = meshes.verts_normals_packed()
57
+ face_vertex_normals = normals_packed[faces_packed[face_indices]]
58
+ bary = bary_coords.view(-1, 3)[valid_mask.view(-1)]
59
+ interpolated_normals = (
60
+ bary[..., 0:1] * face_vertex_normals[:, 0, :] +
61
+ bary[..., 1:2] * face_vertex_normals[:, 1, :] +
62
+ bary[..., 2:3] * face_vertex_normals[:, 2, :]
63
+ )
64
+ interpolated_normals = interpolated_normals / interpolated_normals.norm(dim=-1, keepdim=True)
65
+ normal = torch.zeros(batch_size, H, W, 3, device=self.device)
66
+ normal[valid_mask] = interpolated_normals
67
+ output["normal"] = normal
68
+
69
+ if "ccm" in self.choices:
70
+ face_vertices = meshes.verts_packed()[meshes.faces_packed()]
71
+ faces_at_pixels = face_vertices[fragments.pix_to_face]
72
+ ccm = torch.sum(fragments.bary_coords.unsqueeze(-1) * faces_at_pixels, dim=-2)
73
+ ccm = (ccm[..., 0, :] * self.ccm_scale + 1) / 2
74
+ output["ccm"] = ccm
75
+
76
+ return output
video_render.py CHANGED
@@ -1,28 +1,28 @@
1
- import os
2
- import math
3
- import numpy as np
4
  import imageio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import trimesh
6
- import pyrender
7
  from tqdm import tqdm
8
- # os.environ["CUDA_VISIBLE_DEVICES"] = "7"
9
- os.environ['PYOPENGL_PLATFORM'] = 'egl' # 设置渲染环境为 EGL(无头模式)
10
-
11
- def render_video_from_obj(input_obj_path, output_video_path, fps=15, frame_count=60, resolution=(512, 512)):
12
- """
13
- Render a rotating 3D model (OBJ file) to a video with RGB and normal map side-by-side.
14
-
15
- Args:
16
- input_obj_path (str): Path to the input OBJ file.
17
- output_video_path (str): Path to save the output video.
18
- fps (int): Frames per second for the video.
19
- frame_count (int): Number of frames in the video.
20
- resolution (tuple): Resolution of the rendered video (width, height).
21
-
22
- Returns:
23
- str: Path to the output video.
24
- """
25
- # 检查输入文件是否存在
26
  if not os.path.exists(input_obj_path):
27
  raise FileNotFoundError(f"Input OBJ file not found: {input_obj_path}")
28
 
@@ -39,85 +39,98 @@ def render_video_from_obj(input_obj_path, output_video_path, fps=15, frame_count
39
  if not hasattr(mesh_data, 'vertex_normals') or mesh_data.vertex_normals is None:
40
  mesh_data.compute_vertex_normals()
41
 
42
- # 创建 Pyrender 场景并设置背景为白色
43
- render_scene = pyrender.Scene(bg_color=[1.0, 1.0, 1.0])
44
- mesh = pyrender.Mesh.from_trimesh(mesh_data, smooth=True)
45
- mesh_node = render_scene.add(mesh)
46
-
47
- # 设置摄像机参数
48
- camera = pyrender.PerspectiveCamera(yfov=np.deg2rad(30), znear=0.0001, zfar=100000.0)
49
- camera_pose = np.eye(4)
50
- camera_pose[2, 3] = 4.0 # 距离模型 20 个单位
51
- render_scene.add(camera, pose=camera_pose)
52
-
53
- # 添加全局环境光
54
- ambient_light = np.array([1.0, 1.0, 1.0]) * 2.0
55
- render_scene.ambient_light = ambient_light
56
 
57
- # 准备法线渲染场景
58
- normals = mesh_data.vertex_normals.copy()
59
-
60
- # 将法线映射到颜色范围 [0, 255]
61
- normal_colors = ((normals + 1) / 2 * 255)
62
-
63
- # 创建用于法线渲染的独立网格
64
- normal_mesh_data = mesh_data.copy()
65
- normal_mesh_data.visual.vertex_colors = np.hstack(
66
- [normal_colors, np.full((normals.shape[0], 1), 255, dtype=np.uint8)] # 添加 Alpha 通道
 
 
 
 
 
 
 
 
 
 
67
  )
68
 
69
- # 创建法线渲染场景
70
- normal_scene = pyrender.Scene(bg_color=[1.0, 1.0, 1.0, 1.0])
71
- normal_mesh = pyrender.Mesh.from_trimesh(normal_mesh_data, smooth=True)
72
- normal_mesh_node = normal_scene.add(normal_mesh)
73
- normal_scene.add(camera, pose=camera_pose)
74
- normal_scene.ambient_light = ambient_light
75
-
76
- # 初始化渲染器
77
- r = pyrender.OffscreenRenderer(*resolution)
78
-
79
- # 创建视频写入器
80
- writer = imageio.get_writer(output_video_path, fps=fps)
81
-
82
  # 渲染每一帧
83
- try:
84
- for frame_idx in tqdm(range(frame_count)):
85
- # 计算旋转角度
86
- angle = 2 * np.pi * frame_idx / frame_count
87
- rotation_matrix = np.array([
88
- [math.cos(angle), 0, math.sin(angle), 0],
89
- [0, 1, 0, 0],
90
- [-math.sin(angle), 0, math.cos(angle), 0],
91
- [0, 0, 0, 1]
92
- ])
93
-
94
- # 更新模型的姿态
95
- render_scene.set_pose(mesh_node, rotation_matrix)
96
-
97
- # 渲染 RGB 图像
98
- color, _ = r.render(render_scene)
99
-
100
- # 更新法线场景的姿态
101
- normal_scene.set_pose(normal_mesh_node, rotation_matrix)
102
-
103
- # 渲染法线图像
104
- normal, _ = r.render(normal_scene, flags=pyrender.RenderFlags.FLAT)
105
 
106
- # 拼接左右图像
107
- combined_frame = np.concatenate((color, normal), axis=1)
108
-
109
- # 写入视频帧
110
- writer.append_data(combined_frame)
111
- finally:
112
- # 释放资源
113
- writer.close()
114
- r.delete()
115
-
116
- print(f"Rendered video saved to {output_video_path}")
117
- return output_video_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  if __name__ == '__main__':
120
  # 示例调用
121
- input_obj_path = "output/gradio_cache/text_3D/_超级赛亚人_10/rgb_projected.obj"
122
  output_video_path = "output.mp4"
123
  render_video_from_obj(input_obj_path, output_video_path)
 
1
+ import pytorch3d
2
+ import torch
 
3
  import imageio
4
+ import numpy as np
5
+ import os
6
+ from pytorch3d.io import load_objs_as_meshes
7
+ from pytorch3d.renderer import (
8
+ AmbientLights,
9
+ PerspectiveCameras,
10
+ RasterizationSettings,
11
+ look_at_view_transform,
12
+ TexturesVertex,
13
+ MeshRenderer,
14
+ Materials,
15
+ MeshRasterizer,
16
+ SoftPhongShader,
17
+ PointLights
18
+ )
19
  import trimesh
 
20
  from tqdm import tqdm
21
+ from pytorch3d.transforms import RotateAxisAngle
22
+
23
+ from shader import MultiOutputShader
24
+
25
+ def render_video_from_obj(obj_path, output_video_path, num_frames=60, image_size=512, fps=30, device="cuda"):
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  if not os.path.exists(input_obj_path):
27
  raise FileNotFoundError(f"Input OBJ file not found: {input_obj_path}")
28
 
 
39
  if not hasattr(mesh_data, 'vertex_normals') or mesh_data.vertex_normals is None:
40
  mesh_data.compute_vertex_normals()
41
 
42
+ # 获取顶点坐标、法线和面
43
+ vertices = torch.tensor(mesh_data.vertices, dtype=torch.float32, device=device)
44
+ faces = torch.tensor(mesh_data.faces, dtype=torch.int64, device=device)
45
+ vertex_normals = torch.tensor(mesh_data.vertex_normals, dtype=torch.float32)
 
 
 
 
 
 
 
 
 
 
46
 
47
+ # 获取顶点颜色
48
+ if mesh_data.visual.vertex_colors is None:
49
+ # 如果没有顶点颜色,可以给定一个默认值(例如,白色)
50
+ vertex_colors = torch.ones_like(vertices)[None]
51
+ else:
52
+ vertex_colors = torch.tensor(mesh_data.visual.vertex_colors[:, :3], dtype=torch.float32)[None]
53
+ # 创建纹理并分配顶点颜色
54
+ textures = TexturesVertex(verts_features=vertex_colors)
55
+ textures.to(device)
56
+ # 创建Mesh对象
57
+ mesh = pytorch3d.structures.Meshes(verts=[vertices], faces=[faces], textures=textures)
58
+
59
+ # 设置渲染器
60
+ lights = AmbientLights(ambient_color=((3.0,)*3,), device=device)
61
+ # lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]], ambient_color=[[0.5, 0.5, 0.5]], diffuse_color=[[1.0, 1.0, 1.0]])
62
+ raster_settings = RasterizationSettings(
63
+ image_size=image_size, # 渲染图像的尺寸
64
+ blur_radius=0.0, # 默认无模糊
65
+ faces_per_pixel=1, # 每像素渲染一个面
66
+ # background_color=(1.0, 1.0, 1.0)
67
  )
68
 
69
+ # 设置旋转和渲染参数
70
+ frames = []
71
+ camera_distance = 6.5
72
+ elevs = 0.0
73
+ center = (0.0, 0.0, 0.0)
 
 
 
 
 
 
 
 
74
  # 渲染每一帧
75
+ materials = Materials(
76
+ device=device,
77
+ diffuse_color=((0.0, 0.0, 0.0),),
78
+ ambient_color=((1.0, 1.0, 1.0),),
79
+ specular_color=((0.0, 0.0, 0.0),),
80
+ shininess=0.0,
81
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ rasterizer = MeshRasterizer(raster_settings=raster_settings)
84
+ for i in tqdm(range(num_frames)):
85
+ azims = 360.0 * i / num_frames
86
+ R, T = look_at_view_transform(
87
+ dist=camera_distance,
88
+ elev=elevs,
89
+ azim=azims,
90
+ at=(center,),
91
+ degrees=True
92
+ )
93
+
94
+
95
+ # 手动设置相机的旋转矩阵
96
+ cameras = PerspectiveCameras(device=device, R=R, T=T, focal_length=5.0)
97
+ cameras.znear = 0.0001
98
+ cameras.zfar = 10000000.0
99
+ shader=MultiOutputShader(
100
+ device=device,
101
+ cameras=cameras,
102
+ lights=lights,
103
+ materials=materials,
104
+ choices=["rgb", "mask", "normal"]
105
+ )
106
+
107
+ renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
108
+ # 渲染RGB图像和Normal图像
109
+ render_result = renderer(mesh, cameras=cameras)
110
+ rgb_image = render_result["rgb"] * render_result["mask"] + (1 - render_result["mask"]) * torch.ones_like(render_result["rgb"]) * 255.0
111
+ normal_map = render_result["normal"]
112
+
113
+ # 提取RGB和Normal map
114
+ rgb = rgb_image[0, ..., :3].cpu().numpy() # RGB图像
115
+ normal_map = torch.nn.functional.normalize(normal_map, dim=-1) # Normal map
116
+ normal_map = (normal_map + 1) / 2
117
+ normal_map = normal_map * render_result["mask"] + (1 - render_result["mask"]) * torch.ones_like(render_result["normal"])
118
+ normal = normal_map[0, ..., :3].cpu().numpy() # Normal map
119
+ rgb = np.clip(rgb, 0, 255).astype(np.uint8)
120
+ normal = np.clip(normal*255, 0, 255).astype(np.uint8)
121
+ # 将RGB和Normal map合并为一张图,左边RGB,右边Normal map
122
+ combined_image = np.concatenate((rgb, normal), axis=1)
123
+
124
+ # 将合并后的图像加入到帧列表
125
+ frames.append(combined_image)
126
+
127
+ # 使用imageio保存视频
128
+ imageio.mimsave(output_video_path, frames, fps=fps)
129
+
130
+ print(f"Video saved to {output_video_path}")
131
 
132
  if __name__ == '__main__':
133
  # 示例调用
134
+ input_obj_path = "/hpc2hdd/home/jlin695/code/github/Kiss3DGen/outputs/a_owl_wearing_a_hat/ISOMER/rgb_projected.obj"
135
  output_video_path = "output.mp4"
136
  render_video_from_obj(input_obj_path, output_video_path)