LiDAR-Diffusion / sample_cond.py
Hancy's picture
modify on ZeroGPU
1615664
import os
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from lidm.models.diffusion.ddim import DDIMSampler
from lidm.utils.misc_utils import instantiate_from_config
from lidm.utils.lidar_utils import range2pcd
CUSTOM_STEPS = 50
ETA = 1.0
# model loading
MODEL_PATH = './models/lidm/kitti/cam2lidar'
CFG_PATH = os.path.join(MODEL_PATH, 'config.yaml')
CKPT_PATH = os.path.join(MODEL_PATH, 'model.ckpt')
# settings
MODEL_CFG = OmegaConf.load(CFG_PATH)
def custom_to_pcd(x, config, rgb=None):
x = x.squeeze().detach().cpu().numpy()
x = (np.clip(x, -1., 1.) + 1.) / 2.
if rgb is not None:
rgb = rgb.squeeze().detach().cpu().numpy()
rgb = (np.clip(rgb, -1., 1.) + 1.) / 2.
rgb = rgb.transpose(1, 2, 0)
xyz, rgb, _ = range2pcd(x, color=rgb, **config['data']['params']['dataset'])
return xyz, rgb
def custom_to_pil(x):
x = x.detach().cpu().squeeze().numpy()
x = (np.clip(x, -1., 1.) + 1.) / 2.
x = (255 * x).astype(np.uint8)
if x.ndim == 3:
x = x.transpose(1, 2, 0)
x = Image.fromarray(x)
return x
def logs2pil(logs, keys=["sample"]):
imgs = dict()
for k in logs:
try:
if len(logs[k].shape) == 4:
img = custom_to_pil(logs[k][0, ...])
elif len(logs[k].shape) == 3:
img = custom_to_pil(logs[k])
else:
print(f"Unknown format for key {k}. ")
img = None
except:
img = None
imgs[k] = img
return imgs
def load_model_from_config(config, sd):
model = instantiate_from_config(config)
model.load_state_dict(sd, strict=False)
model.eval()
return model
@torch.no_grad()
def convsample_ddim(model, cond, steps, shape, eta=1.0, verbose=False):
ddim = DDIMSampler(model)
bs = shape[0]
shape = shape[1:]
samples, intermediates = ddim.sample(steps, conditioning=cond, batch_size=bs, shape=shape, eta=eta, verbose=verbose, disable_tqdm=True)
return samples, intermediates
@torch.no_grad()
def make_convolutional_sample(model, batch, batch_size, custom_steps=None, eta=1.0):
xc = batch['camera']
c = model.get_learned_conditioning(xc.to(model.device))
with model.ema_scope("Plotting"):
samples, z_denoise_row = model.sample_log(cond=c, batch_size=batch_size, ddim=True,
ddim_steps=custom_steps, eta=eta)
x_samples = model.decode_first_stage(samples)
return x_samples
def sample(model, cond):
batch = {'camera': cond}
img = make_convolutional_sample(model, batch, batch_size=1, custom_steps=CUSTOM_STEPS, eta=ETA) # TODO add arguments for batch_size, custom_steps and eta
pcd = custom_to_pcd(img, MODEL_CFG)[0].astype(np.float32)
img = img.squeeze().detach().cpu().numpy()
return img, pcd