import logging import math from collections import OrderedDict import numpy as np import torch import torch.distributions as dists import torch.nn.functional as F from torchvision.utils import save_image from models.archs.transformer_arch import TransformerMultiHead from models.archs.vqgan_arch import (Decoder, Encoder, VectorQuantizer, VectorQuantizerTexture) logger = logging.getLogger('base') class TransformerTextureAwareModel(): """Texture-Aware Diffusion based Transformer model. """ def __init__(self, opt): self.opt = opt self.device = torch.device('cuda') self.is_train = opt['is_train'] # VQVAE for image self.img_encoder = Encoder( ch=opt['img_ch'], num_res_blocks=opt['img_num_res_blocks'], attn_resolutions=opt['img_attn_resolutions'], ch_mult=opt['img_ch_mult'], in_channels=opt['img_in_channels'], resolution=opt['img_resolution'], z_channels=opt['img_z_channels'], double_z=opt['img_double_z'], dropout=opt['img_dropout']).to(self.device) self.img_decoder = Decoder( in_channels=opt['img_in_channels'], resolution=opt['img_resolution'], z_channels=opt['img_z_channels'], ch=opt['img_ch'], out_ch=opt['img_out_ch'], num_res_blocks=opt['img_num_res_blocks'], attn_resolutions=opt['img_attn_resolutions'], ch_mult=opt['img_ch_mult'], dropout=opt['img_dropout'], resamp_with_conv=True, give_pre_end=False).to(self.device) self.img_quantizer = VectorQuantizerTexture( opt['img_n_embed'], opt['img_embed_dim'], beta=0.25).to(self.device) self.img_quant_conv = torch.nn.Conv2d(opt["img_z_channels"], opt['img_embed_dim'], 1).to(self.device) self.img_post_quant_conv = torch.nn.Conv2d(opt['img_embed_dim'], opt["img_z_channels"], 1).to(self.device) self.load_pretrained_image_vae() # VAE for segmentation mask self.segm_encoder = Encoder( ch=opt['segm_ch'], num_res_blocks=opt['segm_num_res_blocks'], attn_resolutions=opt['segm_attn_resolutions'], ch_mult=opt['segm_ch_mult'], in_channels=opt['segm_in_channels'], resolution=opt['segm_resolution'], z_channels=opt['segm_z_channels'], double_z=opt['segm_double_z'], dropout=opt['segm_dropout']).to(self.device) self.segm_quantizer = VectorQuantizer( opt['segm_n_embed'], opt['segm_embed_dim'], beta=0.25, sane_index_shape=True).to(self.device) self.segm_quant_conv = torch.nn.Conv2d(opt["segm_z_channels"], opt['segm_embed_dim'], 1).to(self.device) self.load_pretrained_segm_vae() # define sampler self._denoise_fn = TransformerMultiHead( codebook_size=opt['codebook_size'], segm_codebook_size=opt['segm_codebook_size'], texture_codebook_size=opt['texture_codebook_size'], bert_n_emb=opt['bert_n_emb'], bert_n_layers=opt['bert_n_layers'], bert_n_head=opt['bert_n_head'], block_size=opt['block_size'], latent_shape=opt['latent_shape'], embd_pdrop=opt['embd_pdrop'], resid_pdrop=opt['resid_pdrop'], attn_pdrop=opt['attn_pdrop'], num_head=opt['num_head']).to(self.device) self.num_classes = opt['codebook_size'] self.shape = tuple(opt['latent_shape']) self.num_timesteps = 1000 self.mask_id = opt['codebook_size'] self.loss_type = opt['loss_type'] self.mask_schedule = opt['mask_schedule'] self.sample_steps = opt['sample_steps'] self.init_training_settings() def load_pretrained_image_vae(self): # load pretrained vqgan for segmentation mask img_ae_checkpoint = torch.load(self.opt['img_ae_path']) self.img_encoder.load_state_dict( img_ae_checkpoint['encoder'], strict=True) self.img_decoder.load_state_dict( img_ae_checkpoint['decoder'], strict=True) self.img_quantizer.load_state_dict( img_ae_checkpoint['quantize'], strict=True) self.img_quant_conv.load_state_dict( img_ae_checkpoint['quant_conv'], strict=True) self.img_post_quant_conv.load_state_dict( img_ae_checkpoint['post_quant_conv'], strict=True) self.img_encoder.eval() self.img_decoder.eval() self.img_quantizer.eval() self.img_quant_conv.eval() self.img_post_quant_conv.eval() def load_pretrained_segm_vae(self): # load pretrained vqgan for segmentation mask segm_ae_checkpoint = torch.load(self.opt['segm_ae_path']) self.segm_encoder.load_state_dict( segm_ae_checkpoint['encoder'], strict=True) self.segm_quantizer.load_state_dict( segm_ae_checkpoint['quantize'], strict=True) self.segm_quant_conv.load_state_dict( segm_ae_checkpoint['quant_conv'], strict=True) self.segm_encoder.eval() self.segm_quantizer.eval() self.segm_quant_conv.eval() def init_training_settings(self): optim_params = [] for v in self._denoise_fn.parameters(): if v.requires_grad: optim_params.append(v) # set up optimizer self.optimizer = torch.optim.Adam( optim_params, self.opt['lr'], weight_decay=self.opt['weight_decay']) self.log_dict = OrderedDict() @torch.no_grad() def get_quantized_img(self, image, texture_mask): encoded_img = self.img_encoder(image) encoded_img = self.img_quant_conv(encoded_img) # img_tokens_input is the continual index for the input of transformer # img_tokens_gt_list is the index for 18 texture-aware codebooks respectively _, _, [_, img_tokens_input, img_tokens_gt_list ] = self.img_quantizer(encoded_img, texture_mask) # reshape the tokens b = image.size(0) img_tokens_input = img_tokens_input.view(b, -1) img_tokens_gt_return_list = [ img_tokens_gt.view(b, -1) for img_tokens_gt in img_tokens_gt_list ] return img_tokens_input, img_tokens_gt_return_list @torch.no_grad() def decode(self, quant): quant = self.img_post_quant_conv(quant) dec = self.img_decoder(quant) return dec @torch.no_grad() def decode_image_indices(self, indices_list, texture_mask): quant = self.img_quantizer.get_codebook_entry( indices_list, texture_mask, (indices_list[0].size(0), self.shape[0], self.shape[1], self.opt["img_z_channels"])) dec = self.decode(quant) return dec def sample_time(self, b, device, method='uniform'): if method == 'importance': if not (self.Lt_count > 10).all(): return self.sample_time(b, device, method='uniform') Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001 Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1. pt_all = Lt_sqrt / Lt_sqrt.sum() t = torch.multinomial(pt_all, num_samples=b, replacement=True) pt = pt_all.gather(dim=0, index=t) return t, pt elif method == 'uniform': t = torch.randint( 1, self.num_timesteps + 1, (b, ), device=device).long() pt = torch.ones_like(t).float() / self.num_timesteps return t, pt else: raise ValueError def q_sample(self, x_0, x_0_gt_list, t): # samples q(x_t | x_0) # randomly set token to mask with probability t/T # x_t, x_0_ignore = x_0.clone(), x_0.clone() x_t = x_0.clone() mask = torch.rand_like(x_t.float()) < ( t.float().unsqueeze(-1) / self.num_timesteps) x_t[mask] = self.mask_id # x_0_ignore[torch.bitwise_not(mask)] = -1 # for every gt token list, we also need to do the mask x_0_gt_ignore_list = [] for x_0_gt in x_0_gt_list: x_0_gt_ignore = x_0_gt.clone() x_0_gt_ignore[torch.bitwise_not(mask)] = -1 x_0_gt_ignore_list.append(x_0_gt_ignore) return x_t, x_0_gt_ignore_list, mask def _train_loss(self, x_0, x_0_gt_list): b, device = x_0.size(0), x_0.device # choose what time steps to compute loss at t, pt = self.sample_time(b, device, 'uniform') # make x noisy and denoise if self.mask_schedule == 'random': x_t, x_0_gt_ignore_list, mask = self.q_sample( x_0=x_0, x_0_gt_list=x_0_gt_list, t=t) else: raise NotImplementedError # sample p(x_0 | x_t) x_0_hat_logits_list = self._denoise_fn( x_t, self.segm_tokens, self.texture_tokens, t=t) # Always compute ELBO for comparison purposes cross_entropy_loss = 0 for x_0_hat_logits, x_0_gt_ignore in zip(x_0_hat_logits_list, x_0_gt_ignore_list): cross_entropy_loss += F.cross_entropy( x_0_hat_logits.permute(0, 2, 1), x_0_gt_ignore, ignore_index=-1, reduction='none').sum(1) vb_loss = cross_entropy_loss / t vb_loss = vb_loss / pt vb_loss = vb_loss / (math.log(2) * x_0.shape[1:].numel()) if self.loss_type == 'elbo': loss = vb_loss elif self.loss_type == 'mlm': denom = mask.float().sum(1) denom[denom == 0] = 1 # prevent divide by 0 errors. loss = cross_entropy_loss / denom elif self.loss_type == 'reweighted_elbo': weight = (1 - (t / self.num_timesteps)) loss = weight * cross_entropy_loss loss = loss / (math.log(2) * x_0.shape[1:].numel()) else: raise ValueError return loss.mean(), vb_loss.mean() def feed_data(self, data): self.image = data['image'].to(self.device) self.segm = data['segm'].to(self.device) self.texture_mask = data['texture_mask'].to(self.device) self.input_indices, self.gt_indices_list = self.get_quantized_img( self.image, self.texture_mask) self.texture_tokens = F.interpolate( self.texture_mask, size=self.shape, mode='nearest').view(self.image.size(0), -1).long() self.segm_tokens = self.get_quantized_segm(self.segm) self.segm_tokens = self.segm_tokens.view(self.image.size(0), -1) def optimize_parameters(self): self._denoise_fn.train() loss, vb_loss = self._train_loss(self.input_indices, self.gt_indices_list) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.log_dict['loss'] = loss self.log_dict['vb_loss'] = vb_loss self._denoise_fn.eval() @torch.no_grad() def get_quantized_segm(self, segm): segm_one_hot = F.one_hot( segm.squeeze(1).long(), num_classes=self.opt['segm_num_segm_classes']).permute( 0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() encoded_segm_mask = self.segm_encoder(segm_one_hot) encoded_segm_mask = self.segm_quant_conv(encoded_segm_mask) _, _, [_, _, segm_tokens] = self.segm_quantizer(encoded_segm_mask) return segm_tokens def sample_fn(self, temp=1.0, sample_steps=None): self._denoise_fn.eval() b, device = self.image.size(0), 'cuda' x_t = torch.ones( (b, np.prod(self.shape)), device=device).long() * self.mask_id unmasked = torch.zeros_like(x_t, device=device).bool() sample_steps = list(range(1, sample_steps + 1)) texture_mask_flatten = self.texture_tokens.view(-1) # min_encodings_indices_list would be used to visualize the image min_encodings_indices_list = [ torch.full( texture_mask_flatten.size(), fill_value=-1, dtype=torch.long, device=texture_mask_flatten.device) for _ in range(18) ] for t in reversed(sample_steps): print(f'Sample timestep {t:4d}', end='\r') t = torch.full((b, ), t, device=device, dtype=torch.long) # where to unmask changes = torch.rand( x_t.shape, device=device) < 1 / t.float().unsqueeze(-1) # don't unmask somewhere already unmasked changes = torch.bitwise_xor(changes, torch.bitwise_and(changes, unmasked)) # update mask with changes unmasked = torch.bitwise_or(unmasked, changes) x_0_logits_list = self._denoise_fn( x_t, self.segm_tokens, self.texture_tokens, t=t) changes_flatten = changes.view(-1) ori_shape = x_t.shape # [b, h*w] x_t = x_t.view(-1) # [b*h*w] for codebook_idx, x_0_logits in enumerate(x_0_logits_list): if torch.sum(texture_mask_flatten[changes_flatten] == codebook_idx) > 0: # scale by temperature x_0_logits = x_0_logits / temp x_0_dist = dists.Categorical(logits=x_0_logits) x_0_hat = x_0_dist.sample().long() x_0_hat = x_0_hat.view(-1) # only replace the changed indices with corresponding codebook_idx changes_segm = torch.bitwise_and( changes_flatten, texture_mask_flatten == codebook_idx) # x_t would be the input to the transformer, so the index range should be continual one x_t[changes_segm] = x_0_hat[ changes_segm] + 1024 * codebook_idx min_encodings_indices_list[codebook_idx][ changes_segm] = x_0_hat[changes_segm] x_t = x_t.view(ori_shape) # [b, h*w] min_encodings_indices_return_list = [ min_encodings_indices.view(ori_shape) for min_encodings_indices in min_encodings_indices_list ] self._denoise_fn.train() return min_encodings_indices_return_list def get_vis(self, image, gt_indices, predicted_indices, texture_mask, save_path): # original image ori_img = self.decode_image_indices(gt_indices, texture_mask) # pred image pred_img = self.decode_image_indices(predicted_indices, texture_mask) img_cat = torch.cat([ image, ori_img, pred_img, ], dim=3).detach() img_cat = ((img_cat + 1) / 2) img_cat = img_cat.clamp_(0, 1) save_image(img_cat, save_path, nrow=1, padding=4) def inference(self, data_loader, save_dir): self._denoise_fn.eval() for _, data in enumerate(data_loader): img_name = data['img_name'] self.feed_data(data) b = self.image.size(0) with torch.no_grad(): sampled_indices_list = self.sample_fn( temp=1, sample_steps=self.sample_steps) for idx in range(b): self.get_vis(self.image[idx:idx + 1], [ gt_indices[idx:idx + 1] for gt_indices in self.gt_indices_list ], [ sampled_indices[idx:idx + 1] for sampled_indices in sampled_indices_list ], self.texture_mask[idx:idx + 1], f'{save_dir}/{img_name[idx]}') self._denoise_fn.train() def get_current_log(self): return self.log_dict def update_learning_rate(self, epoch, iters=None): """Update learning rate. Args: current_iter (int): Current iteration. warmup_iter (int): Warmup iter numbers. -1 for no warmup. Default: -1. """ lr = self.optimizer.param_groups[0]['lr'] if self.opt['lr_decay'] == 'step': lr = self.opt['lr'] * ( self.opt['gamma']**(epoch // self.opt['step'])) elif self.opt['lr_decay'] == 'cos': lr = self.opt['lr'] * ( 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2 elif self.opt['lr_decay'] == 'linear': lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs']) elif self.opt['lr_decay'] == 'linear2exp': if epoch < self.opt['turning_point'] + 1: # learning rate decay as 95% # at the turning point (1 / 95% = 1.0526) lr = self.opt['lr'] * ( 1 - epoch / int(self.opt['turning_point'] * 1.0526)) else: lr *= self.opt['gamma'] elif self.opt['lr_decay'] == 'schedule': if epoch in self.opt['schedule']: lr *= self.opt['gamma'] elif self.opt['lr_decay'] == 'warm_up': if iters <= self.opt['warmup_iters']: lr = self.opt['lr'] * float(iters) / self.opt['warmup_iters'] else: lr = self.opt['lr'] else: raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay'])) # set learning rate for param_group in self.optimizer.param_groups: param_group['lr'] = lr return lr def save_network(self, net, save_path): """Save networks. Args: net (nn.Module): Network to be saved. net_label (str): Network label. current_iter (int): Current iter number. """ state_dict = net.state_dict() torch.save(state_dict, save_path) def load_network(self): checkpoint = torch.load(self.opt['pretrained_sampler']) self._denoise_fn.load_state_dict(checkpoint, strict=True) self._denoise_fn.eval()