Spaces:
Runtime error
Runtime error
import math | |
import sys | |
from collections import OrderedDict | |
sys.path.append('..') | |
import lpips | |
import torch | |
import torch.nn.functional as F | |
from torchvision.utils import save_image | |
from models.archs.vqgan_arch import (Decoder, Discriminator, Encoder, | |
VectorQuantizer, VectorQuantizerTexture) | |
from models.losses.segmentation_loss import BCELossWithQuant | |
from models.losses.vqgan_loss import (DiffAugment, adopt_weight, | |
calculate_adaptive_weight, hinge_d_loss) | |
class VQModel(): | |
def __init__(self, opt): | |
super().__init__() | |
self.opt = opt | |
self.device = torch.device('cuda') | |
self.encoder = Encoder( | |
ch=opt['ch'], | |
num_res_blocks=opt['num_res_blocks'], | |
attn_resolutions=opt['attn_resolutions'], | |
ch_mult=opt['ch_mult'], | |
in_channels=opt['in_channels'], | |
resolution=opt['resolution'], | |
z_channels=opt['z_channels'], | |
double_z=opt['double_z'], | |
dropout=opt['dropout']).to(self.device) | |
self.decoder = Decoder( | |
in_channels=opt['in_channels'], | |
resolution=opt['resolution'], | |
z_channels=opt['z_channels'], | |
ch=opt['ch'], | |
out_ch=opt['out_ch'], | |
num_res_blocks=opt['num_res_blocks'], | |
attn_resolutions=opt['attn_resolutions'], | |
ch_mult=opt['ch_mult'], | |
dropout=opt['dropout'], | |
resamp_with_conv=True, | |
give_pre_end=False).to(self.device) | |
self.quantize = VectorQuantizer( | |
opt['n_embed'], opt['embed_dim'], beta=0.25).to(self.device) | |
self.quant_conv = torch.nn.Conv2d(opt["z_channels"], opt['embed_dim'], | |
1).to(self.device) | |
self.post_quant_conv = torch.nn.Conv2d(opt['embed_dim'], | |
opt["z_channels"], | |
1).to(self.device) | |
def init_training_settings(self): | |
self.loss = BCELossWithQuant() | |
self.log_dict = OrderedDict() | |
self.configure_optimizers() | |
def save_network(self, save_path): | |
"""Save networks. | |
Args: | |
net (nn.Module): Network to be saved. | |
net_label (str): Network label. | |
current_iter (int): Current iter number. | |
""" | |
save_dict = {} | |
save_dict['encoder'] = self.encoder.state_dict() | |
save_dict['decoder'] = self.decoder.state_dict() | |
save_dict['quantize'] = self.quantize.state_dict() | |
save_dict['quant_conv'] = self.quant_conv.state_dict() | |
save_dict['post_quant_conv'] = self.post_quant_conv.state_dict() | |
save_dict['discriminator'] = self.disc.state_dict() | |
torch.save(save_dict, save_path) | |
def load_network(self): | |
checkpoint = torch.load(self.opt['pretrained_models']) | |
self.encoder.load_state_dict(checkpoint['encoder'], strict=True) | |
self.decoder.load_state_dict(checkpoint['decoder'], strict=True) | |
self.quantize.load_state_dict(checkpoint['quantize'], strict=True) | |
self.quant_conv.load_state_dict(checkpoint['quant_conv'], strict=True) | |
self.post_quant_conv.load_state_dict( | |
checkpoint['post_quant_conv'], strict=True) | |
def optimize_parameters(self, data, current_iter): | |
self.encoder.train() | |
self.decoder.train() | |
self.quantize.train() | |
self.quant_conv.train() | |
self.post_quant_conv.train() | |
loss = self.training_step(data) | |
self.optimizer.zero_grad() | |
loss.backward() | |
self.optimizer.step() | |
def encode(self, x): | |
h = self.encoder(x) | |
h = self.quant_conv(h) | |
quant, emb_loss, info = self.quantize(h) | |
return quant, emb_loss, info | |
def decode(self, quant): | |
quant = self.post_quant_conv(quant) | |
dec = self.decoder(quant) | |
return dec | |
def decode_code(self, code_b): | |
quant_b = self.quantize.embed_code(code_b) | |
dec = self.decode(quant_b) | |
return dec | |
def forward_step(self, input): | |
quant, diff, _ = self.encode(input) | |
dec = self.decode(quant) | |
return dec, diff | |
def feed_data(self, data): | |
x = data['segm'] | |
x = F.one_hot(x, num_classes=self.opt['num_segm_classes']) | |
if len(x.shape) == 3: | |
x = x[..., None] | |
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) | |
return x.float().to(self.device) | |
def get_current_log(self): | |
return self.log_dict | |
def update_learning_rate(self, epoch): | |
"""Update learning rate. | |
Args: | |
current_iter (int): Current iteration. | |
warmup_iter (int): Warmup iter numbers. -1 for no warmup. | |
Default: -1. | |
""" | |
lr = self.optimizer.param_groups[0]['lr'] | |
if self.opt['lr_decay'] == 'step': | |
lr = self.opt['lr'] * ( | |
self.opt['gamma']**(epoch // self.opt['step'])) | |
elif self.opt['lr_decay'] == 'cos': | |
lr = self.opt['lr'] * ( | |
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2 | |
elif self.opt['lr_decay'] == 'linear': | |
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs']) | |
elif self.opt['lr_decay'] == 'linear2exp': | |
if epoch < self.opt['turning_point'] + 1: | |
# learning rate decay as 95% | |
# at the turning point (1 / 95% = 1.0526) | |
lr = self.opt['lr'] * ( | |
1 - epoch / int(self.opt['turning_point'] * 1.0526)) | |
else: | |
lr *= self.opt['gamma'] | |
elif self.opt['lr_decay'] == 'schedule': | |
if epoch in self.opt['schedule']: | |
lr *= self.opt['gamma'] | |
else: | |
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay'])) | |
# set learning rate | |
for param_group in self.optimizer.param_groups: | |
param_group['lr'] = lr | |
return lr | |
class VQSegmentationModel(VQModel): | |
def __init__(self, opt): | |
super().__init__(opt) | |
self.colorize = torch.randn(3, opt['num_segm_classes'], 1, | |
1).to(self.device) | |
self.init_training_settings() | |
def configure_optimizers(self): | |
self.optimizer = torch.optim.Adam( | |
list(self.encoder.parameters()) + list(self.decoder.parameters()) + | |
list(self.quantize.parameters()) + | |
list(self.quant_conv.parameters()) + | |
list(self.post_quant_conv.parameters()), | |
lr=self.opt['lr'], | |
betas=(0.5, 0.9)) | |
def training_step(self, data): | |
x = self.feed_data(data) | |
xrec, qloss = self.forward_step(x) | |
aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train") | |
self.log_dict.update(log_dict_ae) | |
return aeloss | |
def to_rgb(self, x): | |
x = F.conv2d(x, weight=self.colorize) | |
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. | |
return x | |
def inference(self, data_loader, save_dir): | |
self.encoder.eval() | |
self.decoder.eval() | |
self.quantize.eval() | |
self.quant_conv.eval() | |
self.post_quant_conv.eval() | |
loss_total = 0 | |
loss_bce = 0 | |
loss_quant = 0 | |
num = 0 | |
for _, data in enumerate(data_loader): | |
img_name = data['img_name'][0] | |
x = self.feed_data(data) | |
xrec, qloss = self.forward_step(x) | |
_, log_dict_ae = self.loss(qloss, x, xrec, split="val") | |
loss_total += log_dict_ae['val/total_loss'] | |
loss_bce += log_dict_ae['val/bce_loss'] | |
loss_quant += log_dict_ae['val/quant_loss'] | |
num += x.size(0) | |
if x.shape[1] > 3: | |
# colorize with random projection | |
assert xrec.shape[1] > 3 | |
# convert logits to indices | |
xrec = torch.argmax(xrec, dim=1, keepdim=True) | |
xrec = F.one_hot(xrec, num_classes=x.shape[1]) | |
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float() | |
x = self.to_rgb(x) | |
xrec = self.to_rgb(xrec) | |
img_cat = torch.cat([x, xrec], dim=3).detach() | |
img_cat = ((img_cat + 1) / 2) | |
img_cat = img_cat.clamp_(0, 1) | |
save_image( | |
img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4) | |
return (loss_total / num).item(), (loss_bce / | |
num).item(), (loss_quant / | |
num).item() | |
class VQImageModel(VQModel): | |
def __init__(self, opt): | |
super().__init__(opt) | |
self.disc = Discriminator( | |
opt['n_channels'], opt['ndf'], | |
n_layers=opt['disc_layers']).to(self.device) | |
self.perceptual = lpips.LPIPS(net="vgg").to(self.device) | |
self.perceptual_weight = opt['perceptual_weight'] | |
self.disc_start_step = opt['disc_start_step'] | |
self.disc_weight_max = opt['disc_weight_max'] | |
self.diff_aug = opt['diff_aug'] | |
self.policy = "color,translation" | |
self.disc.train() | |
self.init_training_settings() | |
def feed_data(self, data): | |
x = data['image'] | |
return x.float().to(self.device) | |
def init_training_settings(self): | |
self.log_dict = OrderedDict() | |
self.configure_optimizers() | |
def configure_optimizers(self): | |
self.optimizer = torch.optim.Adam( | |
list(self.encoder.parameters()) + list(self.decoder.parameters()) + | |
list(self.quantize.parameters()) + | |
list(self.quant_conv.parameters()) + | |
list(self.post_quant_conv.parameters()), | |
lr=self.opt['lr']) | |
self.disc_optimizer = torch.optim.Adam( | |
self.disc.parameters(), lr=self.opt['lr']) | |
def training_step(self, data, step): | |
x = self.feed_data(data) | |
xrec, codebook_loss = self.forward_step(x) | |
# get recon/perceptual loss | |
recon_loss = torch.abs(x.contiguous() - xrec.contiguous()) | |
p_loss = self.perceptual(x.contiguous(), xrec.contiguous()) | |
nll_loss = recon_loss + self.perceptual_weight * p_loss | |
nll_loss = torch.mean(nll_loss) | |
# augment for input to discriminator | |
if self.diff_aug: | |
xrec = DiffAugment(xrec, policy=self.policy) | |
# update generator | |
logits_fake = self.disc(xrec) | |
g_loss = -torch.mean(logits_fake) | |
last_layer = self.decoder.conv_out.weight | |
d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer, | |
self.disc_weight_max) | |
d_weight *= adopt_weight(1, step, self.disc_start_step) | |
loss = nll_loss + d_weight * g_loss + codebook_loss | |
self.log_dict["loss"] = loss | |
self.log_dict["l1"] = recon_loss.mean().item() | |
self.log_dict["perceptual"] = p_loss.mean().item() | |
self.log_dict["nll_loss"] = nll_loss.item() | |
self.log_dict["g_loss"] = g_loss.item() | |
self.log_dict["d_weight"] = d_weight | |
self.log_dict["codebook_loss"] = codebook_loss.item() | |
if step > self.disc_start_step: | |
if self.diff_aug: | |
logits_real = self.disc( | |
DiffAugment(x.contiguous().detach(), policy=self.policy)) | |
else: | |
logits_real = self.disc(x.contiguous().detach()) | |
logits_fake = self.disc(xrec.contiguous().detach( | |
)) # detach so that generator isn"t also updated | |
d_loss = hinge_d_loss(logits_real, logits_fake) | |
self.log_dict["d_loss"] = d_loss | |
else: | |
d_loss = None | |
return loss, d_loss | |
def optimize_parameters(self, data, step): | |
self.encoder.train() | |
self.decoder.train() | |
self.quantize.train() | |
self.quant_conv.train() | |
self.post_quant_conv.train() | |
loss, d_loss = self.training_step(data, step) | |
self.optimizer.zero_grad() | |
loss.backward() | |
self.optimizer.step() | |
if step > self.disc_start_step: | |
self.disc_optimizer.zero_grad() | |
d_loss.backward() | |
self.disc_optimizer.step() | |
def inference(self, data_loader, save_dir): | |
self.encoder.eval() | |
self.decoder.eval() | |
self.quantize.eval() | |
self.quant_conv.eval() | |
self.post_quant_conv.eval() | |
loss_total = 0 | |
num = 0 | |
for _, data in enumerate(data_loader): | |
img_name = data['img_name'][0] | |
x = self.feed_data(data) | |
xrec, _ = self.forward_step(x) | |
recon_loss = torch.abs(x.contiguous() - xrec.contiguous()) | |
p_loss = self.perceptual(x.contiguous(), xrec.contiguous()) | |
nll_loss = recon_loss + self.perceptual_weight * p_loss | |
nll_loss = torch.mean(nll_loss) | |
loss_total += nll_loss | |
num += x.size(0) | |
if x.shape[1] > 3: | |
# colorize with random projection | |
assert xrec.shape[1] > 3 | |
# convert logits to indices | |
xrec = torch.argmax(xrec, dim=1, keepdim=True) | |
xrec = F.one_hot(xrec, num_classes=x.shape[1]) | |
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float() | |
x = self.to_rgb(x) | |
xrec = self.to_rgb(xrec) | |
img_cat = torch.cat([x, xrec], dim=3).detach() | |
img_cat = ((img_cat + 1) / 2) | |
img_cat = img_cat.clamp_(0, 1) | |
save_image( | |
img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4) | |
return (loss_total / num).item() | |
class VQImageSegmTextureModel(VQImageModel): | |
def __init__(self, opt): | |
self.opt = opt | |
self.device = torch.device('cuda') | |
self.encoder = Encoder( | |
ch=opt['ch'], | |
num_res_blocks=opt['num_res_blocks'], | |
attn_resolutions=opt['attn_resolutions'], | |
ch_mult=opt['ch_mult'], | |
in_channels=opt['in_channels'], | |
resolution=opt['resolution'], | |
z_channels=opt['z_channels'], | |
double_z=opt['double_z'], | |
dropout=opt['dropout']).to(self.device) | |
self.decoder = Decoder( | |
in_channels=opt['in_channels'], | |
resolution=opt['resolution'], | |
z_channels=opt['z_channels'], | |
ch=opt['ch'], | |
out_ch=opt['out_ch'], | |
num_res_blocks=opt['num_res_blocks'], | |
attn_resolutions=opt['attn_resolutions'], | |
ch_mult=opt['ch_mult'], | |
dropout=opt['dropout'], | |
resamp_with_conv=True, | |
give_pre_end=False).to(self.device) | |
self.quantize = VectorQuantizerTexture( | |
opt['n_embed'], opt['embed_dim'], beta=0.25).to(self.device) | |
self.quant_conv = torch.nn.Conv2d(opt["z_channels"], opt['embed_dim'], | |
1).to(self.device) | |
self.post_quant_conv = torch.nn.Conv2d(opt['embed_dim'], | |
opt["z_channels"], | |
1).to(self.device) | |
self.disc = Discriminator( | |
opt['n_channels'], opt['ndf'], | |
n_layers=opt['disc_layers']).to(self.device) | |
self.perceptual = lpips.LPIPS(net="vgg").to(self.device) | |
self.perceptual_weight = opt['perceptual_weight'] | |
self.disc_start_step = opt['disc_start_step'] | |
self.disc_weight_max = opt['disc_weight_max'] | |
self.diff_aug = opt['diff_aug'] | |
self.policy = "color,translation" | |
self.disc.train() | |
self.init_training_settings() | |
def feed_data(self, data): | |
x = data['image'].float().to(self.device) | |
mask = data['texture_mask'].float().to(self.device) | |
return x, mask | |
def training_step(self, data, step): | |
x, mask = self.feed_data(data) | |
xrec, codebook_loss = self.forward_step(x, mask) | |
# get recon/perceptual loss | |
recon_loss = torch.abs(x.contiguous() - xrec.contiguous()) | |
p_loss = self.perceptual(x.contiguous(), xrec.contiguous()) | |
nll_loss = recon_loss + self.perceptual_weight * p_loss | |
nll_loss = torch.mean(nll_loss) | |
# augment for input to discriminator | |
if self.diff_aug: | |
xrec = DiffAugment(xrec, policy=self.policy) | |
# update generator | |
logits_fake = self.disc(xrec) | |
g_loss = -torch.mean(logits_fake) | |
last_layer = self.decoder.conv_out.weight | |
d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer, | |
self.disc_weight_max) | |
d_weight *= adopt_weight(1, step, self.disc_start_step) | |
loss = nll_loss + d_weight * g_loss + codebook_loss | |
self.log_dict["loss"] = loss | |
self.log_dict["l1"] = recon_loss.mean().item() | |
self.log_dict["perceptual"] = p_loss.mean().item() | |
self.log_dict["nll_loss"] = nll_loss.item() | |
self.log_dict["g_loss"] = g_loss.item() | |
self.log_dict["d_weight"] = d_weight | |
self.log_dict["codebook_loss"] = codebook_loss.item() | |
if step > self.disc_start_step: | |
if self.diff_aug: | |
logits_real = self.disc( | |
DiffAugment(x.contiguous().detach(), policy=self.policy)) | |
else: | |
logits_real = self.disc(x.contiguous().detach()) | |
logits_fake = self.disc(xrec.contiguous().detach( | |
)) # detach so that generator isn"t also updated | |
d_loss = hinge_d_loss(logits_real, logits_fake) | |
self.log_dict["d_loss"] = d_loss | |
else: | |
d_loss = None | |
return loss, d_loss | |
def inference(self, data_loader, save_dir): | |
self.encoder.eval() | |
self.decoder.eval() | |
self.quantize.eval() | |
self.quant_conv.eval() | |
self.post_quant_conv.eval() | |
loss_total = 0 | |
num = 0 | |
for _, data in enumerate(data_loader): | |
img_name = data['img_name'][0] | |
x, mask = self.feed_data(data) | |
xrec, _ = self.forward_step(x, mask) | |
recon_loss = torch.abs(x.contiguous() - xrec.contiguous()) | |
p_loss = self.perceptual(x.contiguous(), xrec.contiguous()) | |
nll_loss = recon_loss + self.perceptual_weight * p_loss | |
nll_loss = torch.mean(nll_loss) | |
loss_total += nll_loss | |
num += x.size(0) | |
if x.shape[1] > 3: | |
# colorize with random projection | |
assert xrec.shape[1] > 3 | |
# convert logits to indices | |
xrec = torch.argmax(xrec, dim=1, keepdim=True) | |
xrec = F.one_hot(xrec, num_classes=x.shape[1]) | |
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float() | |
x = self.to_rgb(x) | |
xrec = self.to_rgb(xrec) | |
img_cat = torch.cat([x, xrec], dim=3).detach() | |
img_cat = ((img_cat + 1) / 2) | |
img_cat = img_cat.clamp_(0, 1) | |
save_image( | |
img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4) | |
return (loss_total / num).item() | |
def encode(self, x, mask): | |
h = self.encoder(x) | |
h = self.quant_conv(h) | |
quant, emb_loss, info = self.quantize(h, mask) | |
return quant, emb_loss, info | |
def decode(self, quant): | |
quant = self.post_quant_conv(quant) | |
dec = self.decoder(quant) | |
return dec | |
def decode_code(self, code_b): | |
quant_b = self.quantize.embed_code(code_b) | |
dec = self.decode(quant_b) | |
return dec | |
def forward_step(self, input, mask): | |
quant, diff, _ = self.encode(input, mask) | |
dec = self.decode(quant) | |
return dec, diff | |