File size: 6,298 Bytes
f549064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import torch
from torch import Tensor

from mmdet.registry import MODELS
from mmdet.utils import InstanceList, OptInstanceList
from ..losses import carl_loss, isr_p
from ..utils import images_to_levels
from .retina_head import RetinaHead


@MODELS.register_module()
class PISARetinaHead(RetinaHead):
    """PISA Retinanet Head.

    The head owns the same structure with Retinanet Head, but differs in two
        aspects:
        1. Importance-based Sample Reweighting Positive (ISR-P) is applied to
            change the positive loss weights.
        2. Classification-aware regression loss is adopted as a third loss.
    """

    def loss_by_feat(
            self,
            cls_scores: List[Tensor],
            bbox_preds: List[Tensor],
            batch_gt_instances: InstanceList,
            batch_img_metas: List[dict],
            batch_gt_instances_ignore: OptInstanceList = None) -> dict:
        """Compute losses of the head.

        Args:
            cls_scores (list[Tensor]): Box scores for each scale level
                Has shape (N, num_anchors * num_classes, H, W)
            bbox_preds (list[Tensor]): Box energies / deltas for each scale
                level with shape (N, num_anchors * 4, H, W)
            batch_gt_instances (list[:obj:`InstanceData`]): Batch of
                gt_instance. It usually includes ``bboxes`` and ``labels``
                attributes.
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
                image size, scaling factor, etc.
            batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
                Batch of gt_instances_ignore. It includes ``bboxes`` attribute
                data that is ignored during training and testing.
                Defaults to None.

        Returns:
            dict: Loss dict, comprise classification loss, regression loss and
            carl loss.
        """
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == self.prior_generator.num_levels

        device = cls_scores[0].device

        anchor_list, valid_flag_list = self.get_anchors(
            featmap_sizes, batch_img_metas, device=device)
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
        cls_reg_targets = self.get_targets(
            anchor_list,
            valid_flag_list,
            batch_gt_instances,
            batch_img_metas,
            batch_gt_instances_ignore=batch_gt_instances_ignore,
            return_sampling_results=True)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         avg_factor, sampling_results_list) = cls_reg_targets

        # anchor number of multi levels
        num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
        # concat all level anchors and flags to a single tensor
        concat_anchor_list = []
        for i in range(len(anchor_list)):
            concat_anchor_list.append(torch.cat(anchor_list[i]))
        all_anchor_list = images_to_levels(concat_anchor_list,
                                           num_level_anchors)

        num_imgs = len(batch_img_metas)
        flatten_cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, label_channels)
            for cls_score in cls_scores
        ]
        flatten_cls_scores = torch.cat(
            flatten_cls_scores, dim=1).reshape(-1,
                                               flatten_cls_scores[0].size(-1))
        flatten_bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
            for bbox_pred in bbox_preds
        ]
        flatten_bbox_preds = torch.cat(
            flatten_bbox_preds, dim=1).view(-1, flatten_bbox_preds[0].size(-1))
        flatten_labels = torch.cat(labels_list, dim=1).reshape(-1)
        flatten_label_weights = torch.cat(
            label_weights_list, dim=1).reshape(-1)
        flatten_anchors = torch.cat(all_anchor_list, dim=1).reshape(-1, 4)
        flatten_bbox_targets = torch.cat(
            bbox_targets_list, dim=1).reshape(-1, 4)
        flatten_bbox_weights = torch.cat(
            bbox_weights_list, dim=1).reshape(-1, 4)

        # Apply ISR-P
        isr_cfg = self.train_cfg.get('isr', None)
        if isr_cfg is not None:
            all_targets = (flatten_labels, flatten_label_weights,
                           flatten_bbox_targets, flatten_bbox_weights)
            with torch.no_grad():
                all_targets = isr_p(
                    flatten_cls_scores,
                    flatten_bbox_preds,
                    all_targets,
                    flatten_anchors,
                    sampling_results_list,
                    bbox_coder=self.bbox_coder,
                    loss_cls=self.loss_cls,
                    num_class=self.num_classes,
                    **self.train_cfg['isr'])
            (flatten_labels, flatten_label_weights, flatten_bbox_targets,
             flatten_bbox_weights) = all_targets

        # For convenience we compute loss once instead separating by fpn level,
        # so that we don't need to separate the weights by level again.
        # The result should be the same
        losses_cls = self.loss_cls(
            flatten_cls_scores,
            flatten_labels,
            flatten_label_weights,
            avg_factor=avg_factor)
        losses_bbox = self.loss_bbox(
            flatten_bbox_preds,
            flatten_bbox_targets,
            flatten_bbox_weights,
            avg_factor=avg_factor)
        loss_dict = dict(loss_cls=losses_cls, loss_bbox=losses_bbox)

        # CARL Loss
        carl_cfg = self.train_cfg.get('carl', None)
        if carl_cfg is not None:
            loss_carl = carl_loss(
                flatten_cls_scores,
                flatten_labels,
                flatten_bbox_preds,
                flatten_bbox_targets,
                self.loss_bbox,
                **self.train_cfg['carl'],
                avg_factor=avg_factor,
                sigmoid=True,
                num_class=self.num_classes)
            loss_dict.update(loss_carl)

        return loss_dict