Flux-TRELLIS / svrm /ldm /vis_util.py
gokaygokay's picture
Upload 93 files
0a88b62 verified
raw
history blame
2.92 kB
import os
from typing import List, Optional
from PIL import Image
import imageio
import time
import torch
from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj
from pytorch3d.ops import interpolate_face_attributes
from pytorch3d.common.datatypes import Device
from pytorch3d.structures import Meshes
from pytorch3d.vis.texture_vis import texturesuv_image_matplotlib
from pytorch3d.renderer import (
look_at_view_transform,
FoVPerspectiveCameras,
PointLights,
DirectionalLights,
AmbientLights,
Materials,
RasterizationSettings,
MeshRenderer,
MeshRasterizer,
SoftPhongShader,
TexturesUV,
TexturesVertex,
camera_position_from_spherical_angles,
BlendParams,
)
def render(
obj_filename,
elev=0,
azim=0,
resolution=512,
gif_dst_path='',
n_views=120,
fps=30,
device="cuda:0",
rgb=False
):
'''
obj_filename: path to obj file
gif_dst_path:
if set a path, will render n_views frames, then save it to a gif file
if not set, will render single frame, then return PIL.Image instance
rgb: if set true, will convert result to rgb image/frame
'''
# load mesh
mesh = load_objs_as_meshes([obj_filename], device=device)
meshes = mesh.extend(n_views)
if gif_dst_path != '':
elev = torch.linspace(elev, elev, n_views+1)[:-1]
azim = torch.linspace(0, 360, n_views+1)[:-1]
# prepare R,T then compute cameras
R, T = look_at_view_transform(dist=1.5, elev=elev, azim=azim)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=49.1)
# init pytorch3d renderer instance
renderer = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras,
raster_settings=RasterizationSettings(
image_size=resolution,
blur_radius=0.0,
faces_per_pixel=1,
),
),
shader=SoftPhongShader(
device=device,
cameras=cameras,
lights=AmbientLights(device=device),
blend_params=BlendParams(background_color=(1.0, 1.0, 1.0)),
)
)
images = renderer(meshes)
# single frame rendering
if gif_dst_path == '':
frame = images[0, ..., :3] if rgb else images[0, ...]
frame = Image.fromarray((frame.cpu().squeeze(0) * 255).numpy().astype("uint8"))
return frame
# orbit frames rendering
with imageio.get_writer(uri=gif_dst_path, mode='I', duration=1. / fps * 1000, loop=0) as writer:
for i in range(n_views):
frame = images[i, ..., :3] if rgb else images[i, ...]
frame = Image.fromarray((frame.cpu().squeeze(0) * 255).numpy().astype("uint8"))
writer.append_data(frame)
return gif_dst_path