import torch import torch.nn as nn import torch.nn.functional as F try: from itertools import ifilterfalse except ImportError: # py3k from itertools import filterfalse as ifilterfalse def dice_loss(probas, labels, smooth=1): C = probas.size(1) losses = [] for c in list(range(C)): fg = (labels == c).float() if fg.sum() == 0: continue class_pred = probas[:, c] p0 = class_pred g0 = fg numerator = 2 * torch.sum(p0 * g0) + smooth denominator = torch.sum(p0) + torch.sum(g0) + smooth losses.append(1 - ((numerator) / (denominator))) return mean(losses) def tversky_loss(probas, labels, alpha=0.5, beta=0.5, epsilon=1e-6): ''' Tversky loss function. probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) labels: [P] Tensor, ground truth labels (between 0 and C - 1) Same as soft dice loss when alpha=beta=0.5. Same as Jaccord loss when alpha=beta=1.0. See `Tversky loss function for image segmentation using 3D fully convolutional deep networks` https://arxiv.org/pdf/1706.05721.pdf ''' C = probas.size(1) losses = [] for c in list(range(C)): fg = (labels == c).float() if fg.sum() == 0: continue class_pred = probas[:, c] p0 = class_pred p1 = 1 - class_pred g0 = fg g1 = 1 - fg numerator = torch.sum(p0 * g0) denominator = numerator + alpha * \ torch.sum(p0*g1) + beta*torch.sum(p1*g0) losses.append(1 - ((numerator) / (denominator + epsilon))) return mean(losses) def flatten_probas(probas, labels, ignore=255): """ Flattens predictions in the batch """ B, C, H, W = probas.size() probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C labels = labels.view(-1) if ignore is None: return probas, labels valid = (labels != ignore) vprobas = probas[valid.view(-1, 1).expand(-1, C)].reshape(-1, C) # vprobas = probas[torch.nonzero(valid).squeeze()] vlabels = labels[valid] return vprobas, vlabels def isnan(x): return x != x def mean(l, ignore_nan=False, empty=0): """ nanmean compatible with generators. """ l = iter(l) if ignore_nan: l = ifilterfalse(isnan, l) try: n = 1 acc = next(l) except StopIteration: if empty == 'raise': raise ValueError('Empty mean') return empty for n, v in enumerate(l, 2): acc += v if n == 1: return acc return acc / n class DiceLoss(nn.Module): def __init__(self, ignore_index=255): super(DiceLoss, self).__init__() self.ignore_index = ignore_index def forward(self, tmp_dic, label_dic, step=None): total_loss = [] for idx in range(len(tmp_dic)): pred = tmp_dic[idx] label = label_dic[idx] pred = F.softmax(pred, dim=1) label = label.view(1, 1, pred.size()[2], pred.size()[3]) loss = dice_loss( *flatten_probas(pred, label, ignore=self.ignore_index)) total_loss.append(loss.unsqueeze(0)) total_loss = torch.cat(total_loss, dim=0) return total_loss class SoftJaccordLoss(nn.Module): def __init__(self, ignore_index=255): super(SoftJaccordLoss, self).__init__() self.ignore_index = ignore_index def forward(self, tmp_dic, label_dic, step=None): total_loss = [] for idx in range(len(tmp_dic)): pred = tmp_dic[idx] label = label_dic[idx] pred = F.softmax(pred, dim=1) label = label.view(1, 1, pred.size()[2], pred.size()[3]) loss = tversky_loss(*flatten_probas(pred, label, ignore=self.ignore_index), alpha=1.0, beta=1.0) total_loss.append(loss.unsqueeze(0)) total_loss = torch.cat(total_loss, dim=0) return total_loss class CrossEntropyLoss(nn.Module): def __init__(self, top_k_percent_pixels=None, hard_example_mining_step=100000): super(CrossEntropyLoss, self).__init__() self.top_k_percent_pixels = top_k_percent_pixels if top_k_percent_pixels is not None: assert (top_k_percent_pixels > 0 and top_k_percent_pixels < 1) self.hard_example_mining_step = hard_example_mining_step + 1e-5 if self.top_k_percent_pixels is None: self.celoss = nn.CrossEntropyLoss(ignore_index=255, reduction='mean') else: self.celoss = nn.CrossEntropyLoss(ignore_index=255, reduction='none') def forward(self, dic_tmp, y, step): total_loss = [] for i in range(len(dic_tmp)): pred_logits = dic_tmp[i] gts = y[i] if self.top_k_percent_pixels is None: final_loss = self.celoss(pred_logits, gts) else: # Only compute the loss for top k percent pixels. # First, compute the loss for all pixels. Note we do not put the loss # to loss_collection and set reduction = None to keep the shape. num_pixels = float(pred_logits.size(2) * pred_logits.size(3)) pred_logits = pred_logits.view( -1, pred_logits.size(1), pred_logits.size(2) * pred_logits.size(3)) gts = gts.view(-1, gts.size(1) * gts.size(2)) pixel_losses = self.celoss(pred_logits, gts) if self.hard_example_mining_step == 0: top_k_pixels = int(self.top_k_percent_pixels * num_pixels) else: ratio = min(1.0, step / float(self.hard_example_mining_step)) top_k_pixels = int((ratio * self.top_k_percent_pixels + (1.0 - ratio)) * num_pixels) top_k_loss, top_k_indices = torch.topk(pixel_losses, k=top_k_pixels, dim=1) final_loss = torch.mean(top_k_loss) final_loss = final_loss.unsqueeze(0) total_loss.append(final_loss) total_loss = torch.cat(total_loss, dim=0) return total_loss