File size: 4,827 Bytes
1e5535f
 
02a9751
1e5535f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02a9751
 
1e5535f
 
 
 
6dafd85
 
 
 
 
 
 
 
 
7cad53a
02a9751
 
 
 
 
 
 
 
 
 
 
 
 
1e5535f
 
02a9751
1e5535f
 
 
 
 
 
 
 
ded6c2a
1e5535f
 
235efa3
 
 
02a9751
 
1e5535f
 
 
 
 
 
235efa3
1e5535f
235efa3
1e5535f
 
02a9751
1e5535f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6dafd85
1e5535f
 
 
 
6dafd85
 
 
1e5535f
 
235efa3
1e5535f
 
 
 
 
 
 
 
 
 
 
 
 
02a9751
 
6dafd85
02a9751
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import pytorch3d
import torch
import imageio
import numpy as np
import os
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.renderer import (
    AmbientLights,
    PerspectiveCameras, 
    RasterizationSettings, 
    look_at_view_transform,
    TexturesVertex,
    MeshRenderer, 
    Materials,
    MeshRasterizer, 
    SoftPhongShader, 
    PointLights
)
import trimesh
from tqdm import tqdm
from pytorch3d.transforms import RotateAxisAngle

from shader import MultiOutputShader

def _rgb_to_srgb(f: torch.Tensor) -> torch.Tensor:
    return torch.where(f <= 0.0031308, f * 12.92, torch.pow(torch.clamp(f, 0.0031308), 1.0/2.4)*1.055 - 0.055)

def rgb_to_srgb(f: torch.Tensor) -> torch.Tensor:
    assert f.shape[-1] == 3 or f.shape[-1] == 4
    out = torch.cat((_rgb_to_srgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _rgb_to_srgb(f)
    assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1]
    return out

def render_video_from_obj(input_obj_path, output_video_path, num_frames=60, image_size=512, fps=15, device="cuda"):
    if not os.path.exists(input_obj_path):
        raise FileNotFoundError(f"Input OBJ file not found: {input_obj_path}")

    scene_data = trimesh.load(input_obj_path)

    if isinstance(scene_data, trimesh.Scene):
        mesh_data = trimesh.util.concatenate([geom for geom in scene_data.geometry.values()])
    else:
        mesh_data = scene_data

    if not hasattr(mesh_data, 'vertex_normals') or mesh_data.vertex_normals is None:
        mesh_data.compute_vertex_normals()

    vertices = torch.tensor(mesh_data.vertices, dtype=torch.float32, device=device)
    faces = torch.tensor(mesh_data.faces, dtype=torch.int64, device=device)

    if mesh_data.visual.vertex_colors is None:
        vertex_colors = torch.ones_like(vertices)[None]
    else:
        vertex_colors = torch.tensor(mesh_data.visual.vertex_colors[:, :3], dtype=torch.float32)[None]
    textures = TexturesVertex(verts_features=vertex_colors)
    textures.to(device)
    mesh = pytorch3d.structures.Meshes(verts=[vertices], faces=[faces], textures=textures)

    lights = AmbientLights(ambient_color=((2.0,)*3,), device=device)
    # lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]], ambient_color=[[0.5, 0.5, 0.5]], diffuse_color=[[1.0, 1.0, 1.0]])
    raster_settings = RasterizationSettings(
        image_size=image_size,
        blur_radius=0.0,
        faces_per_pixel=1,
    )

    frames = []
    camera_distance = 6.5
    elevs = 0.0
    center = (0.0, 0.0, 0.0)
    materials = Materials(
            device=device,
            diffuse_color=((1.0, 1.0, 1.0),),
            ambient_color=((1.0, 1.0, 1.0),),
            specular_color=((1.0, 1.0, 1.0),),
            shininess=0.0,
    )
        
    rasterizer = MeshRasterizer(raster_settings=raster_settings)
    for i in tqdm(range(num_frames)):
        azims = 360.0 * i / num_frames
        R, T = look_at_view_transform(
            dist=camera_distance,
            elev=elevs,
            azim=azims,
            at=(center,),
            degrees=True
        )

        # 手动设置相机的旋转矩阵
        cameras = PerspectiveCameras(device=device, R=R, T=T, focal_length=5.0)
        cameras.znear = 0.0001
        cameras.zfar = 10000000.0
        shader=MultiOutputShader(
                device=device,
                cameras=cameras,
                lights=lights,
                materials=materials,
                choices=["rgb", "mask", "normal", "albedo"]
            )

        renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
        render_result = renderer(mesh, cameras=cameras)

        render_result["albedo"] = rgb_to_srgb(render_result["albedo"]/255.0)*255.0
        rgb_image = render_result["albedo"] * render_result["mask"] + (1 - render_result["mask"]) * torch.ones_like(render_result["albedo"]) * 255.0
        normal_map = render_result["normal"]

        rgb = rgb_image[0, ..., :3].cpu().numpy()
        normal_map = torch.nn.functional.normalize(normal_map, dim=-1)  # Normal map
        normal_map = (normal_map + 1) / 2
        normal_map = normal_map * render_result["mask"] + (1 - render_result["mask"]) * torch.ones_like(render_result["normal"])
        normal = normal_map[0, ..., :3].cpu().numpy()  # Normal map
        rgb = np.clip(rgb, 0, 255).astype(np.uint8)
        normal = np.clip(normal*255, 0, 255).astype(np.uint8)
        combined_image = np.concatenate((rgb, normal), axis=1)

        frames.append(combined_image)

    imageio.mimsave(output_video_path, frames, fps=fps)

    print(f"Video saved to {output_video_path}")

if __name__ == '__main__':
    input_obj_path = "./354e2aee-091d-4dc6-bdb1-e09be5791218_isomer_recon_mesh.obj"
    output_video_path = "output.mp4"
    render_video_from_obj(input_obj_path, output_video_path)