import torch.nn as nn import torch.nn.functional as F class BCELoss(nn.Module): def forward(self, prediction, target): loss = F.binary_cross_entropy_with_logits(prediction, target) return loss, {} class BCELossWithQuant(nn.Module): def __init__(self, codebook_weight=1.): super().__init__() self.codebook_weight = codebook_weight def forward(self, qloss, target, prediction, split): bce_loss = F.binary_cross_entropy_with_logits(prediction, target) loss = bce_loss + self.codebook_weight * qloss return loss, { "{}/total_loss".format(split): loss.clone().detach().mean(), "{}/bce_loss".format(split): bce_loss.detach().mean(), "{}/quant_loss".format(split): qloss.detach().mean() }