Spaces:
Runtime error
Runtime error
import logging | |
import math | |
from collections import OrderedDict | |
import numpy as np | |
import torch | |
import torch.distributions as dists | |
import torch.nn.functional as F | |
from torchvision.utils import save_image | |
from models.archs.transformer_arch import TransformerMultiHead | |
from models.archs.vqgan_arch import (Decoder, Encoder, VectorQuantizer, | |
VectorQuantizerTexture) | |
logger = logging.getLogger('base') | |
class TransformerTextureAwareModel(): | |
"""Texture-Aware Diffusion based Transformer model. | |
""" | |
def __init__(self, opt): | |
self.opt = opt | |
self.device = torch.device('cuda') | |
self.is_train = opt['is_train'] | |
# VQVAE for image | |
self.img_encoder = Encoder( | |
ch=opt['img_ch'], | |
num_res_blocks=opt['img_num_res_blocks'], | |
attn_resolutions=opt['img_attn_resolutions'], | |
ch_mult=opt['img_ch_mult'], | |
in_channels=opt['img_in_channels'], | |
resolution=opt['img_resolution'], | |
z_channels=opt['img_z_channels'], | |
double_z=opt['img_double_z'], | |
dropout=opt['img_dropout']).to(self.device) | |
self.img_decoder = Decoder( | |
in_channels=opt['img_in_channels'], | |
resolution=opt['img_resolution'], | |
z_channels=opt['img_z_channels'], | |
ch=opt['img_ch'], | |
out_ch=opt['img_out_ch'], | |
num_res_blocks=opt['img_num_res_blocks'], | |
attn_resolutions=opt['img_attn_resolutions'], | |
ch_mult=opt['img_ch_mult'], | |
dropout=opt['img_dropout'], | |
resamp_with_conv=True, | |
give_pre_end=False).to(self.device) | |
self.img_quantizer = VectorQuantizerTexture( | |
opt['img_n_embed'], opt['img_embed_dim'], | |
beta=0.25).to(self.device) | |
self.img_quant_conv = torch.nn.Conv2d(opt["img_z_channels"], | |
opt['img_embed_dim'], | |
1).to(self.device) | |
self.img_post_quant_conv = torch.nn.Conv2d(opt['img_embed_dim'], | |
opt["img_z_channels"], | |
1).to(self.device) | |
self.load_pretrained_image_vae() | |
# VAE for segmentation mask | |
self.segm_encoder = Encoder( | |
ch=opt['segm_ch'], | |
num_res_blocks=opt['segm_num_res_blocks'], | |
attn_resolutions=opt['segm_attn_resolutions'], | |
ch_mult=opt['segm_ch_mult'], | |
in_channels=opt['segm_in_channels'], | |
resolution=opt['segm_resolution'], | |
z_channels=opt['segm_z_channels'], | |
double_z=opt['segm_double_z'], | |
dropout=opt['segm_dropout']).to(self.device) | |
self.segm_quantizer = VectorQuantizer( | |
opt['segm_n_embed'], | |
opt['segm_embed_dim'], | |
beta=0.25, | |
sane_index_shape=True).to(self.device) | |
self.segm_quant_conv = torch.nn.Conv2d(opt["segm_z_channels"], | |
opt['segm_embed_dim'], | |
1).to(self.device) | |
self.load_pretrained_segm_vae() | |
# define sampler | |
self._denoise_fn = TransformerMultiHead( | |
codebook_size=opt['codebook_size'], | |
segm_codebook_size=opt['segm_codebook_size'], | |
texture_codebook_size=opt['texture_codebook_size'], | |
bert_n_emb=opt['bert_n_emb'], | |
bert_n_layers=opt['bert_n_layers'], | |
bert_n_head=opt['bert_n_head'], | |
block_size=opt['block_size'], | |
latent_shape=opt['latent_shape'], | |
embd_pdrop=opt['embd_pdrop'], | |
resid_pdrop=opt['resid_pdrop'], | |
attn_pdrop=opt['attn_pdrop'], | |
num_head=opt['num_head']).to(self.device) | |
self.num_classes = opt['codebook_size'] | |
self.shape = tuple(opt['latent_shape']) | |
self.num_timesteps = 1000 | |
self.mask_id = opt['codebook_size'] | |
self.loss_type = opt['loss_type'] | |
self.mask_schedule = opt['mask_schedule'] | |
self.sample_steps = opt['sample_steps'] | |
self.init_training_settings() | |
def load_pretrained_image_vae(self): | |
# load pretrained vqgan for segmentation mask | |
img_ae_checkpoint = torch.load(self.opt['img_ae_path']) | |
self.img_encoder.load_state_dict( | |
img_ae_checkpoint['encoder'], strict=True) | |
self.img_decoder.load_state_dict( | |
img_ae_checkpoint['decoder'], strict=True) | |
self.img_quantizer.load_state_dict( | |
img_ae_checkpoint['quantize'], strict=True) | |
self.img_quant_conv.load_state_dict( | |
img_ae_checkpoint['quant_conv'], strict=True) | |
self.img_post_quant_conv.load_state_dict( | |
img_ae_checkpoint['post_quant_conv'], strict=True) | |
self.img_encoder.eval() | |
self.img_decoder.eval() | |
self.img_quantizer.eval() | |
self.img_quant_conv.eval() | |
self.img_post_quant_conv.eval() | |
def load_pretrained_segm_vae(self): | |
# load pretrained vqgan for segmentation mask | |
segm_ae_checkpoint = torch.load(self.opt['segm_ae_path']) | |
self.segm_encoder.load_state_dict( | |
segm_ae_checkpoint['encoder'], strict=True) | |
self.segm_quantizer.load_state_dict( | |
segm_ae_checkpoint['quantize'], strict=True) | |
self.segm_quant_conv.load_state_dict( | |
segm_ae_checkpoint['quant_conv'], strict=True) | |
self.segm_encoder.eval() | |
self.segm_quantizer.eval() | |
self.segm_quant_conv.eval() | |
def init_training_settings(self): | |
optim_params = [] | |
for v in self._denoise_fn.parameters(): | |
if v.requires_grad: | |
optim_params.append(v) | |
# set up optimizer | |
self.optimizer = torch.optim.Adam( | |
optim_params, | |
self.opt['lr'], | |
weight_decay=self.opt['weight_decay']) | |
self.log_dict = OrderedDict() | |
def get_quantized_img(self, image, texture_mask): | |
encoded_img = self.img_encoder(image) | |
encoded_img = self.img_quant_conv(encoded_img) | |
# img_tokens_input is the continual index for the input of transformer | |
# img_tokens_gt_list is the index for 18 texture-aware codebooks respectively | |
_, _, [_, img_tokens_input, img_tokens_gt_list | |
] = self.img_quantizer(encoded_img, texture_mask) | |
# reshape the tokens | |
b = image.size(0) | |
img_tokens_input = img_tokens_input.view(b, -1) | |
img_tokens_gt_return_list = [ | |
img_tokens_gt.view(b, -1) for img_tokens_gt in img_tokens_gt_list | |
] | |
return img_tokens_input, img_tokens_gt_return_list | |
def decode(self, quant): | |
quant = self.img_post_quant_conv(quant) | |
dec = self.img_decoder(quant) | |
return dec | |
def decode_image_indices(self, indices_list, texture_mask): | |
quant = self.img_quantizer.get_codebook_entry( | |
indices_list, texture_mask, | |
(indices_list[0].size(0), self.shape[0], self.shape[1], | |
self.opt["img_z_channels"])) | |
dec = self.decode(quant) | |
return dec | |
def sample_time(self, b, device, method='uniform'): | |
if method == 'importance': | |
if not (self.Lt_count > 10).all(): | |
return self.sample_time(b, device, method='uniform') | |
Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001 | |
Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1. | |
pt_all = Lt_sqrt / Lt_sqrt.sum() | |
t = torch.multinomial(pt_all, num_samples=b, replacement=True) | |
pt = pt_all.gather(dim=0, index=t) | |
return t, pt | |
elif method == 'uniform': | |
t = torch.randint( | |
1, self.num_timesteps + 1, (b, ), device=device).long() | |
pt = torch.ones_like(t).float() / self.num_timesteps | |
return t, pt | |
else: | |
raise ValueError | |
def q_sample(self, x_0, x_0_gt_list, t): | |
# samples q(x_t | x_0) | |
# randomly set token to mask with probability t/T | |
# x_t, x_0_ignore = x_0.clone(), x_0.clone() | |
x_t = x_0.clone() | |
mask = torch.rand_like(x_t.float()) < ( | |
t.float().unsqueeze(-1) / self.num_timesteps) | |
x_t[mask] = self.mask_id | |
# x_0_ignore[torch.bitwise_not(mask)] = -1 | |
# for every gt token list, we also need to do the mask | |
x_0_gt_ignore_list = [] | |
for x_0_gt in x_0_gt_list: | |
x_0_gt_ignore = x_0_gt.clone() | |
x_0_gt_ignore[torch.bitwise_not(mask)] = -1 | |
x_0_gt_ignore_list.append(x_0_gt_ignore) | |
return x_t, x_0_gt_ignore_list, mask | |
def _train_loss(self, x_0, x_0_gt_list): | |
b, device = x_0.size(0), x_0.device | |
# choose what time steps to compute loss at | |
t, pt = self.sample_time(b, device, 'uniform') | |
# make x noisy and denoise | |
if self.mask_schedule == 'random': | |
x_t, x_0_gt_ignore_list, mask = self.q_sample( | |
x_0=x_0, x_0_gt_list=x_0_gt_list, t=t) | |
else: | |
raise NotImplementedError | |
# sample p(x_0 | x_t) | |
x_0_hat_logits_list = self._denoise_fn( | |
x_t, self.segm_tokens, self.texture_tokens, t=t) | |
# Always compute ELBO for comparison purposes | |
cross_entropy_loss = 0 | |
for x_0_hat_logits, x_0_gt_ignore in zip(x_0_hat_logits_list, | |
x_0_gt_ignore_list): | |
cross_entropy_loss += F.cross_entropy( | |
x_0_hat_logits.permute(0, 2, 1), | |
x_0_gt_ignore, | |
ignore_index=-1, | |
reduction='none').sum(1) | |
vb_loss = cross_entropy_loss / t | |
vb_loss = vb_loss / pt | |
vb_loss = vb_loss / (math.log(2) * x_0.shape[1:].numel()) | |
if self.loss_type == 'elbo': | |
loss = vb_loss | |
elif self.loss_type == 'mlm': | |
denom = mask.float().sum(1) | |
denom[denom == 0] = 1 # prevent divide by 0 errors. | |
loss = cross_entropy_loss / denom | |
elif self.loss_type == 'reweighted_elbo': | |
weight = (1 - (t / self.num_timesteps)) | |
loss = weight * cross_entropy_loss | |
loss = loss / (math.log(2) * x_0.shape[1:].numel()) | |
else: | |
raise ValueError | |
return loss.mean(), vb_loss.mean() | |
def feed_data(self, data): | |
self.image = data['image'].to(self.device) | |
self.segm = data['segm'].to(self.device) | |
self.texture_mask = data['texture_mask'].to(self.device) | |
self.input_indices, self.gt_indices_list = self.get_quantized_img( | |
self.image, self.texture_mask) | |
self.texture_tokens = F.interpolate( | |
self.texture_mask, size=self.shape, | |
mode='nearest').view(self.image.size(0), -1).long() | |
self.segm_tokens = self.get_quantized_segm(self.segm) | |
self.segm_tokens = self.segm_tokens.view(self.image.size(0), -1) | |
def optimize_parameters(self): | |
self._denoise_fn.train() | |
loss, vb_loss = self._train_loss(self.input_indices, | |
self.gt_indices_list) | |
self.optimizer.zero_grad() | |
loss.backward() | |
self.optimizer.step() | |
self.log_dict['loss'] = loss | |
self.log_dict['vb_loss'] = vb_loss | |
self._denoise_fn.eval() | |
def get_quantized_segm(self, segm): | |
segm_one_hot = F.one_hot( | |
segm.squeeze(1).long(), | |
num_classes=self.opt['segm_num_segm_classes']).permute( | |
0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() | |
encoded_segm_mask = self.segm_encoder(segm_one_hot) | |
encoded_segm_mask = self.segm_quant_conv(encoded_segm_mask) | |
_, _, [_, _, segm_tokens] = self.segm_quantizer(encoded_segm_mask) | |
return segm_tokens | |
def sample_fn(self, temp=1.0, sample_steps=None): | |
self._denoise_fn.eval() | |
b, device = self.image.size(0), 'cuda' | |
x_t = torch.ones( | |
(b, np.prod(self.shape)), device=device).long() * self.mask_id | |
unmasked = torch.zeros_like(x_t, device=device).bool() | |
sample_steps = list(range(1, sample_steps + 1)) | |
texture_mask_flatten = self.texture_tokens.view(-1) | |
# min_encodings_indices_list would be used to visualize the image | |
min_encodings_indices_list = [ | |
torch.full( | |
texture_mask_flatten.size(), | |
fill_value=-1, | |
dtype=torch.long, | |
device=texture_mask_flatten.device) for _ in range(18) | |
] | |
for t in reversed(sample_steps): | |
print(f'Sample timestep {t:4d}', end='\r') | |
t = torch.full((b, ), t, device=device, dtype=torch.long) | |
# where to unmask | |
changes = torch.rand( | |
x_t.shape, device=device) < 1 / t.float().unsqueeze(-1) | |
# don't unmask somewhere already unmasked | |
changes = torch.bitwise_xor(changes, | |
torch.bitwise_and(changes, unmasked)) | |
# update mask with changes | |
unmasked = torch.bitwise_or(unmasked, changes) | |
x_0_logits_list = self._denoise_fn( | |
x_t, self.segm_tokens, self.texture_tokens, t=t) | |
changes_flatten = changes.view(-1) | |
ori_shape = x_t.shape # [b, h*w] | |
x_t = x_t.view(-1) # [b*h*w] | |
for codebook_idx, x_0_logits in enumerate(x_0_logits_list): | |
if torch.sum(texture_mask_flatten[changes_flatten] == | |
codebook_idx) > 0: | |
# scale by temperature | |
x_0_logits = x_0_logits / temp | |
x_0_dist = dists.Categorical(logits=x_0_logits) | |
x_0_hat = x_0_dist.sample().long() | |
x_0_hat = x_0_hat.view(-1) | |
# only replace the changed indices with corresponding codebook_idx | |
changes_segm = torch.bitwise_and( | |
changes_flatten, texture_mask_flatten == codebook_idx) | |
# x_t would be the input to the transformer, so the index range should be continual one | |
x_t[changes_segm] = x_0_hat[ | |
changes_segm] + 1024 * codebook_idx | |
min_encodings_indices_list[codebook_idx][ | |
changes_segm] = x_0_hat[changes_segm] | |
x_t = x_t.view(ori_shape) # [b, h*w] | |
min_encodings_indices_return_list = [ | |
min_encodings_indices.view(ori_shape) | |
for min_encodings_indices in min_encodings_indices_list | |
] | |
self._denoise_fn.train() | |
return min_encodings_indices_return_list | |
def get_vis(self, image, gt_indices, predicted_indices, texture_mask, | |
save_path): | |
# original image | |
ori_img = self.decode_image_indices(gt_indices, texture_mask) | |
# pred image | |
pred_img = self.decode_image_indices(predicted_indices, texture_mask) | |
img_cat = torch.cat([ | |
image, | |
ori_img, | |
pred_img, | |
], dim=3).detach() | |
img_cat = ((img_cat + 1) / 2) | |
img_cat = img_cat.clamp_(0, 1) | |
save_image(img_cat, save_path, nrow=1, padding=4) | |
def inference(self, data_loader, save_dir): | |
self._denoise_fn.eval() | |
for _, data in enumerate(data_loader): | |
img_name = data['img_name'] | |
self.feed_data(data) | |
b = self.image.size(0) | |
with torch.no_grad(): | |
sampled_indices_list = self.sample_fn( | |
temp=1, sample_steps=self.sample_steps) | |
for idx in range(b): | |
self.get_vis(self.image[idx:idx + 1], [ | |
gt_indices[idx:idx + 1] | |
for gt_indices in self.gt_indices_list | |
], [ | |
sampled_indices[idx:idx + 1] | |
for sampled_indices in sampled_indices_list | |
], self.texture_mask[idx:idx + 1], | |
f'{save_dir}/{img_name[idx]}') | |
self._denoise_fn.train() | |
def get_current_log(self): | |
return self.log_dict | |
def update_learning_rate(self, epoch, iters=None): | |
"""Update learning rate. | |
Args: | |
current_iter (int): Current iteration. | |
warmup_iter (int): Warmup iter numbers. -1 for no warmup. | |
Default: -1. | |
""" | |
lr = self.optimizer.param_groups[0]['lr'] | |
if self.opt['lr_decay'] == 'step': | |
lr = self.opt['lr'] * ( | |
self.opt['gamma']**(epoch // self.opt['step'])) | |
elif self.opt['lr_decay'] == 'cos': | |
lr = self.opt['lr'] * ( | |
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2 | |
elif self.opt['lr_decay'] == 'linear': | |
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs']) | |
elif self.opt['lr_decay'] == 'linear2exp': | |
if epoch < self.opt['turning_point'] + 1: | |
# learning rate decay as 95% | |
# at the turning point (1 / 95% = 1.0526) | |
lr = self.opt['lr'] * ( | |
1 - epoch / int(self.opt['turning_point'] * 1.0526)) | |
else: | |
lr *= self.opt['gamma'] | |
elif self.opt['lr_decay'] == 'schedule': | |
if epoch in self.opt['schedule']: | |
lr *= self.opt['gamma'] | |
elif self.opt['lr_decay'] == 'warm_up': | |
if iters <= self.opt['warmup_iters']: | |
lr = self.opt['lr'] * float(iters) / self.opt['warmup_iters'] | |
else: | |
lr = self.opt['lr'] | |
else: | |
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay'])) | |
# set learning rate | |
for param_group in self.optimizer.param_groups: | |
param_group['lr'] = lr | |
return lr | |
def save_network(self, net, save_path): | |
"""Save networks. | |
Args: | |
net (nn.Module): Network to be saved. | |
net_label (str): Network label. | |
current_iter (int): Current iter number. | |
""" | |
state_dict = net.state_dict() | |
torch.save(state_dict, save_path) | |
def load_network(self): | |
checkpoint = torch.load(self.opt['pretrained_sampler']) | |
self._denoise_fn.load_state_dict(checkpoint, strict=True) | |
self._denoise_fn.eval() | |