File size: 2,480 Bytes
d7e58f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Copyright (c) OpenMMLab. All rights reserved.

from mmcv.utils import Registry

from .balanced_mse_loss import BMCLossMD
from .cross_entropy_loss import CrossEntropyLoss
from .focal_loss import FocalLoss
from .gan_loss import GANLoss
from .iou_loss import BoundedIoULoss, CIoULoss, DIoULoss, GIoULoss, IoULoss
from .mse_loss import KeypointMSELoss, MSELoss
from .prior_loss import (
    CameraPriorLoss,
    JointPriorLoss,
    LimbLengthLoss,
    MaxMixturePrior,
    PoseRegLoss,
    ShapePriorLoss,
    ShapeThresholdPriorLoss,
    SmoothJointLoss,
    SmoothPelvisLoss,
    SmoothTranslationLoss,
)
from .rotaion_distance_loss import RotationDistance
from .smooth_l1_loss import L1Loss, SmoothL1Loss

LOSSES = Registry('losses')

LOSSES.register_module(name='GANLoss', module=GANLoss)
LOSSES.register_module(name='MSELoss', module=MSELoss)
LOSSES.register_module(name='KeypointMSELoss', module=KeypointMSELoss)
LOSSES.register_module(name='ShapePriorLoss', module=ShapePriorLoss)
LOSSES.register_module(name='PoseRegLoss', module=PoseRegLoss)
LOSSES.register_module(name='LimbLengthLoss', module=LimbLengthLoss)
LOSSES.register_module(name='JointPriorLoss', module=JointPriorLoss)
LOSSES.register_module(name='SmoothJointLoss', module=SmoothJointLoss)
LOSSES.register_module(name='SmoothPelvisLoss', module=SmoothPelvisLoss)
LOSSES.register_module(name='SmoothTranslationLoss',
                       module=SmoothTranslationLoss)
LOSSES.register_module(name='ShapeThresholdPriorLoss',
                       module=ShapeThresholdPriorLoss)
LOSSES.register_module(name='CameraPriorLoss', module=CameraPriorLoss)
LOSSES.register_module(name='MaxMixturePrior', module=MaxMixturePrior)
LOSSES.register_module(name='L1Loss', module=L1Loss)
LOSSES.register_module(name='SmoothL1Loss', module=SmoothL1Loss)
LOSSES.register_module(name='CrossEntropyLoss', module=CrossEntropyLoss)
LOSSES.register_module(name='RotationDistance', module=RotationDistance)
LOSSES.register_module(name='BMCLossMD', module=BMCLossMD)
LOSSES.register_module(name='FocalLoss', module=FocalLoss)
LOSSES.register_module(name='IoULoss', module=IoULoss)
LOSSES.register_module(name='BoundedIoULoss', module=BoundedIoULoss)
LOSSES.register_module(name='GIoULoss', module=GIoULoss)
LOSSES.register_module(name='DIoULoss', module=DIoULoss)
LOSSES.register_module(name='CIoULoss', module=CIoULoss)


def build_loss(cfg):
    """Build loss."""
    if cfg is None:
        return None
    return LOSSES.build(cfg)