import os import torch import numpy as np from omegaconf import OmegaConf from PIL import Image from lidm.models.diffusion.ddim import DDIMSampler from lidm.utils.misc_utils import instantiate_from_config from lidm.utils.lidar_utils import range2pcd CUSTOM_STEPS = 50 ETA = 1.0 # model loading MODEL_PATH = './models/lidm/kitti/cam2lidar' CFG_PATH = os.path.join(MODEL_PATH, 'config.yaml') CKPT_PATH = os.path.join(MODEL_PATH, 'model.ckpt') # settings MODEL_CFG = OmegaConf.load(CFG_PATH) def custom_to_pcd(x, config, rgb=None): x = x.squeeze().detach().cpu().numpy() x = (np.clip(x, -1., 1.) + 1.) / 2. if rgb is not None: rgb = rgb.squeeze().detach().cpu().numpy() rgb = (np.clip(rgb, -1., 1.) + 1.) / 2. rgb = rgb.transpose(1, 2, 0) xyz, rgb, _ = range2pcd(x, color=rgb, **config['data']['params']['dataset']) return xyz, rgb def custom_to_pil(x): x = x.detach().cpu().squeeze().numpy() x = (np.clip(x, -1., 1.) + 1.) / 2. x = (255 * x).astype(np.uint8) if x.ndim == 3: x = x.transpose(1, 2, 0) x = Image.fromarray(x) return x def logs2pil(logs, keys=["sample"]): imgs = dict() for k in logs: try: if len(logs[k].shape) == 4: img = custom_to_pil(logs[k][0, ...]) elif len(logs[k].shape) == 3: img = custom_to_pil(logs[k]) else: print(f"Unknown format for key {k}. ") img = None except: img = None imgs[k] = img return imgs def load_model_from_config(config, sd): model = instantiate_from_config(config) model.load_state_dict(sd, strict=False) model.eval() return model @torch.no_grad() def convsample_ddim(model, cond, steps, shape, eta=1.0, verbose=False): ddim = DDIMSampler(model) bs = shape[0] shape = shape[1:] samples, intermediates = ddim.sample(steps, conditioning=cond, batch_size=bs, shape=shape, eta=eta, verbose=verbose, disable_tqdm=True) return samples, intermediates @torch.no_grad() def make_convolutional_sample(model, batch, batch_size, custom_steps=None, eta=1.0): xc = batch['camera'] c = model.get_learned_conditioning(xc.to(model.device)) with model.ema_scope("Plotting"): samples, z_denoise_row = model.sample_log(cond=c, batch_size=batch_size, ddim=True, ddim_steps=custom_steps, eta=eta) x_samples = model.decode_first_stage(samples) return x_samples def sample(model, cond): batch = {'camera': cond} img = make_convolutional_sample(model, batch, batch_size=1, custom_steps=CUSTOM_STEPS, eta=ETA) # TODO add arguments for batch_size, custom_steps and eta pcd = custom_to_pcd(img, MODEL_CFG)[0].astype(np.float32) img = img.squeeze().detach().cpu().numpy() return img, pcd