Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
from einops import rearrange | |
from voxnerf.render import subpixel_rays_from_img | |
from run_sjc import ( | |
SJC, ScoreAdapter, StableDiffusion, | |
tqdm, EventStorage, HeartBeat, EarlyLoopBreak, get_event_storage, get_heartbeat, optional_load_config, read_stats, | |
vis_routine, stitch_vis, latest_ckpt, | |
scene_box_filter, render_ray_bundle, as_torch_tsrs, | |
device_glb | |
) | |
# the SD deocder is very memory hungry; the latent image cannot be too large | |
# for a graphics card with < 12 GB memory, set this to 128; quality already good | |
# if your card has 12 to 24 GB memory, you can set this to 200; | |
# but visually it won't help beyond a certain point. Our teaser is done with 128. | |
decoder_bottleneck_hw = 128 | |
def final_vis(): | |
cfg = optional_load_config(fname="full_config.yml") | |
assert len(cfg) > 0, "can't find cfg file" | |
mod = SJC(**cfg) | |
family = cfg.pop("family") | |
model: ScoreAdapter = getattr(mod, family).make() | |
vox = mod.vox.make() | |
poser = mod.pose.make() | |
pbar = tqdm(range(1)) | |
with EventStorage(), HeartBeat(pbar): | |
ckpt_fname = latest_ckpt() | |
state = torch.load(ckpt_fname, map_location="cpu") | |
vox.load_state_dict(state) | |
vox.to(device_glb) | |
with EventStorage("highres"): | |
# what dominates the speed is NOT the factor here. | |
# you can try from 2 to 8, and the speed is about the same. | |
# the dominating factor in the pipeline I believe is the SD decoder. | |
evaluate(model, vox, poser, n_frames=200, factor=4) | |
def evaluate(score_model, vox, poser, n_frames=200, factor=4): | |
H, W = poser.H, poser.W | |
vox.eval() | |
K, poses = poser.sample_test(n_frames) | |
del n_frames | |
poses = poses[60:] # skip the full overhead view; not interesting | |
fuse = EarlyLoopBreak(5) | |
metric = get_event_storage() | |
hbeat = get_heartbeat() | |
aabb = vox.aabb.T.cpu().numpy() | |
vox = vox.to(device_glb) | |
num_imgs = len(poses) | |
for i in (pbar := tqdm(range(num_imgs))): | |
if fuse.on_break(): | |
break | |
pose = poses[i] | |
y, depth = highres_render_one_view(vox, aabb, H, W, K, pose, f=factor) | |
if isinstance(score_model, StableDiffusion): | |
y = score_model.decode(y) | |
vis_routine(metric, y, depth) | |
metric.step() | |
hbeat.beat() | |
metric.flush_history() | |
metric.put_artifact( | |
"movie_im_and_depth", ".mp4", | |
lambda fn: stitch_vis(fn, read_stats(metric.output_dir, "view")[1]) | |
) | |
metric.put_artifact( | |
"movie_im_only", ".mp4", | |
lambda fn: stitch_vis(fn, read_stats(metric.output_dir, "img")[1]) | |
) | |
metric.step() | |
def highres_render_one_view(vox, aabb, H, W, K, pose, f=4): | |
bs = 4096 | |
ro, rd = subpixel_rays_from_img(H, W, K, pose, f=f) | |
ro, rd, t_min, t_max = scene_box_filter(ro, rd, aabb) | |
n = len(ro) | |
ro, rd, t_min, t_max = as_torch_tsrs(vox.device, ro, rd, t_min, t_max) | |
rgbs = torch.zeros(n, 4, device=vox.device) | |
depth = torch.zeros(n, 1, device=vox.device) | |
with torch.no_grad(): | |
for i in range(int(np.ceil(n / bs))): | |
s = i * bs | |
e = min(n, s + bs) | |
_rgbs, _depth, _ = render_ray_bundle( | |
vox, ro[s:e], rd[s:e], t_min[s:e], t_max[s:e] | |
) | |
rgbs[s:e] = _rgbs | |
depth[s:e] = _depth | |
rgbs = rearrange(rgbs, "(h w) c -> 1 c h w", h=H*f, w=W*f) | |
depth = rearrange(depth, "(h w) 1 -> h w", h=H*f, w=W*f) | |
rgbs = torch.nn.functional.interpolate( | |
rgbs, (decoder_bottleneck_hw, decoder_bottleneck_hw), | |
mode='bilinear', antialias=True | |
) | |
return rgbs, depth | |
if __name__ == "__main__": | |
final_vis() | |