Spaces:
Runtime error
Runtime error
File size: 52,112 Bytes
f549064 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple
import mmcv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.structures import InstanceData
from torch import Tensor
from mmdet.models.utils.misc import floordiv
from mmdet.registry import MODELS
from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptConfigType
from ..layers import mask_matrix_nms
from ..utils import center_of_mass, generate_coordinate, multi_apply
from .base_mask_head import BaseMaskHead
@MODELS.register_module()
class SOLOHead(BaseMaskHead):
"""SOLO mask head used in `SOLO: Segmenting Objects by Locations.
<https://arxiv.org/abs/1912.04488>`_
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of hidden channels. Used in child classes.
Defaults to 256.
stacked_convs (int): Number of stacking convs of the head.
Defaults to 4.
strides (tuple): Downsample factor of each feature map.
scale_ranges (tuple[tuple[int, int]]): Area range of multiple
level masks, in the format [(min1, max1), (min2, max2), ...].
A range of (16, 64) means the area range between (16, 64).
pos_scale (float): Constant scale factor to control the center region.
num_grids (list[int]): Divided image into a uniform grids, each
feature map has a different grid value. The number of output
channels is grid ** 2. Defaults to [40, 36, 24, 16, 12].
cls_down_index (int): The index of downsample operation in
classification branch. Defaults to 0.
loss_mask (dict): Config of mask loss.
loss_cls (dict): Config of classification loss.
norm_cfg (dict): Dictionary to construct and config norm layer.
Defaults to norm_cfg=dict(type='GN', num_groups=32,
requires_grad=True).
train_cfg (dict): Training config of head.
test_cfg (dict): Testing config of head.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(
self,
num_classes: int,
in_channels: int,
feat_channels: int = 256,
stacked_convs: int = 4,
strides: tuple = (4, 8, 16, 32, 64),
scale_ranges: tuple = ((8, 32), (16, 64), (32, 128), (64, 256), (128,
512)),
pos_scale: float = 0.2,
num_grids: list = [40, 36, 24, 16, 12],
cls_down_index: int = 0,
loss_mask: ConfigType = dict(
type='DiceLoss', use_sigmoid=True, loss_weight=3.0),
loss_cls: ConfigType = dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
norm_cfg: ConfigType = dict(
type='GN', num_groups=32, requires_grad=True),
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: MultiConfig = [
dict(type='Normal', layer='Conv2d', std=0.01),
dict(
type='Normal',
std=0.01,
bias_prob=0.01,
override=dict(name='conv_mask_list')),
dict(
type='Normal',
std=0.01,
bias_prob=0.01,
override=dict(name='conv_cls'))
]
) -> None:
super().__init__(init_cfg=init_cfg)
self.num_classes = num_classes
self.cls_out_channels = self.num_classes
self.in_channels = in_channels
self.feat_channels = feat_channels
self.stacked_convs = stacked_convs
self.strides = strides
self.num_grids = num_grids
# number of FPN feats
self.num_levels = len(strides)
assert self.num_levels == len(scale_ranges) == len(num_grids)
self.scale_ranges = scale_ranges
self.pos_scale = pos_scale
self.cls_down_index = cls_down_index
self.loss_cls = MODELS.build(loss_cls)
self.loss_mask = MODELS.build(loss_mask)
self.norm_cfg = norm_cfg
self.init_cfg = init_cfg
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self._init_layers()
def _init_layers(self) -> None:
"""Initialize layers of the head."""
self.mask_convs = nn.ModuleList()
self.cls_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels + 2 if i == 0 else self.feat_channels
self.mask_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg))
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg))
self.conv_mask_list = nn.ModuleList()
for num_grid in self.num_grids:
self.conv_mask_list.append(
nn.Conv2d(self.feat_channels, num_grid**2, 1))
self.conv_cls = nn.Conv2d(
self.feat_channels, self.cls_out_channels, 3, padding=1)
def resize_feats(self, x: Tuple[Tensor]) -> List[Tensor]:
"""Downsample the first feat and upsample last feat in feats.
Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
list[Tensor]: Features after resizing, each is a 4D-tensor.
"""
out = []
for i in range(len(x)):
if i == 0:
out.append(
F.interpolate(x[0], scale_factor=0.5, mode='bilinear'))
elif i == len(x) - 1:
out.append(
F.interpolate(
x[i], size=x[i - 1].shape[-2:], mode='bilinear'))
else:
out.append(x[i])
return out
def forward(self, x: Tuple[Tensor]) -> tuple:
"""Forward features from the upstream network.
Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple: A tuple of classification scores and mask prediction.
- mlvl_mask_preds (list[Tensor]): Multi-level mask prediction.
Each element in the list has shape
(batch_size, num_grids**2 ,h ,w).
- mlvl_cls_preds (list[Tensor]): Multi-level scores.
Each element in the list has shape
(batch_size, num_classes, num_grids ,num_grids).
"""
assert len(x) == self.num_levels
feats = self.resize_feats(x)
mlvl_mask_preds = []
mlvl_cls_preds = []
for i in range(self.num_levels):
x = feats[i]
mask_feat = x
cls_feat = x
# generate and concat the coordinate
coord_feat = generate_coordinate(mask_feat.size(),
mask_feat.device)
mask_feat = torch.cat([mask_feat, coord_feat], 1)
for mask_layer in (self.mask_convs):
mask_feat = mask_layer(mask_feat)
mask_feat = F.interpolate(
mask_feat, scale_factor=2, mode='bilinear')
mask_preds = self.conv_mask_list[i](mask_feat)
# cls branch
for j, cls_layer in enumerate(self.cls_convs):
if j == self.cls_down_index:
num_grid = self.num_grids[i]
cls_feat = F.interpolate(
cls_feat, size=num_grid, mode='bilinear')
cls_feat = cls_layer(cls_feat)
cls_pred = self.conv_cls(cls_feat)
if not self.training:
feat_wh = feats[0].size()[-2:]
upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2)
mask_preds = F.interpolate(
mask_preds.sigmoid(), size=upsampled_size, mode='bilinear')
cls_pred = cls_pred.sigmoid()
# get local maximum
local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1)
keep_mask = local_max[:, :, :-1, :-1] == cls_pred
cls_pred = cls_pred * keep_mask
mlvl_mask_preds.append(mask_preds)
mlvl_cls_preds.append(cls_pred)
return mlvl_mask_preds, mlvl_cls_preds
def loss_by_feat(self, mlvl_mask_preds: List[Tensor],
mlvl_cls_preds: List[Tensor],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict], **kwargs) -> dict:
"""Calculate the loss based on the features extracted by the mask head.
Args:
mlvl_mask_preds (list[Tensor]): Multi-level mask prediction.
Each element in the list has shape
(batch_size, num_grids**2 ,h ,w).
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes``, ``masks``,
and ``labels`` attributes.
batch_img_metas (list[dict]): Meta information of multiple images.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
num_levels = self.num_levels
num_imgs = len(batch_img_metas)
featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds]
# `BoolTensor` in `pos_masks` represent
# whether the corresponding point is
# positive
pos_mask_targets, labels, pos_masks = multi_apply(
self._get_targets_single,
batch_gt_instances,
featmap_sizes=featmap_sizes)
# change from the outside list meaning multi images
# to the outside list meaning multi levels
mlvl_pos_mask_targets = [[] for _ in range(num_levels)]
mlvl_pos_mask_preds = [[] for _ in range(num_levels)]
mlvl_pos_masks = [[] for _ in range(num_levels)]
mlvl_labels = [[] for _ in range(num_levels)]
for img_id in range(num_imgs):
assert num_levels == len(pos_mask_targets[img_id])
for lvl in range(num_levels):
mlvl_pos_mask_targets[lvl].append(
pos_mask_targets[img_id][lvl])
mlvl_pos_mask_preds[lvl].append(
mlvl_mask_preds[lvl][img_id, pos_masks[img_id][lvl], ...])
mlvl_pos_masks[lvl].append(pos_masks[img_id][lvl].flatten())
mlvl_labels[lvl].append(labels[img_id][lvl].flatten())
# cat multiple image
temp_mlvl_cls_preds = []
for lvl in range(num_levels):
mlvl_pos_mask_targets[lvl] = torch.cat(
mlvl_pos_mask_targets[lvl], dim=0)
mlvl_pos_mask_preds[lvl] = torch.cat(
mlvl_pos_mask_preds[lvl], dim=0)
mlvl_pos_masks[lvl] = torch.cat(mlvl_pos_masks[lvl], dim=0)
mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0)
temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute(
0, 2, 3, 1).reshape(-1, self.cls_out_channels))
num_pos = sum(item.sum() for item in mlvl_pos_masks)
# dice loss
loss_mask = []
for pred, target in zip(mlvl_pos_mask_preds, mlvl_pos_mask_targets):
if pred.size()[0] == 0:
loss_mask.append(pred.sum().unsqueeze(0))
continue
loss_mask.append(
self.loss_mask(pred, target, reduction_override='none'))
if num_pos > 0:
loss_mask = torch.cat(loss_mask).sum() / num_pos
else:
loss_mask = torch.cat(loss_mask).mean()
flatten_labels = torch.cat(mlvl_labels)
flatten_cls_preds = torch.cat(temp_mlvl_cls_preds)
loss_cls = self.loss_cls(
flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1)
return dict(loss_mask=loss_mask, loss_cls=loss_cls)
def _get_targets_single(self,
gt_instances: InstanceData,
featmap_sizes: Optional[list] = None) -> tuple:
"""Compute targets for predictions of single image.
Args:
gt_instances (:obj:`InstanceData`): Ground truth of instance
annotations. It should includes ``bboxes``, ``labels``,
and ``masks`` attributes.
featmap_sizes (list[:obj:`torch.size`]): Size of each
feature map from feature pyramid, each element
means (feat_h, feat_w). Defaults to None.
Returns:
Tuple: Usually returns a tuple containing targets for predictions.
- mlvl_pos_mask_targets (list[Tensor]): Each element represent
the binary mask targets for positive points in this
level, has shape (num_pos, out_h, out_w).
- mlvl_labels (list[Tensor]): Each element is
classification labels for all
points in this level, has shape
(num_grid, num_grid).
- mlvl_pos_masks (list[Tensor]): Each element is
a `BoolTensor` to represent whether the
corresponding point in single level
is positive, has shape (num_grid **2).
"""
gt_labels = gt_instances.labels
device = gt_labels.device
gt_bboxes = gt_instances.bboxes
gt_areas = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) *
(gt_bboxes[:, 3] - gt_bboxes[:, 1]))
gt_masks = gt_instances.masks.to_tensor(
dtype=torch.bool, device=device)
mlvl_pos_mask_targets = []
mlvl_labels = []
mlvl_pos_masks = []
for (lower_bound, upper_bound), stride, featmap_size, num_grid \
in zip(self.scale_ranges, self.strides,
featmap_sizes, self.num_grids):
mask_target = torch.zeros(
[num_grid**2, featmap_size[0], featmap_size[1]],
dtype=torch.uint8,
device=device)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
labels = torch.zeros([num_grid, num_grid],
dtype=torch.int64,
device=device) + self.num_classes
pos_mask = torch.zeros([num_grid**2],
dtype=torch.bool,
device=device)
gt_inds = ((gt_areas >= lower_bound) &
(gt_areas <= upper_bound)).nonzero().flatten()
if len(gt_inds) == 0:
mlvl_pos_mask_targets.append(
mask_target.new_zeros(0, featmap_size[0], featmap_size[1]))
mlvl_labels.append(labels)
mlvl_pos_masks.append(pos_mask)
continue
hit_gt_bboxes = gt_bboxes[gt_inds]
hit_gt_labels = gt_labels[gt_inds]
hit_gt_masks = gt_masks[gt_inds, ...]
pos_w_ranges = 0.5 * (hit_gt_bboxes[:, 2] -
hit_gt_bboxes[:, 0]) * self.pos_scale
pos_h_ranges = 0.5 * (hit_gt_bboxes[:, 3] -
hit_gt_bboxes[:, 1]) * self.pos_scale
# Make sure hit_gt_masks has a value
valid_mask_flags = hit_gt_masks.sum(dim=-1).sum(dim=-1) > 0
output_stride = stride / 2
for gt_mask, gt_label, pos_h_range, pos_w_range, \
valid_mask_flag in \
zip(hit_gt_masks, hit_gt_labels, pos_h_ranges,
pos_w_ranges, valid_mask_flags):
if not valid_mask_flag:
continue
upsampled_size = (featmap_sizes[0][0] * 4,
featmap_sizes[0][1] * 4)
center_h, center_w = center_of_mass(gt_mask)
coord_w = int(
floordiv((center_w / upsampled_size[1]), (1. / num_grid),
rounding_mode='trunc'))
coord_h = int(
floordiv((center_h / upsampled_size[0]), (1. / num_grid),
rounding_mode='trunc'))
# left, top, right, down
top_box = max(
0,
int(
floordiv(
(center_h - pos_h_range) / upsampled_size[0],
(1. / num_grid),
rounding_mode='trunc')))
down_box = min(
num_grid - 1,
int(
floordiv(
(center_h + pos_h_range) / upsampled_size[0],
(1. / num_grid),
rounding_mode='trunc')))
left_box = max(
0,
int(
floordiv(
(center_w - pos_w_range) / upsampled_size[1],
(1. / num_grid),
rounding_mode='trunc')))
right_box = min(
num_grid - 1,
int(
floordiv(
(center_w + pos_w_range) / upsampled_size[1],
(1. / num_grid),
rounding_mode='trunc')))
top = max(top_box, coord_h - 1)
down = min(down_box, coord_h + 1)
left = max(coord_w - 1, left_box)
right = min(right_box, coord_w + 1)
labels[top:(down + 1), left:(right + 1)] = gt_label
# ins
gt_mask = np.uint8(gt_mask.cpu().numpy())
# Follow the original implementation, F.interpolate is
# different from cv2 and opencv
gt_mask = mmcv.imrescale(gt_mask, scale=1. / output_stride)
gt_mask = torch.from_numpy(gt_mask).to(device=device)
for i in range(top, down + 1):
for j in range(left, right + 1):
index = int(i * num_grid + j)
mask_target[index, :gt_mask.shape[0], :gt_mask.
shape[1]] = gt_mask
pos_mask[index] = True
mlvl_pos_mask_targets.append(mask_target[pos_mask])
mlvl_labels.append(labels)
mlvl_pos_masks.append(pos_mask)
return mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks
def predict_by_feat(self, mlvl_mask_preds: List[Tensor],
mlvl_cls_scores: List[Tensor],
batch_img_metas: List[dict], **kwargs) -> InstanceList:
"""Transform a batch of output features extracted from the head into
mask results.
Args:
mlvl_mask_preds (list[Tensor]): Multi-level mask prediction.
Each element in the list has shape
(batch_size, num_grids**2 ,h ,w).
mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element
in the list has shape
(batch_size, num_classes, num_grids ,num_grids).
batch_img_metas (list[dict]): Meta information of all images.
Returns:
list[:obj:`InstanceData`]: Processed results of multiple
images.Each :obj:`InstanceData` usually contains
following keys.
- scores (Tensor): Classification scores, has shape
(num_instance,).
- labels (Tensor): Has shape (num_instances,).
- masks (Tensor): Processed mask results, has
shape (num_instances, h, w).
"""
mlvl_cls_scores = [
item.permute(0, 2, 3, 1) for item in mlvl_cls_scores
]
assert len(mlvl_mask_preds) == len(mlvl_cls_scores)
num_levels = len(mlvl_cls_scores)
results_list = []
for img_id in range(len(batch_img_metas)):
cls_pred_list = [
mlvl_cls_scores[lvl][img_id].view(-1, self.cls_out_channels)
for lvl in range(num_levels)
]
mask_pred_list = [
mlvl_mask_preds[lvl][img_id] for lvl in range(num_levels)
]
cls_pred_list = torch.cat(cls_pred_list, dim=0)
mask_pred_list = torch.cat(mask_pred_list, dim=0)
img_meta = batch_img_metas[img_id]
results = self._predict_by_feat_single(
cls_pred_list, mask_pred_list, img_meta=img_meta)
results_list.append(results)
return results_list
def _predict_by_feat_single(self,
cls_scores: Tensor,
mask_preds: Tensor,
img_meta: dict,
cfg: OptConfigType = None) -> InstanceData:
"""Transform a single image's features extracted from the head into
mask results.
Args:
cls_scores (Tensor): Classification score of all points
in single image, has shape (num_points, num_classes).
mask_preds (Tensor): Mask prediction of all points in
single image, has shape (num_points, feat_h, feat_w).
img_meta (dict): Meta information of corresponding image.
cfg (dict, optional): Config used in test phase.
Defaults to None.
Returns:
:obj:`InstanceData`: Processed results of single image.
it usually contains following keys.
- scores (Tensor): Classification scores, has shape
(num_instance,).
- labels (Tensor): Has shape (num_instances,).
- masks (Tensor): Processed mask results, has
shape (num_instances, h, w).
"""
def empty_results(cls_scores, ori_shape):
"""Generate a empty results."""
results = InstanceData()
results.scores = cls_scores.new_ones(0)
results.masks = cls_scores.new_zeros(0, *ori_shape)
results.labels = cls_scores.new_ones(0)
results.bboxes = cls_scores.new_zeros(0, 4)
return results
cfg = self.test_cfg if cfg is None else cfg
assert len(cls_scores) == len(mask_preds)
featmap_size = mask_preds.size()[-2:]
h, w = img_meta['img_shape'][:2]
upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4)
score_mask = (cls_scores > cfg.score_thr)
cls_scores = cls_scores[score_mask]
if len(cls_scores) == 0:
return empty_results(cls_scores, img_meta['ori_shape'][:2])
inds = score_mask.nonzero()
cls_labels = inds[:, 1]
# Filter the mask mask with an area is smaller than
# stride of corresponding feature level
lvl_interval = cls_labels.new_tensor(self.num_grids).pow(2).cumsum(0)
strides = cls_scores.new_ones(lvl_interval[-1])
strides[:lvl_interval[0]] *= self.strides[0]
for lvl in range(1, self.num_levels):
strides[lvl_interval[lvl -
1]:lvl_interval[lvl]] *= self.strides[lvl]
strides = strides[inds[:, 0]]
mask_preds = mask_preds[inds[:, 0]]
masks = mask_preds > cfg.mask_thr
sum_masks = masks.sum((1, 2)).float()
keep = sum_masks > strides
if keep.sum() == 0:
return empty_results(cls_scores, img_meta['ori_shape'][:2])
masks = masks[keep]
mask_preds = mask_preds[keep]
sum_masks = sum_masks[keep]
cls_scores = cls_scores[keep]
cls_labels = cls_labels[keep]
# maskness.
mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks
cls_scores *= mask_scores
scores, labels, _, keep_inds = mask_matrix_nms(
masks,
cls_labels,
cls_scores,
mask_area=sum_masks,
nms_pre=cfg.nms_pre,
max_num=cfg.max_per_img,
kernel=cfg.kernel,
sigma=cfg.sigma,
filter_thr=cfg.filter_thr)
# mask_matrix_nms may return an empty Tensor
if len(keep_inds) == 0:
return empty_results(cls_scores, img_meta['ori_shape'][:2])
mask_preds = mask_preds[keep_inds]
mask_preds = F.interpolate(
mask_preds.unsqueeze(0), size=upsampled_size,
mode='bilinear')[:, :, :h, :w]
mask_preds = F.interpolate(
mask_preds, size=img_meta['ori_shape'][:2],
mode='bilinear').squeeze(0)
masks = mask_preds > cfg.mask_thr
results = InstanceData()
results.masks = masks
results.labels = labels
results.scores = scores
# create an empty bbox in InstanceData to avoid bugs when
# calculating metrics.
results.bboxes = results.scores.new_zeros(len(scores), 4)
return results
@MODELS.register_module()
class DecoupledSOLOHead(SOLOHead):
"""Decoupled SOLO mask head used in `SOLO: Segmenting Objects by Locations.
<https://arxiv.org/abs/1912.04488>`_
Args:
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
*args,
init_cfg: MultiConfig = [
dict(type='Normal', layer='Conv2d', std=0.01),
dict(
type='Normal',
std=0.01,
bias_prob=0.01,
override=dict(name='conv_mask_list_x')),
dict(
type='Normal',
std=0.01,
bias_prob=0.01,
override=dict(name='conv_mask_list_y')),
dict(
type='Normal',
std=0.01,
bias_prob=0.01,
override=dict(name='conv_cls'))
],
**kwargs) -> None:
super().__init__(*args, init_cfg=init_cfg, **kwargs)
def _init_layers(self) -> None:
self.mask_convs_x = nn.ModuleList()
self.mask_convs_y = nn.ModuleList()
self.cls_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels + 1 if i == 0 else self.feat_channels
self.mask_convs_x.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg))
self.mask_convs_y.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg))
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg))
self.conv_mask_list_x = nn.ModuleList()
self.conv_mask_list_y = nn.ModuleList()
for num_grid in self.num_grids:
self.conv_mask_list_x.append(
nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
self.conv_mask_list_y.append(
nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
self.conv_cls = nn.Conv2d(
self.feat_channels, self.cls_out_channels, 3, padding=1)
def forward(self, x: Tuple[Tensor]) -> Tuple:
"""Forward features from the upstream network.
Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple: A tuple of classification scores and mask prediction.
- mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction
from x branch. Each element in the list has shape
(batch_size, num_grids ,h ,w).
- mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction
from y branch. Each element in the list has shape
(batch_size, num_grids ,h ,w).
- mlvl_cls_preds (list[Tensor]): Multi-level scores.
Each element in the list has shape
(batch_size, num_classes, num_grids ,num_grids).
"""
assert len(x) == self.num_levels
feats = self.resize_feats(x)
mask_preds_x = []
mask_preds_y = []
cls_preds = []
for i in range(self.num_levels):
x = feats[i]
mask_feat = x
cls_feat = x
# generate and concat the coordinate
coord_feat = generate_coordinate(mask_feat.size(),
mask_feat.device)
mask_feat_x = torch.cat([mask_feat, coord_feat[:, 0:1, ...]], 1)
mask_feat_y = torch.cat([mask_feat, coord_feat[:, 1:2, ...]], 1)
for mask_layer_x, mask_layer_y in \
zip(self.mask_convs_x, self.mask_convs_y):
mask_feat_x = mask_layer_x(mask_feat_x)
mask_feat_y = mask_layer_y(mask_feat_y)
mask_feat_x = F.interpolate(
mask_feat_x, scale_factor=2, mode='bilinear')
mask_feat_y = F.interpolate(
mask_feat_y, scale_factor=2, mode='bilinear')
mask_pred_x = self.conv_mask_list_x[i](mask_feat_x)
mask_pred_y = self.conv_mask_list_y[i](mask_feat_y)
# cls branch
for j, cls_layer in enumerate(self.cls_convs):
if j == self.cls_down_index:
num_grid = self.num_grids[i]
cls_feat = F.interpolate(
cls_feat, size=num_grid, mode='bilinear')
cls_feat = cls_layer(cls_feat)
cls_pred = self.conv_cls(cls_feat)
if not self.training:
feat_wh = feats[0].size()[-2:]
upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2)
mask_pred_x = F.interpolate(
mask_pred_x.sigmoid(),
size=upsampled_size,
mode='bilinear')
mask_pred_y = F.interpolate(
mask_pred_y.sigmoid(),
size=upsampled_size,
mode='bilinear')
cls_pred = cls_pred.sigmoid()
# get local maximum
local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1)
keep_mask = local_max[:, :, :-1, :-1] == cls_pred
cls_pred = cls_pred * keep_mask
mask_preds_x.append(mask_pred_x)
mask_preds_y.append(mask_pred_y)
cls_preds.append(cls_pred)
return mask_preds_x, mask_preds_y, cls_preds
def loss_by_feat(self, mlvl_mask_preds_x: List[Tensor],
mlvl_mask_preds_y: List[Tensor],
mlvl_cls_preds: List[Tensor],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict], **kwargs) -> dict:
"""Calculate the loss based on the features extracted by the mask head.
Args:
mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction
from x branch. Each element in the list has shape
(batch_size, num_grids ,h ,w).
mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction
from y branch. Each element in the list has shape
(batch_size, num_grids ,h ,w).
mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element
in the list has shape
(batch_size, num_classes, num_grids ,num_grids).
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes``, ``masks``,
and ``labels`` attributes.
batch_img_metas (list[dict]): Meta information of multiple images.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
num_levels = self.num_levels
num_imgs = len(batch_img_metas)
featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds_x]
pos_mask_targets, labels, xy_pos_indexes = multi_apply(
self._get_targets_single,
batch_gt_instances,
featmap_sizes=featmap_sizes)
# change from the outside list meaning multi images
# to the outside list meaning multi levels
mlvl_pos_mask_targets = [[] for _ in range(num_levels)]
mlvl_pos_mask_preds_x = [[] for _ in range(num_levels)]
mlvl_pos_mask_preds_y = [[] for _ in range(num_levels)]
mlvl_labels = [[] for _ in range(num_levels)]
for img_id in range(num_imgs):
for lvl in range(num_levels):
mlvl_pos_mask_targets[lvl].append(
pos_mask_targets[img_id][lvl])
mlvl_pos_mask_preds_x[lvl].append(
mlvl_mask_preds_x[lvl][img_id,
xy_pos_indexes[img_id][lvl][:, 1]])
mlvl_pos_mask_preds_y[lvl].append(
mlvl_mask_preds_y[lvl][img_id,
xy_pos_indexes[img_id][lvl][:, 0]])
mlvl_labels[lvl].append(labels[img_id][lvl].flatten())
# cat multiple image
temp_mlvl_cls_preds = []
for lvl in range(num_levels):
mlvl_pos_mask_targets[lvl] = torch.cat(
mlvl_pos_mask_targets[lvl], dim=0)
mlvl_pos_mask_preds_x[lvl] = torch.cat(
mlvl_pos_mask_preds_x[lvl], dim=0)
mlvl_pos_mask_preds_y[lvl] = torch.cat(
mlvl_pos_mask_preds_y[lvl], dim=0)
mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0)
temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute(
0, 2, 3, 1).reshape(-1, self.cls_out_channels))
num_pos = 0.
# dice loss
loss_mask = []
for pred_x, pred_y, target in \
zip(mlvl_pos_mask_preds_x,
mlvl_pos_mask_preds_y, mlvl_pos_mask_targets):
num_masks = pred_x.size(0)
if num_masks == 0:
# make sure can get grad
loss_mask.append((pred_x.sum() + pred_y.sum()).unsqueeze(0))
continue
num_pos += num_masks
pred_mask = pred_y.sigmoid() * pred_x.sigmoid()
loss_mask.append(
self.loss_mask(pred_mask, target, reduction_override='none'))
if num_pos > 0:
loss_mask = torch.cat(loss_mask).sum() / num_pos
else:
loss_mask = torch.cat(loss_mask).mean()
# cate
flatten_labels = torch.cat(mlvl_labels)
flatten_cls_preds = torch.cat(temp_mlvl_cls_preds)
loss_cls = self.loss_cls(
flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1)
return dict(loss_mask=loss_mask, loss_cls=loss_cls)
def _get_targets_single(self,
gt_instances: InstanceData,
featmap_sizes: Optional[list] = None) -> tuple:
"""Compute targets for predictions of single image.
Args:
gt_instances (:obj:`InstanceData`): Ground truth of instance
annotations. It should includes ``bboxes``, ``labels``,
and ``masks`` attributes.
featmap_sizes (list[:obj:`torch.size`]): Size of each
feature map from feature pyramid, each element
means (feat_h, feat_w). Defaults to None.
Returns:
Tuple: Usually returns a tuple containing targets for predictions.
- mlvl_pos_mask_targets (list[Tensor]): Each element represent
the binary mask targets for positive points in this
level, has shape (num_pos, out_h, out_w).
- mlvl_labels (list[Tensor]): Each element is
classification labels for all
points in this level, has shape
(num_grid, num_grid).
- mlvl_xy_pos_indexes (list[Tensor]): Each element
in the list contains the index of positive samples in
corresponding level, has shape (num_pos, 2), last
dimension 2 present (index_x, index_y).
"""
mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks = \
super()._get_targets_single(gt_instances,
featmap_sizes=featmap_sizes)
mlvl_xy_pos_indexes = [(item - self.num_classes).nonzero()
for item in mlvl_labels]
return mlvl_pos_mask_targets, mlvl_labels, mlvl_xy_pos_indexes
def predict_by_feat(self, mlvl_mask_preds_x: List[Tensor],
mlvl_mask_preds_y: List[Tensor],
mlvl_cls_scores: List[Tensor],
batch_img_metas: List[dict], **kwargs) -> InstanceList:
"""Transform a batch of output features extracted from the head into
mask results.
Args:
mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction
from x branch. Each element in the list has shape
(batch_size, num_grids ,h ,w).
mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction
from y branch. Each element in the list has shape
(batch_size, num_grids ,h ,w).
mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element
in the list has shape
(batch_size, num_classes ,num_grids ,num_grids).
batch_img_metas (list[dict]): Meta information of all images.
Returns:
list[:obj:`InstanceData`]: Processed results of multiple
images.Each :obj:`InstanceData` usually contains
following keys.
- scores (Tensor): Classification scores, has shape
(num_instance,).
- labels (Tensor): Has shape (num_instances,).
- masks (Tensor): Processed mask results, has
shape (num_instances, h, w).
"""
mlvl_cls_scores = [
item.permute(0, 2, 3, 1) for item in mlvl_cls_scores
]
assert len(mlvl_mask_preds_x) == len(mlvl_cls_scores)
num_levels = len(mlvl_cls_scores)
results_list = []
for img_id in range(len(batch_img_metas)):
cls_pred_list = [
mlvl_cls_scores[i][img_id].view(
-1, self.cls_out_channels).detach()
for i in range(num_levels)
]
mask_pred_list_x = [
mlvl_mask_preds_x[i][img_id] for i in range(num_levels)
]
mask_pred_list_y = [
mlvl_mask_preds_y[i][img_id] for i in range(num_levels)
]
cls_pred_list = torch.cat(cls_pred_list, dim=0)
mask_pred_list_x = torch.cat(mask_pred_list_x, dim=0)
mask_pred_list_y = torch.cat(mask_pred_list_y, dim=0)
img_meta = batch_img_metas[img_id]
results = self._predict_by_feat_single(
cls_pred_list,
mask_pred_list_x,
mask_pred_list_y,
img_meta=img_meta)
results_list.append(results)
return results_list
def _predict_by_feat_single(self,
cls_scores: Tensor,
mask_preds_x: Tensor,
mask_preds_y: Tensor,
img_meta: dict,
cfg: OptConfigType = None) -> InstanceData:
"""Transform a single image's features extracted from the head into
mask results.
Args:
cls_scores (Tensor): Classification score of all points
in single image, has shape (num_points, num_classes).
mask_preds_x (Tensor): Mask prediction of x branch of
all points in single image, has shape
(sum_num_grids, feat_h, feat_w).
mask_preds_y (Tensor): Mask prediction of y branch of
all points in single image, has shape
(sum_num_grids, feat_h, feat_w).
img_meta (dict): Meta information of corresponding image.
cfg (dict): Config used in test phase.
Returns:
:obj:`InstanceData`: Processed results of single image.
it usually contains following keys.
- scores (Tensor): Classification scores, has shape
(num_instance,).
- labels (Tensor): Has shape (num_instances,).
- masks (Tensor): Processed mask results, has
shape (num_instances, h, w).
"""
def empty_results(cls_scores, ori_shape):
"""Generate a empty results."""
results = InstanceData()
results.scores = cls_scores.new_ones(0)
results.masks = cls_scores.new_zeros(0, *ori_shape)
results.labels = cls_scores.new_ones(0)
results.bboxes = cls_scores.new_zeros(0, 4)
return results
cfg = self.test_cfg if cfg is None else cfg
featmap_size = mask_preds_x.size()[-2:]
h, w = img_meta['img_shape'][:2]
upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4)
score_mask = (cls_scores > cfg.score_thr)
cls_scores = cls_scores[score_mask]
inds = score_mask.nonzero()
lvl_interval = inds.new_tensor(self.num_grids).pow(2).cumsum(0)
num_all_points = lvl_interval[-1]
lvl_start_index = inds.new_ones(num_all_points)
num_grids = inds.new_ones(num_all_points)
seg_size = inds.new_tensor(self.num_grids).cumsum(0)
mask_lvl_start_index = inds.new_ones(num_all_points)
strides = inds.new_ones(num_all_points)
lvl_start_index[:lvl_interval[0]] *= 0
mask_lvl_start_index[:lvl_interval[0]] *= 0
num_grids[:lvl_interval[0]] *= self.num_grids[0]
strides[:lvl_interval[0]] *= self.strides[0]
for lvl in range(1, self.num_levels):
lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
lvl_interval[lvl - 1]
mask_lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
seg_size[lvl - 1]
num_grids[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
self.num_grids[lvl]
strides[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
self.strides[lvl]
lvl_start_index = lvl_start_index[inds[:, 0]]
mask_lvl_start_index = mask_lvl_start_index[inds[:, 0]]
num_grids = num_grids[inds[:, 0]]
strides = strides[inds[:, 0]]
y_lvl_offset = (inds[:, 0] - lvl_start_index) // num_grids
x_lvl_offset = (inds[:, 0] - lvl_start_index) % num_grids
y_inds = mask_lvl_start_index + y_lvl_offset
x_inds = mask_lvl_start_index + x_lvl_offset
cls_labels = inds[:, 1]
mask_preds = mask_preds_x[x_inds, ...] * mask_preds_y[y_inds, ...]
masks = mask_preds > cfg.mask_thr
sum_masks = masks.sum((1, 2)).float()
keep = sum_masks > strides
if keep.sum() == 0:
return empty_results(cls_scores, img_meta['ori_shape'][:2])
masks = masks[keep]
mask_preds = mask_preds[keep]
sum_masks = sum_masks[keep]
cls_scores = cls_scores[keep]
cls_labels = cls_labels[keep]
# maskness.
mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks
cls_scores *= mask_scores
scores, labels, _, keep_inds = mask_matrix_nms(
masks,
cls_labels,
cls_scores,
mask_area=sum_masks,
nms_pre=cfg.nms_pre,
max_num=cfg.max_per_img,
kernel=cfg.kernel,
sigma=cfg.sigma,
filter_thr=cfg.filter_thr)
# mask_matrix_nms may return an empty Tensor
if len(keep_inds) == 0:
return empty_results(cls_scores, img_meta['ori_shape'][:2])
mask_preds = mask_preds[keep_inds]
mask_preds = F.interpolate(
mask_preds.unsqueeze(0), size=upsampled_size,
mode='bilinear')[:, :, :h, :w]
mask_preds = F.interpolate(
mask_preds, size=img_meta['ori_shape'][:2],
mode='bilinear').squeeze(0)
masks = mask_preds > cfg.mask_thr
results = InstanceData()
results.masks = masks
results.labels = labels
results.scores = scores
# create an empty bbox in InstanceData to avoid bugs when
# calculating metrics.
results.bboxes = results.scores.new_zeros(len(scores), 4)
return results
@MODELS.register_module()
class DecoupledSOLOLightHead(DecoupledSOLOHead):
"""Decoupled Light SOLO mask head used in `SOLO: Segmenting Objects by
Locations <https://arxiv.org/abs/1912.04488>`_
Args:
with_dcn (bool): Whether use dcn in mask_convs and cls_convs,
Defaults to False.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
*args,
dcn_cfg: OptConfigType = None,
init_cfg: MultiConfig = [
dict(type='Normal', layer='Conv2d', std=0.01),
dict(
type='Normal',
std=0.01,
bias_prob=0.01,
override=dict(name='conv_mask_list_x')),
dict(
type='Normal',
std=0.01,
bias_prob=0.01,
override=dict(name='conv_mask_list_y')),
dict(
type='Normal',
std=0.01,
bias_prob=0.01,
override=dict(name='conv_cls'))
],
**kwargs) -> None:
assert dcn_cfg is None or isinstance(dcn_cfg, dict)
self.dcn_cfg = dcn_cfg
super().__init__(*args, init_cfg=init_cfg, **kwargs)
def _init_layers(self) -> None:
self.mask_convs = nn.ModuleList()
self.cls_convs = nn.ModuleList()
for i in range(self.stacked_convs):
if self.dcn_cfg is not None \
and i == self.stacked_convs - 1:
conv_cfg = self.dcn_cfg
else:
conv_cfg = None
chn = self.in_channels + 2 if i == 0 else self.feat_channels
self.mask_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=self.norm_cfg))
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=self.norm_cfg))
self.conv_mask_list_x = nn.ModuleList()
self.conv_mask_list_y = nn.ModuleList()
for num_grid in self.num_grids:
self.conv_mask_list_x.append(
nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
self.conv_mask_list_y.append(
nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
self.conv_cls = nn.Conv2d(
self.feat_channels, self.cls_out_channels, 3, padding=1)
def forward(self, x: Tuple[Tensor]) -> Tuple:
"""Forward features from the upstream network.
Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple: A tuple of classification scores and mask prediction.
- mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction
from x branch. Each element in the list has shape
(batch_size, num_grids ,h ,w).
- mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction
from y branch. Each element in the list has shape
(batch_size, num_grids ,h ,w).
- mlvl_cls_preds (list[Tensor]): Multi-level scores.
Each element in the list has shape
(batch_size, num_classes, num_grids ,num_grids).
"""
assert len(x) == self.num_levels
feats = self.resize_feats(x)
mask_preds_x = []
mask_preds_y = []
cls_preds = []
for i in range(self.num_levels):
x = feats[i]
mask_feat = x
cls_feat = x
# generate and concat the coordinate
coord_feat = generate_coordinate(mask_feat.size(),
mask_feat.device)
mask_feat = torch.cat([mask_feat, coord_feat], 1)
for mask_layer in self.mask_convs:
mask_feat = mask_layer(mask_feat)
mask_feat = F.interpolate(
mask_feat, scale_factor=2, mode='bilinear')
mask_pred_x = self.conv_mask_list_x[i](mask_feat)
mask_pred_y = self.conv_mask_list_y[i](mask_feat)
# cls branch
for j, cls_layer in enumerate(self.cls_convs):
if j == self.cls_down_index:
num_grid = self.num_grids[i]
cls_feat = F.interpolate(
cls_feat, size=num_grid, mode='bilinear')
cls_feat = cls_layer(cls_feat)
cls_pred = self.conv_cls(cls_feat)
if not self.training:
feat_wh = feats[0].size()[-2:]
upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2)
mask_pred_x = F.interpolate(
mask_pred_x.sigmoid(),
size=upsampled_size,
mode='bilinear')
mask_pred_y = F.interpolate(
mask_pred_y.sigmoid(),
size=upsampled_size,
mode='bilinear')
cls_pred = cls_pred.sigmoid()
# get local maximum
local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1)
keep_mask = local_max[:, :, :-1, :-1] == cls_pred
cls_pred = cls_pred * keep_mask
mask_preds_x.append(mask_pred_x)
mask_preds_y.append(mask_pred_y)
cls_preds.append(cls_pred)
return mask_preds_x, mask_preds_y, cls_preds
|