Spaces:
Runtime error
Runtime error
File size: 2,267 Bytes
f549064 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from mmdet.registry import MODELS
from .utils import weighted_loss
@weighted_loss
def mse_loss(pred: Tensor, target: Tensor) -> Tensor:
"""A Wrapper of MSE loss.
Args:
pred (Tensor): The prediction.
target (Tensor): The learning target of the prediction.
Returns:
Tensor: loss Tensor
"""
return F.mse_loss(pred, target, reduction='none')
@MODELS.register_module()
class MSELoss(nn.Module):
"""MSELoss.
Args:
reduction (str, optional): The method that reduces the loss to a
scalar. Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of the loss. Defaults to 1.0
"""
def __init__(self,
reduction: str = 'mean',
loss_weight: float = 1.0) -> None:
super().__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred: Tensor,
target: Tensor,
weight: Optional[Tensor] = None,
avg_factor: Optional[int] = None,
reduction_override: Optional[str] = None) -> Tensor:
"""Forward function of loss.
Args:
pred (Tensor): The prediction.
target (Tensor): The learning target of the prediction.
weight (Tensor, optional): Weight of the loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
Returns:
Tensor: The calculated loss.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss = self.loss_weight * mse_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss
|