import torch import torch.nn as nn import torch.nn.functional as F from scipy.optimize import linear_sum_assignment from .tokenizer import PAD_ID, MASK, MASK_ID class LabelSmoothingLoss(nn.Module): """ With label smoothing, KL-divergence between q_{smoothed ground truth prob.}(w) and p_{prob. computed by model}(w) is minimized. """ def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100): assert 0.0 < label_smoothing <= 1.0 self.ignore_index = ignore_index super(LabelSmoothingLoss, self).__init__() smoothing_value = label_smoothing / (tgt_vocab_size - 2) one_hot = torch.full((tgt_vocab_size,), smoothing_value) one_hot[self.ignore_index] = 0 self.register_buffer('one_hot', one_hot.unsqueeze(0)) self.confidence = 1.0 - label_smoothing def forward(self, output, target): """ output (FloatTensor): batch_size x n_classes target (LongTensor): batch_size """ # assuming output is raw logits # convert to log_probs log_probs = F.log_softmax(output, dim=-1) model_prob = self.one_hot.repeat(target.size(0), 1) model_prob.scatter_(1, target.unsqueeze(1), self.confidence) model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) # reduction mean or sum? return F.kl_div(log_probs, model_prob, reduction='batchmean') class SequenceLoss(nn.Module): def __init__(self, label_smoothing, vocab_size, ignore_index=-100, ignore_indices=[]): super(SequenceLoss, self).__init__() if ignore_indices: ignore_index = ignore_indices[0] self.ignore_index = ignore_index self.ignore_indices = ignore_indices if label_smoothing == 0: self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='mean') else: self.criterion = LabelSmoothingLoss(label_smoothing, vocab_size, ignore_index) def forward(self, output, target): """ :param output: [batch, len, vocab] :param target: [batch, len] :return: """ batch_size, max_len, vocab_size = output.size() output = output.reshape(-1, vocab_size) target = target.reshape(-1) for idx in self.ignore_indices: if idx != self.ignore_index: target.masked_fill_((target == idx), self.ignore_index) loss = self.criterion(output, target) return loss class GraphLoss(nn.Module): def __init__(self): super(GraphLoss, self).__init__() weight = torch.ones(7) * 10 weight[0] = 1 self.criterion = nn.CrossEntropyLoss(weight, ignore_index=-100) def forward(self, outputs, targets): results = {} if 'coords' in outputs: pred = outputs['coords'] max_len = pred.size(1) target = targets['coords'][:, :max_len] mask = target.ge(0) loss = F.l1_loss(pred, target, reduction='none') results['coords'] = (loss * mask).sum() / mask.sum() if 'edges' in outputs: pred = outputs['edges'] max_len = pred.size(-1) target = targets['edges'][:, :max_len, :max_len] results['edges'] = self.criterion(pred, target) return results class Criterion(nn.Module): def __init__(self, args, tokenizer): super(Criterion, self).__init__() criterion = {} for format_ in args.formats: if format_ == 'edges': criterion['edges'] = GraphLoss() else: if MASK in tokenizer[format_].stoi: ignore_indices = [PAD_ID, MASK_ID] else: ignore_indices = [] criterion[format_] = SequenceLoss(args.label_smoothing, len(tokenizer[format_]), ignore_index=PAD_ID, ignore_indices=ignore_indices) self.criterion = nn.ModuleDict(criterion) def forward(self, results, refs): losses = {} for format_ in results: predictions, targets, *_ = results[format_] loss_ = self.criterion[format_](predictions, targets) if type(loss_) is dict: losses.update(loss_) else: if loss_.numel() > 1: loss_ = loss_.mean() losses[format_] = loss_ return losses