Spaces:
Runtime error
Runtime error
import logging | |
import math | |
from collections import OrderedDict | |
import torch | |
import torch.nn.functional as F | |
from torchvision.utils import save_image | |
from models.archs.fcn_arch import MultiHeadFCNHead | |
from models.archs.unet_arch import UNet | |
from models.archs.vqgan_arch import (Decoder, DecoderRes, Encoder, | |
VectorQuantizerSpatialTextureAware, | |
VectorQuantizerTexture) | |
from models.losses.accuracy import accuracy | |
from models.losses.cross_entropy_loss import CrossEntropyLoss | |
logger = logging.getLogger('base') | |
class VQGANTextureAwareSpatialHierarchyInferenceModel(): | |
def __init__(self, opt): | |
self.opt = opt | |
self.device = torch.device('cuda') | |
self.is_train = opt['is_train'] | |
self.top_encoder = Encoder( | |
ch=opt['top_ch'], | |
num_res_blocks=opt['top_num_res_blocks'], | |
attn_resolutions=opt['top_attn_resolutions'], | |
ch_mult=opt['top_ch_mult'], | |
in_channels=opt['top_in_channels'], | |
resolution=opt['top_resolution'], | |
z_channels=opt['top_z_channels'], | |
double_z=opt['top_double_z'], | |
dropout=opt['top_dropout']).to(self.device) | |
self.decoder = Decoder( | |
in_channels=opt['top_in_channels'], | |
resolution=opt['top_resolution'], | |
z_channels=opt['top_z_channels'], | |
ch=opt['top_ch'], | |
out_ch=opt['top_out_ch'], | |
num_res_blocks=opt['top_num_res_blocks'], | |
attn_resolutions=opt['top_attn_resolutions'], | |
ch_mult=opt['top_ch_mult'], | |
dropout=opt['top_dropout'], | |
resamp_with_conv=True, | |
give_pre_end=False).to(self.device) | |
self.top_quantize = VectorQuantizerTexture( | |
1024, opt['embed_dim'], beta=0.25).to(self.device) | |
self.top_quant_conv = torch.nn.Conv2d(opt["top_z_channels"], | |
opt['embed_dim'], | |
1).to(self.device) | |
self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'], | |
opt["top_z_channels"], | |
1).to(self.device) | |
self.load_top_pretrain_models() | |
self.bot_encoder = Encoder( | |
ch=opt['bot_ch'], | |
num_res_blocks=opt['bot_num_res_blocks'], | |
attn_resolutions=opt['bot_attn_resolutions'], | |
ch_mult=opt['bot_ch_mult'], | |
in_channels=opt['bot_in_channels'], | |
resolution=opt['bot_resolution'], | |
z_channels=opt['bot_z_channels'], | |
double_z=opt['bot_double_z'], | |
dropout=opt['bot_dropout']).to(self.device) | |
self.bot_decoder_res = DecoderRes( | |
in_channels=opt['bot_in_channels'], | |
resolution=opt['bot_resolution'], | |
z_channels=opt['bot_z_channels'], | |
ch=opt['bot_ch'], | |
num_res_blocks=opt['bot_num_res_blocks'], | |
ch_mult=opt['bot_ch_mult'], | |
dropout=opt['bot_dropout'], | |
give_pre_end=False).to(self.device) | |
self.bot_quantize = VectorQuantizerSpatialTextureAware( | |
opt['bot_n_embed'], | |
opt['embed_dim'], | |
beta=0.25, | |
spatial_size=opt['codebook_spatial_size']).to(self.device) | |
self.bot_quant_conv = torch.nn.Conv2d(opt["bot_z_channels"], | |
opt['embed_dim'], | |
1).to(self.device) | |
self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'], | |
opt["bot_z_channels"], | |
1).to(self.device) | |
self.load_bot_pretrain_network() | |
self.guidance_encoder = UNet( | |
in_channels=opt['encoder_in_channels']).to(self.device) | |
self.index_decoder = MultiHeadFCNHead( | |
in_channels=opt['fc_in_channels'], | |
in_index=opt['fc_in_index'], | |
channels=opt['fc_channels'], | |
num_convs=opt['fc_num_convs'], | |
concat_input=opt['fc_concat_input'], | |
dropout_ratio=opt['fc_dropout_ratio'], | |
num_classes=opt['fc_num_classes'], | |
align_corners=opt['fc_align_corners'], | |
num_head=18).to(self.device) | |
self.init_training_settings() | |
def init_training_settings(self): | |
optim_params = [] | |
for v in self.guidance_encoder.parameters(): | |
if v.requires_grad: | |
optim_params.append(v) | |
for v in self.index_decoder.parameters(): | |
if v.requires_grad: | |
optim_params.append(v) | |
# set up optimizers | |
if self.opt['optimizer'] == 'Adam': | |
self.optimizer = torch.optim.Adam( | |
optim_params, | |
self.opt['lr'], | |
weight_decay=self.opt['weight_decay']) | |
elif self.opt['optimizer'] == 'SGD': | |
self.optimizer = torch.optim.SGD( | |
optim_params, | |
self.opt['lr'], | |
momentum=self.opt['momentum'], | |
weight_decay=self.opt['weight_decay']) | |
self.log_dict = OrderedDict() | |
if self.opt['loss_function'] == 'cross_entropy': | |
self.loss_func = CrossEntropyLoss().to(self.device) | |
def load_top_pretrain_models(self): | |
# load pretrained vqgan for segmentation mask | |
top_vae_checkpoint = torch.load(self.opt['top_vae_path']) | |
self.top_encoder.load_state_dict( | |
top_vae_checkpoint['encoder'], strict=True) | |
self.decoder.load_state_dict( | |
top_vae_checkpoint['decoder'], strict=True) | |
self.top_quantize.load_state_dict( | |
top_vae_checkpoint['quantize'], strict=True) | |
self.top_quant_conv.load_state_dict( | |
top_vae_checkpoint['quant_conv'], strict=True) | |
self.top_post_quant_conv.load_state_dict( | |
top_vae_checkpoint['post_quant_conv'], strict=True) | |
self.top_encoder.eval() | |
self.top_quantize.eval() | |
self.top_quant_conv.eval() | |
self.top_post_quant_conv.eval() | |
def load_bot_pretrain_network(self): | |
checkpoint = torch.load(self.opt['bot_vae_path']) | |
self.bot_encoder.load_state_dict( | |
checkpoint['bot_encoder'], strict=True) | |
self.bot_decoder_res.load_state_dict( | |
checkpoint['bot_decoder_res'], strict=True) | |
self.decoder.load_state_dict(checkpoint['decoder'], strict=True) | |
self.bot_quantize.load_state_dict( | |
checkpoint['bot_quantize'], strict=True) | |
self.bot_quant_conv.load_state_dict( | |
checkpoint['bot_quant_conv'], strict=True) | |
self.bot_post_quant_conv.load_state_dict( | |
checkpoint['bot_post_quant_conv'], strict=True) | |
self.bot_encoder.eval() | |
self.bot_decoder_res.eval() | |
self.decoder.eval() | |
self.bot_quantize.eval() | |
self.bot_quant_conv.eval() | |
self.bot_post_quant_conv.eval() | |
def top_encode(self, x, mask): | |
h = self.top_encoder(x) | |
h = self.top_quant_conv(h) | |
quant, _, _ = self.top_quantize(h, mask) | |
quant = self.top_post_quant_conv(quant) | |
return quant, quant | |
def feed_data(self, data): | |
self.image = data['image'].to(self.device) | |
self.texture_mask = data['texture_mask'].float().to(self.device) | |
self.get_gt_indices() | |
self.texture_tokens = F.interpolate( | |
self.texture_mask, size=(32, 16), | |
mode='nearest').view(self.image.size(0), -1).long() | |
def bot_encode(self, x, mask): | |
h = self.bot_encoder(x) | |
h = self.bot_quant_conv(h) | |
_, _, (_, _, indices_list) = self.bot_quantize(h, mask) | |
return indices_list | |
def get_gt_indices(self): | |
self.quant_t, self.feature_t = self.top_encode(self.image, | |
self.texture_mask) | |
self.gt_indices_list = self.bot_encode(self.image, self.texture_mask) | |
def index_to_image(self, index_bottom_list, texture_mask): | |
quant_b = self.bot_quantize.get_codebook_entry( | |
index_bottom_list, texture_mask, | |
(index_bottom_list[0].size(0), index_bottom_list[0].size(1), | |
index_bottom_list[0].size(2), | |
self.opt["bot_z_channels"])) #.permute(0, 3, 1, 2) | |
quant_b = self.bot_post_quant_conv(quant_b) | |
bot_dec_res = self.bot_decoder_res(quant_b) | |
dec = self.decoder(self.quant_t, bot_h=bot_dec_res) | |
return dec | |
def get_vis(self, pred_img_index, rec_img_index, texture_mask, save_path): | |
rec_img = self.index_to_image(rec_img_index, texture_mask) | |
pred_img = self.index_to_image(pred_img_index, texture_mask) | |
base_img = self.decoder(self.quant_t) | |
img_cat = torch.cat([ | |
self.image, | |
rec_img, | |
base_img, | |
pred_img, | |
], dim=3).detach() | |
img_cat = ((img_cat + 1) / 2) | |
img_cat = img_cat.clamp_(0, 1) | |
save_image(img_cat, save_path, nrow=1, padding=4) | |
def optimize_parameters(self): | |
self.guidance_encoder.train() | |
self.index_decoder.train() | |
self.feature_enc = self.guidance_encoder(self.feature_t) | |
self.memory_logits_list = self.index_decoder(self.feature_enc) | |
loss = 0 | |
for i in range(18): | |
loss += self.loss_func( | |
self.memory_logits_list[i], | |
self.gt_indices_list[i], | |
ignore_index=-1) | |
self.optimizer.zero_grad() | |
loss.backward() | |
self.optimizer.step() | |
self.log_dict['loss_total'] = loss | |
def inference(self, data_loader, save_dir): | |
self.guidance_encoder.eval() | |
self.index_decoder.eval() | |
acc = 0 | |
num = 0 | |
for _, data in enumerate(data_loader): | |
self.feed_data(data) | |
img_name = data['img_name'] | |
num += self.image.size(0) | |
texture_mask_flatten = self.texture_tokens.view(-1) | |
min_encodings_indices_list = [ | |
torch.full( | |
texture_mask_flatten.size(), | |
fill_value=-1, | |
dtype=torch.long, | |
device=texture_mask_flatten.device) for _ in range(18) | |
] | |
with torch.no_grad(): | |
self.feature_enc = self.guidance_encoder(self.feature_t) | |
memory_logits_list = self.index_decoder(self.feature_enc) | |
# memory_indices_pred = memory_logits.argmax(dim=1) | |
batch_acc = 0 | |
for codebook_idx, memory_logits in enumerate(memory_logits_list): | |
region_of_interest = texture_mask_flatten == codebook_idx | |
if torch.sum(region_of_interest) > 0: | |
memory_indices_pred = memory_logits.argmax(dim=1).view(-1) | |
batch_acc += torch.sum( | |
memory_indices_pred[region_of_interest] == | |
self.gt_indices_list[codebook_idx].view( | |
-1)[region_of_interest]) | |
memory_indices_pred = memory_indices_pred | |
min_encodings_indices_list[codebook_idx][ | |
region_of_interest] = memory_indices_pred[ | |
region_of_interest] | |
min_encodings_indices_return_list = [ | |
min_encodings_indices.view(self.gt_indices_list[0].size()) | |
for min_encodings_indices in min_encodings_indices_list | |
] | |
batch_acc = batch_acc / self.gt_indices_list[codebook_idx].numel( | |
) * self.image.size(0) | |
acc += batch_acc | |
self.get_vis(min_encodings_indices_return_list, | |
self.gt_indices_list, self.texture_mask, | |
f'{save_dir}/{img_name[0]}') | |
self.guidance_encoder.train() | |
self.index_decoder.train() | |
return (acc / num).item() | |
def load_network(self): | |
checkpoint = torch.load(self.opt['pretrained_models']) | |
self.guidance_encoder.load_state_dict( | |
checkpoint['guidance_encoder'], strict=True) | |
self.guidance_encoder.eval() | |
self.index_decoder.load_state_dict( | |
checkpoint['index_decoder'], strict=True) | |
self.index_decoder.eval() | |
def save_network(self, save_path): | |
"""Save networks. | |
Args: | |
net (nn.Module): Network to be saved. | |
net_label (str): Network label. | |
current_iter (int): Current iter number. | |
""" | |
save_dict = {} | |
save_dict['guidance_encoder'] = self.guidance_encoder.state_dict() | |
save_dict['index_decoder'] = self.index_decoder.state_dict() | |
torch.save(save_dict, save_path) | |
def update_learning_rate(self, epoch): | |
"""Update learning rate. | |
Args: | |
current_iter (int): Current iteration. | |
warmup_iter (int): Warmup iter numbers. -1 for no warmup. | |
Default: -1. | |
""" | |
lr = self.optimizer.param_groups[0]['lr'] | |
if self.opt['lr_decay'] == 'step': | |
lr = self.opt['lr'] * ( | |
self.opt['gamma']**(epoch // self.opt['step'])) | |
elif self.opt['lr_decay'] == 'cos': | |
lr = self.opt['lr'] * ( | |
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2 | |
elif self.opt['lr_decay'] == 'linear': | |
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs']) | |
elif self.opt['lr_decay'] == 'linear2exp': | |
if epoch < self.opt['turning_point'] + 1: | |
# learning rate decay as 95% | |
# at the turning point (1 / 95% = 1.0526) | |
lr = self.opt['lr'] * ( | |
1 - epoch / int(self.opt['turning_point'] * 1.0526)) | |
else: | |
lr *= self.opt['gamma'] | |
elif self.opt['lr_decay'] == 'schedule': | |
if epoch in self.opt['schedule']: | |
lr *= self.opt['gamma'] | |
else: | |
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay'])) | |
# set learning rate | |
for param_group in self.optimizer.param_groups: | |
param_group['lr'] = lr | |
return lr | |
def get_current_log(self): | |
return self.log_dict | |