File size: 7,526 Bytes
18dd6ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#!/usr/bin/env python3
# Copyright 2004-present Facebook. All Rights Reserved.
import copy
import numpy as np
from typing import Dict
import torch
from scipy.optimize import linear_sum_assignment

from annotator.oneformer.detectron2.config import configurable
from annotator.oneformer.detectron2.structures import Boxes, Instances

from ..config.config import CfgNode as CfgNode_
from .base_tracker import BaseTracker


class BaseHungarianTracker(BaseTracker):
    """
    A base class for all Hungarian trackers
    """

    @configurable
    def __init__(
        self,
        video_height: int,
        video_width: int,
        max_num_instances: int = 200,
        max_lost_frame_count: int = 0,
        min_box_rel_dim: float = 0.02,
        min_instance_period: int = 1,
        **kwargs
    ):
        """
        Args:
        video_height: height the video frame
        video_width: width of the video frame
        max_num_instances: maximum number of id allowed to be tracked
        max_lost_frame_count: maximum number of frame an id can lost tracking
                              exceed this number, an id is considered as lost
                              forever
        min_box_rel_dim: a percentage, smaller than this dimension, a bbox is
                         removed from tracking
        min_instance_period: an instance will be shown after this number of period
                             since its first showing up in the video
        """
        super().__init__(**kwargs)
        self._video_height = video_height
        self._video_width = video_width
        self._max_num_instances = max_num_instances
        self._max_lost_frame_count = max_lost_frame_count
        self._min_box_rel_dim = min_box_rel_dim
        self._min_instance_period = min_instance_period

    @classmethod
    def from_config(cls, cfg: CfgNode_) -> Dict:
        raise NotImplementedError("Calling HungarianTracker::from_config")

    def build_cost_matrix(self, instances: Instances, prev_instances: Instances) -> np.ndarray:
        raise NotImplementedError("Calling HungarianTracker::build_matrix")

    def update(self, instances: Instances) -> Instances:
        if instances.has("pred_keypoints"):
            raise NotImplementedError("Need to add support for keypoints")
        instances = self._initialize_extra_fields(instances)
        if self._prev_instances is not None:
            self._untracked_prev_idx = set(range(len(self._prev_instances)))
            cost_matrix = self.build_cost_matrix(instances, self._prev_instances)
            matched_idx, matched_prev_idx = linear_sum_assignment(cost_matrix)
            instances = self._process_matched_idx(instances, matched_idx, matched_prev_idx)
            instances = self._process_unmatched_idx(instances, matched_idx)
            instances = self._process_unmatched_prev_idx(instances, matched_prev_idx)
        self._prev_instances = copy.deepcopy(instances)
        return instances

    def _initialize_extra_fields(self, instances: Instances) -> Instances:
        """
        If input instances don't have ID, ID_period, lost_frame_count fields,
        this method is used to initialize these fields.

        Args:
            instances: D2 Instances, for predictions of the current frame
        Return:
            D2 Instances with extra fields added
        """
        if not instances.has("ID"):
            instances.set("ID", [None] * len(instances))
        if not instances.has("ID_period"):
            instances.set("ID_period", [None] * len(instances))
        if not instances.has("lost_frame_count"):
            instances.set("lost_frame_count", [None] * len(instances))
        if self._prev_instances is None:
            instances.ID = list(range(len(instances)))
            self._id_count += len(instances)
            instances.ID_period = [1] * len(instances)
            instances.lost_frame_count = [0] * len(instances)
        return instances

    def _process_matched_idx(
        self, instances: Instances, matched_idx: np.ndarray, matched_prev_idx: np.ndarray
    ) -> Instances:
        assert matched_idx.size == matched_prev_idx.size
        for i in range(matched_idx.size):
            instances.ID[matched_idx[i]] = self._prev_instances.ID[matched_prev_idx[i]]
            instances.ID_period[matched_idx[i]] = (
                self._prev_instances.ID_period[matched_prev_idx[i]] + 1
            )
            instances.lost_frame_count[matched_idx[i]] = 0
        return instances

    def _process_unmatched_idx(self, instances: Instances, matched_idx: np.ndarray) -> Instances:
        untracked_idx = set(range(len(instances))).difference(set(matched_idx))
        for idx in untracked_idx:
            instances.ID[idx] = self._id_count
            self._id_count += 1
            instances.ID_period[idx] = 1
            instances.lost_frame_count[idx] = 0
        return instances

    def _process_unmatched_prev_idx(
        self, instances: Instances, matched_prev_idx: np.ndarray
    ) -> Instances:
        untracked_instances = Instances(
            image_size=instances.image_size,
            pred_boxes=[],
            pred_masks=[],
            pred_classes=[],
            scores=[],
            ID=[],
            ID_period=[],
            lost_frame_count=[],
        )
        prev_bboxes = list(self._prev_instances.pred_boxes)
        prev_classes = list(self._prev_instances.pred_classes)
        prev_scores = list(self._prev_instances.scores)
        prev_ID_period = self._prev_instances.ID_period
        if instances.has("pred_masks"):
            prev_masks = list(self._prev_instances.pred_masks)
        untracked_prev_idx = set(range(len(self._prev_instances))).difference(set(matched_prev_idx))
        for idx in untracked_prev_idx:
            x_left, y_top, x_right, y_bot = prev_bboxes[idx]
            if (
                (1.0 * (x_right - x_left) / self._video_width < self._min_box_rel_dim)
                or (1.0 * (y_bot - y_top) / self._video_height < self._min_box_rel_dim)
                or self._prev_instances.lost_frame_count[idx] >= self._max_lost_frame_count
                or prev_ID_period[idx] <= self._min_instance_period
            ):
                continue
            untracked_instances.pred_boxes.append(list(prev_bboxes[idx].numpy()))
            untracked_instances.pred_classes.append(int(prev_classes[idx]))
            untracked_instances.scores.append(float(prev_scores[idx]))
            untracked_instances.ID.append(self._prev_instances.ID[idx])
            untracked_instances.ID_period.append(self._prev_instances.ID_period[idx])
            untracked_instances.lost_frame_count.append(
                self._prev_instances.lost_frame_count[idx] + 1
            )
            if instances.has("pred_masks"):
                untracked_instances.pred_masks.append(prev_masks[idx].numpy().astype(np.uint8))

        untracked_instances.pred_boxes = Boxes(torch.FloatTensor(untracked_instances.pred_boxes))
        untracked_instances.pred_classes = torch.IntTensor(untracked_instances.pred_classes)
        untracked_instances.scores = torch.FloatTensor(untracked_instances.scores)
        if instances.has("pred_masks"):
            untracked_instances.pred_masks = torch.IntTensor(untracked_instances.pred_masks)
        else:
            untracked_instances.remove("pred_masks")

        return Instances.cat(
            [
                instances,
                untracked_instances,
            ]
        )