|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.autograd import Function |
|
from torch.autograd.function import once_differentiable |
|
|
|
from ..utils import ext_loader |
|
|
|
ext_module = ext_loader.load_ext('_ext', [ |
|
'sigmoid_focal_loss_forward', 'sigmoid_focal_loss_backward', |
|
'softmax_focal_loss_forward', 'softmax_focal_loss_backward' |
|
]) |
|
|
|
|
|
class SigmoidFocalLossFunction(Function): |
|
|
|
@staticmethod |
|
def symbolic(g, input, target, gamma, alpha, weight, reduction): |
|
return g.op( |
|
'mmcv::MMCVSigmoidFocalLoss', |
|
input, |
|
target, |
|
gamma_f=gamma, |
|
alpha_f=alpha, |
|
weight_f=weight, |
|
reduction_s=reduction) |
|
|
|
@staticmethod |
|
def forward(ctx, |
|
input, |
|
target, |
|
gamma=2.0, |
|
alpha=0.25, |
|
weight=None, |
|
reduction='mean'): |
|
|
|
assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor)) |
|
assert input.dim() == 2 |
|
assert target.dim() == 1 |
|
assert input.size(0) == target.size(0) |
|
if weight is None: |
|
weight = input.new_empty(0) |
|
else: |
|
assert weight.dim() == 1 |
|
assert input.size(1) == weight.size(0) |
|
ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2} |
|
assert reduction in ctx.reduction_dict.keys() |
|
|
|
ctx.gamma = float(gamma) |
|
ctx.alpha = float(alpha) |
|
ctx.reduction = ctx.reduction_dict[reduction] |
|
|
|
output = input.new_zeros(input.size()) |
|
|
|
ext_module.sigmoid_focal_loss_forward( |
|
input, target, weight, output, gamma=ctx.gamma, alpha=ctx.alpha) |
|
if ctx.reduction == ctx.reduction_dict['mean']: |
|
output = output.sum() / input.size(0) |
|
elif ctx.reduction == ctx.reduction_dict['sum']: |
|
output = output.sum() |
|
ctx.save_for_backward(input, target, weight) |
|
return output |
|
|
|
@staticmethod |
|
@once_differentiable |
|
def backward(ctx, grad_output): |
|
input, target, weight = ctx.saved_tensors |
|
|
|
grad_input = input.new_zeros(input.size()) |
|
|
|
ext_module.sigmoid_focal_loss_backward( |
|
input, |
|
target, |
|
weight, |
|
grad_input, |
|
gamma=ctx.gamma, |
|
alpha=ctx.alpha) |
|
|
|
grad_input *= grad_output |
|
if ctx.reduction == ctx.reduction_dict['mean']: |
|
grad_input /= input.size(0) |
|
return grad_input, None, None, None, None, None |
|
|
|
|
|
sigmoid_focal_loss = SigmoidFocalLossFunction.apply |
|
|
|
|
|
class SigmoidFocalLoss(nn.Module): |
|
|
|
def __init__(self, gamma, alpha, weight=None, reduction='mean'): |
|
super(SigmoidFocalLoss, self).__init__() |
|
self.gamma = gamma |
|
self.alpha = alpha |
|
self.register_buffer('weight', weight) |
|
self.reduction = reduction |
|
|
|
def forward(self, input, target): |
|
return sigmoid_focal_loss(input, target, self.gamma, self.alpha, |
|
self.weight, self.reduction) |
|
|
|
def __repr__(self): |
|
s = self.__class__.__name__ |
|
s += f'(gamma={self.gamma}, ' |
|
s += f'alpha={self.alpha}, ' |
|
s += f'reduction={self.reduction})' |
|
return s |
|
|
|
|
|
class SoftmaxFocalLossFunction(Function): |
|
|
|
@staticmethod |
|
def symbolic(g, input, target, gamma, alpha, weight, reduction): |
|
return g.op( |
|
'mmcv::MMCVSoftmaxFocalLoss', |
|
input, |
|
target, |
|
gamma_f=gamma, |
|
alpha_f=alpha, |
|
weight_f=weight, |
|
reduction_s=reduction) |
|
|
|
@staticmethod |
|
def forward(ctx, |
|
input, |
|
target, |
|
gamma=2.0, |
|
alpha=0.25, |
|
weight=None, |
|
reduction='mean'): |
|
|
|
assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor)) |
|
assert input.dim() == 2 |
|
assert target.dim() == 1 |
|
assert input.size(0) == target.size(0) |
|
if weight is None: |
|
weight = input.new_empty(0) |
|
else: |
|
assert weight.dim() == 1 |
|
assert input.size(1) == weight.size(0) |
|
ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2} |
|
assert reduction in ctx.reduction_dict.keys() |
|
|
|
ctx.gamma = float(gamma) |
|
ctx.alpha = float(alpha) |
|
ctx.reduction = ctx.reduction_dict[reduction] |
|
|
|
channel_stats, _ = torch.max(input, dim=1) |
|
input_softmax = input - channel_stats.unsqueeze(1).expand_as(input) |
|
input_softmax.exp_() |
|
|
|
channel_stats = input_softmax.sum(dim=1) |
|
input_softmax /= channel_stats.unsqueeze(1).expand_as(input) |
|
|
|
output = input.new_zeros(input.size(0)) |
|
ext_module.softmax_focal_loss_forward( |
|
input_softmax, |
|
target, |
|
weight, |
|
output, |
|
gamma=ctx.gamma, |
|
alpha=ctx.alpha) |
|
|
|
if ctx.reduction == ctx.reduction_dict['mean']: |
|
output = output.sum() / input.size(0) |
|
elif ctx.reduction == ctx.reduction_dict['sum']: |
|
output = output.sum() |
|
ctx.save_for_backward(input_softmax, target, weight) |
|
return output |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
input_softmax, target, weight = ctx.saved_tensors |
|
buff = input_softmax.new_zeros(input_softmax.size(0)) |
|
grad_input = input_softmax.new_zeros(input_softmax.size()) |
|
|
|
ext_module.softmax_focal_loss_backward( |
|
input_softmax, |
|
target, |
|
weight, |
|
buff, |
|
grad_input, |
|
gamma=ctx.gamma, |
|
alpha=ctx.alpha) |
|
|
|
grad_input *= grad_output |
|
if ctx.reduction == ctx.reduction_dict['mean']: |
|
grad_input /= input_softmax.size(0) |
|
return grad_input, None, None, None, None, None |
|
|
|
|
|
softmax_focal_loss = SoftmaxFocalLossFunction.apply |
|
|
|
|
|
class SoftmaxFocalLoss(nn.Module): |
|
|
|
def __init__(self, gamma, alpha, weight=None, reduction='mean'): |
|
super(SoftmaxFocalLoss, self).__init__() |
|
self.gamma = gamma |
|
self.alpha = alpha |
|
self.register_buffer('weight', weight) |
|
self.reduction = reduction |
|
|
|
def forward(self, input, target): |
|
return softmax_focal_loss(input, target, self.gamma, self.alpha, |
|
self.weight, self.reduction) |
|
|
|
def __repr__(self): |
|
s = self.__class__.__name__ |
|
s += f'(gamma={self.gamma}, ' |
|
s += f'alpha={self.alpha}, ' |
|
s += f'reduction={self.reduction})' |
|
return s |
|
|