KyanChen's picture
init
f549064
raw
history blame
No virus
10.7 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from mmdet.registry import MODELS
from .accuracy import accuracy
from .cross_entropy_loss import cross_entropy
from .utils import weight_reduce_loss
def seesaw_ce_loss(cls_score: Tensor,
labels: Tensor,
label_weights: Tensor,
cum_samples: Tensor,
num_classes: int,
p: float,
q: float,
eps: float,
reduction: str = 'mean',
avg_factor: Optional[int] = None) -> Tensor:
"""Calculate the Seesaw CrossEntropy loss.
Args:
cls_score (Tensor): The prediction with shape (N, C),
C is the number of classes.
labels (Tensor): The learning label of the prediction.
label_weights (Tensor): Sample-wise loss weight.
cum_samples (Tensor): Cumulative samples for each category.
num_classes (int): The number of classes.
p (float): The ``p`` in the mitigation factor.
q (float): The ``q`` in the compenstation factor.
eps (float): The minimal value of divisor to smooth
the computation of compensation factor
reduction (str, optional): The method used to reduce the loss.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
Returns:
Tensor: The calculated loss
"""
assert cls_score.size(-1) == num_classes
assert len(cum_samples) == num_classes
onehot_labels = F.one_hot(labels, num_classes)
seesaw_weights = cls_score.new_ones(onehot_labels.size())
# mitigation factor
if p > 0:
sample_ratio_matrix = cum_samples[None, :].clamp(
min=1) / cum_samples[:, None].clamp(min=1)
index = (sample_ratio_matrix < 1.0).float()
sample_weights = sample_ratio_matrix.pow(p) * index + (1 - index)
mitigation_factor = sample_weights[labels.long(), :]
seesaw_weights = seesaw_weights * mitigation_factor
# compensation factor
if q > 0:
scores = F.softmax(cls_score.detach(), dim=1)
self_scores = scores[
torch.arange(0, len(scores)).to(scores.device).long(),
labels.long()]
score_matrix = scores / self_scores[:, None].clamp(min=eps)
index = (score_matrix > 1.0).float()
compensation_factor = score_matrix.pow(q) * index + (1 - index)
seesaw_weights = seesaw_weights * compensation_factor
cls_score = cls_score + (seesaw_weights.log() * (1 - onehot_labels))
loss = F.cross_entropy(cls_score, labels, weight=None, reduction='none')
if label_weights is not None:
label_weights = label_weights.float()
loss = weight_reduce_loss(
loss, weight=label_weights, reduction=reduction, avg_factor=avg_factor)
return loss
@MODELS.register_module()
class SeesawLoss(nn.Module):
"""
Seesaw Loss for Long-Tailed Instance Segmentation (CVPR 2021)
arXiv: https://arxiv.org/abs/2008.10032
Args:
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
of softmax. Only False is supported.
p (float, optional): The ``p`` in the mitigation factor.
Defaults to 0.8.
q (float, optional): The ``q`` in the compenstation factor.
Defaults to 2.0.
num_classes (int, optional): The number of classes.
Default to 1203 for LVIS v1 dataset.
eps (float, optional): The minimal value of divisor to smooth
the computation of compensation factor
reduction (str, optional): The method that reduces the loss to a
scalar. Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of the loss. Defaults to 1.0
return_dict (bool, optional): Whether return the losses as a dict.
Default to True.
"""
def __init__(self,
use_sigmoid: bool = False,
p: float = 0.8,
q: float = 2.0,
num_classes: int = 1203,
eps: float = 1e-2,
reduction: str = 'mean',
loss_weight: float = 1.0,
return_dict: bool = True) -> None:
super().__init__()
assert not use_sigmoid
self.use_sigmoid = False
self.p = p
self.q = q
self.num_classes = num_classes
self.eps = eps
self.reduction = reduction
self.loss_weight = loss_weight
self.return_dict = return_dict
# 0 for pos, 1 for neg
self.cls_criterion = seesaw_ce_loss
# cumulative samples for each category
self.register_buffer(
'cum_samples',
torch.zeros(self.num_classes + 1, dtype=torch.float))
# custom output channels of the classifier
self.custom_cls_channels = True
# custom activation of cls_score
self.custom_activation = True
# custom accuracy of the classsifier
self.custom_accuracy = True
def _split_cls_score(self, cls_score: Tensor) -> Tuple[Tensor, Tensor]:
"""split cls_score.
Args:
cls_score (Tensor): The prediction with shape (N, C + 2).
Returns:
Tuple[Tensor, Tensor]: The score for classes and objectness,
respectively
"""
# split cls_score to cls_score_classes and cls_score_objectness
assert cls_score.size(-1) == self.num_classes + 2
cls_score_classes = cls_score[..., :-2]
cls_score_objectness = cls_score[..., -2:]
return cls_score_classes, cls_score_objectness
def get_cls_channels(self, num_classes: int) -> int:
"""Get custom classification channels.
Args:
num_classes (int): The number of classes.
Returns:
int: The custom classification channels.
"""
assert num_classes == self.num_classes
return num_classes + 2
def get_activation(self, cls_score: Tensor) -> Tensor:
"""Get custom activation of cls_score.
Args:
cls_score (Tensor): The prediction with shape (N, C + 2).
Returns:
Tensor: The custom activation of cls_score with shape
(N, C + 1).
"""
cls_score_classes, cls_score_objectness = self._split_cls_score(
cls_score)
score_classes = F.softmax(cls_score_classes, dim=-1)
score_objectness = F.softmax(cls_score_objectness, dim=-1)
score_pos = score_objectness[..., [0]]
score_neg = score_objectness[..., [1]]
score_classes = score_classes * score_pos
scores = torch.cat([score_classes, score_neg], dim=-1)
return scores
def get_accuracy(self, cls_score: Tensor,
labels: Tensor) -> Dict[str, Tensor]:
"""Get custom accuracy w.r.t. cls_score and labels.
Args:
cls_score (Tensor): The prediction with shape (N, C + 2).
labels (Tensor): The learning label of the prediction.
Returns:
Dict [str, Tensor]: The accuracy for objectness and classes,
respectively.
"""
pos_inds = labels < self.num_classes
obj_labels = (labels == self.num_classes).long()
cls_score_classes, cls_score_objectness = self._split_cls_score(
cls_score)
acc_objectness = accuracy(cls_score_objectness, obj_labels)
acc_classes = accuracy(cls_score_classes[pos_inds], labels[pos_inds])
acc = dict()
acc['acc_objectness'] = acc_objectness
acc['acc_classes'] = acc_classes
return acc
def forward(
self,
cls_score: Tensor,
labels: Tensor,
label_weights: Optional[Tensor] = None,
avg_factor: Optional[int] = None,
reduction_override: Optional[str] = None
) -> Union[Tensor, Dict[str, Tensor]]:
"""Forward function.
Args:
cls_score (Tensor): The prediction with shape (N, C + 2).
labels (Tensor): The learning label of the prediction.
label_weights (Tensor, optional): Sample-wise loss weight.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
Returns:
Tensor | Dict [str, Tensor]:
if return_dict == False: The calculated loss |
if return_dict == True: The dict of calculated losses
for objectness and classes, respectively.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
assert cls_score.size(-1) == self.num_classes + 2
pos_inds = labels < self.num_classes
# 0 for pos, 1 for neg
obj_labels = (labels == self.num_classes).long()
# accumulate the samples for each category
unique_labels = labels.unique()
for u_l in unique_labels:
inds_ = labels == u_l.item()
self.cum_samples[u_l] += inds_.sum()
if label_weights is not None:
label_weights = label_weights.float()
else:
label_weights = labels.new_ones(labels.size(), dtype=torch.float)
cls_score_classes, cls_score_objectness = self._split_cls_score(
cls_score)
# calculate loss_cls_classes (only need pos samples)
if pos_inds.sum() > 0:
loss_cls_classes = self.loss_weight * self.cls_criterion(
cls_score_classes[pos_inds], labels[pos_inds],
label_weights[pos_inds], self.cum_samples[:self.num_classes],
self.num_classes, self.p, self.q, self.eps, reduction,
avg_factor)
else:
loss_cls_classes = cls_score_classes[pos_inds].sum()
# calculate loss_cls_objectness
loss_cls_objectness = self.loss_weight * cross_entropy(
cls_score_objectness, obj_labels, label_weights, reduction,
avg_factor)
if self.return_dict:
loss_cls = dict()
loss_cls['loss_cls_objectness'] = loss_cls_objectness
loss_cls['loss_cls_classes'] = loss_cls_classes
else:
loss_cls = loss_cls_classes + loss_cls_objectness
return loss_cls