Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
import torch.nn.functional as F | |
class BCELoss(nn.Module): | |
def forward(self, prediction, target): | |
loss = F.binary_cross_entropy_with_logits(prediction, target) | |
return loss, {} | |
class BCELossWithQuant(nn.Module): | |
def __init__(self, codebook_weight=1.): | |
super().__init__() | |
self.codebook_weight = codebook_weight | |
def forward(self, qloss, target, prediction, split): | |
bce_loss = F.binary_cross_entropy_with_logits(prediction, target) | |
loss = bce_loss + self.codebook_weight * qloss | |
return loss, { | |
"{}/total_loss".format(split): loss.clone().detach().mean(), | |
"{}/bce_loss".format(split): bce_loss.detach().mean(), | |
"{}/quant_loss".format(split): qloss.detach().mean() | |
} | |