Spaces:
Runtime error
Runtime error
File size: 26,604 Bytes
f549064 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Conv2d
from mmengine.model import caffe2_xavier_init
from mmengine.structures import InstanceData, PixelData
from torch import Tensor
from mmdet.models.layers.pixel_decoder import PixelDecoder
from mmdet.registry import MODELS, TASK_UTILS
from mmdet.structures import SampleList
from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
OptMultiConfig, reduce_mean)
from ..layers import DetrTransformerDecoder, SinePositionalEncoding
from ..utils import multi_apply, preprocess_panoptic_gt
from .anchor_free_head import AnchorFreeHead
@MODELS.register_module()
class MaskFormerHead(AnchorFreeHead):
"""Implements the MaskFormer head.
See `Per-Pixel Classification is Not All You Need for Semantic
Segmentation <https://arxiv.org/pdf/2107.06278>`_ for details.
Args:
in_channels (list[int]): Number of channels in the input feature map.
feat_channels (int): Number of channels for feature.
out_channels (int): Number of channels for output.
num_things_classes (int): Number of things.
num_stuff_classes (int): Number of stuff.
num_queries (int): Number of query in Transformer.
pixel_decoder (:obj:`ConfigDict` or dict): Config for pixel
decoder.
enforce_decoder_input_project (bool): Whether to add a layer
to change the embed_dim of transformer encoder in pixel decoder to
the embed_dim of transformer decoder. Defaults to False.
transformer_decoder (:obj:`ConfigDict` or dict): Config for
transformer decoder.
positional_encoding (:obj:`ConfigDict` or dict): Config for
transformer decoder position encoding.
loss_cls (:obj:`ConfigDict` or dict): Config of the classification
loss. Defaults to `CrossEntropyLoss`.
loss_mask (:obj:`ConfigDict` or dict): Config of the mask loss.
Defaults to `FocalLoss`.
loss_dice (:obj:`ConfigDict` or dict): Config of the dice loss.
Defaults to `DiceLoss`.
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
MaskFormer head.
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
MaskFormer head.
init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
dict], optional): Initialization config dict. Defaults to None.
"""
def __init__(self,
in_channels: List[int],
feat_channels: int,
out_channels: int,
num_things_classes: int = 80,
num_stuff_classes: int = 53,
num_queries: int = 100,
pixel_decoder: ConfigType = ...,
enforce_decoder_input_project: bool = False,
transformer_decoder: ConfigType = ...,
positional_encoding: ConfigType = dict(
num_feats=128, normalize=True),
loss_cls: ConfigType = dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0,
class_weight=[1.0] * 133 + [0.1]),
loss_mask: ConfigType = dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=20.0),
loss_dice: ConfigType = dict(
type='DiceLoss',
use_sigmoid=True,
activate=True,
naive_dice=True,
loss_weight=1.0),
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: OptMultiConfig = None,
**kwargs) -> None:
super(AnchorFreeHead, self).__init__(init_cfg=init_cfg)
self.num_things_classes = num_things_classes
self.num_stuff_classes = num_stuff_classes
self.num_classes = self.num_things_classes + self.num_stuff_classes
self.num_queries = num_queries
pixel_decoder.update(
in_channels=in_channels,
feat_channels=feat_channels,
out_channels=out_channels)
self.pixel_decoder = MODELS.build(pixel_decoder)
self.transformer_decoder = DetrTransformerDecoder(
**transformer_decoder)
self.decoder_embed_dims = self.transformer_decoder.embed_dims
if type(self.pixel_decoder) == PixelDecoder and (
self.decoder_embed_dims != in_channels[-1]
or enforce_decoder_input_project):
self.decoder_input_proj = Conv2d(
in_channels[-1], self.decoder_embed_dims, kernel_size=1)
else:
self.decoder_input_proj = nn.Identity()
self.decoder_pe = SinePositionalEncoding(**positional_encoding)
self.query_embed = nn.Embedding(self.num_queries, out_channels)
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
self.mask_embed = nn.Sequential(
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
nn.Linear(feat_channels, out_channels))
self.test_cfg = test_cfg
self.train_cfg = train_cfg
if train_cfg:
self.assigner = TASK_UTILS.build(train_cfg['assigner'])
self.sampler = TASK_UTILS.build(
train_cfg['sampler'], default_args=dict(context=self))
self.class_weight = loss_cls.class_weight
self.loss_cls = MODELS.build(loss_cls)
self.loss_mask = MODELS.build(loss_mask)
self.loss_dice = MODELS.build(loss_dice)
def init_weights(self) -> None:
if isinstance(self.decoder_input_proj, Conv2d):
caffe2_xavier_init(self.decoder_input_proj, bias=0)
self.pixel_decoder.init_weights()
for p in self.transformer_decoder.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def preprocess_gt(
self, batch_gt_instances: InstanceList,
batch_gt_semantic_segs: List[Optional[PixelData]]) -> InstanceList:
"""Preprocess the ground truth for all images.
Args:
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``labels``, each is
ground truth labels of each bbox, with shape (num_gts, )
and ``masks``, each is ground truth masks of each instances
of a image, shape (num_gts, h, w).
gt_semantic_seg (list[Optional[PixelData]]): Ground truth of
semantic segmentation, each with the shape (1, h, w).
[0, num_thing_class - 1] means things,
[num_thing_class, num_class-1] means stuff,
255 means VOID. It's None when training instance segmentation.
Returns:
list[obj:`InstanceData`]: each contains the following keys
- labels (Tensor): Ground truth class indices\
for a image, with shape (n, ), n is the sum of\
number of stuff type and number of instance in a image.
- masks (Tensor): Ground truth mask for a\
image, with shape (n, h, w).
"""
num_things_list = [self.num_things_classes] * len(batch_gt_instances)
num_stuff_list = [self.num_stuff_classes] * len(batch_gt_instances)
gt_labels_list = [
gt_instances['labels'] for gt_instances in batch_gt_instances
]
gt_masks_list = [
gt_instances['masks'] for gt_instances in batch_gt_instances
]
gt_semantic_segs = [
None if gt_semantic_seg is None else gt_semantic_seg.sem_seg
for gt_semantic_seg in batch_gt_semantic_segs
]
targets = multi_apply(preprocess_panoptic_gt, gt_labels_list,
gt_masks_list, gt_semantic_segs, num_things_list,
num_stuff_list)
labels, masks = targets
batch_gt_instances = [
InstanceData(labels=label, masks=mask)
for label, mask in zip(labels, masks)
]
return batch_gt_instances
def get_targets(
self,
cls_scores_list: List[Tensor],
mask_preds_list: List[Tensor],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict],
return_sampling_results: bool = False
) -> Tuple[List[Union[Tensor, int]]]:
"""Compute classification and mask targets for all images for a decoder
layer.
Args:
cls_scores_list (list[Tensor]): Mask score logits from a single
decoder layer for all images. Each with shape (num_queries,
cls_out_channels).
mask_preds_list (list[Tensor]): Mask logits from a single decoder
layer for all images. Each with shape (num_queries, h, w).
batch_gt_instances (list[obj:`InstanceData`]): each contains
``labels`` and ``masks``.
batch_img_metas (list[dict]): List of image meta information.
return_sampling_results (bool): Whether to return the sampling
results. Defaults to False.
Returns:
tuple: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels of all images.\
Each with shape (num_queries, ).
- label_weights_list (list[Tensor]): Label weights\
of all images. Each with shape (num_queries, ).
- mask_targets_list (list[Tensor]): Mask targets of\
all images. Each with shape (num_queries, h, w).
- mask_weights_list (list[Tensor]): Mask weights of\
all images. Each with shape (num_queries, ).
- avg_factor (int): Average factor that is used to average\
the loss. When using sampling method, avg_factor is
usually the sum of positive and negative priors. When
using `MaskPseudoSampler`, `avg_factor` is usually equal
to the number of positive priors.
additional_returns: This function enables user-defined returns from
`self._get_targets_single`. These returns are currently refined
to properties at each feature map (i.e. having HxW dimension).
The results will be concatenated after the end.
"""
results = multi_apply(self._get_targets_single, cls_scores_list,
mask_preds_list, batch_gt_instances,
batch_img_metas)
(labels_list, label_weights_list, mask_targets_list, mask_weights_list,
pos_inds_list, neg_inds_list, sampling_results_list) = results[:7]
rest_results = list(results[7:])
avg_factor = sum(
[results.avg_factor for results in sampling_results_list])
res = (labels_list, label_weights_list, mask_targets_list,
mask_weights_list, avg_factor)
if return_sampling_results:
res = res + (sampling_results_list)
return res + tuple(rest_results)
def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor,
gt_instances: InstanceData,
img_meta: dict) -> Tuple[Tensor]:
"""Compute classification and mask targets for one image.
Args:
cls_score (Tensor): Mask score logits from a single decoder layer
for one image. Shape (num_queries, cls_out_channels).
mask_pred (Tensor): Mask logits for a single decoder layer for one
image. Shape (num_queries, h, w).
gt_instances (:obj:`InstanceData`): It contains ``labels`` and
``masks``.
img_meta (dict): Image informtation.
Returns:
tuple: a tuple containing the following for one image.
- labels (Tensor): Labels of each image.
shape (num_queries, ).
- label_weights (Tensor): Label weights of each image.
shape (num_queries, ).
- mask_targets (Tensor): Mask targets of each image.
shape (num_queries, h, w).
- mask_weights (Tensor): Mask weights of each image.
shape (num_queries, ).
- pos_inds (Tensor): Sampled positive indices for each image.
- neg_inds (Tensor): Sampled negative indices for each image.
- sampling_result (:obj:`SamplingResult`): Sampling results.
"""
gt_masks = gt_instances.masks
gt_labels = gt_instances.labels
target_shape = mask_pred.shape[-2:]
if gt_masks.shape[0] > 0:
gt_masks_downsampled = F.interpolate(
gt_masks.unsqueeze(1).float(), target_shape,
mode='nearest').squeeze(1).long()
else:
gt_masks_downsampled = gt_masks
pred_instances = InstanceData(scores=cls_score, masks=mask_pred)
downsampled_gt_instances = InstanceData(
labels=gt_labels, masks=gt_masks_downsampled)
# assign and sample
assign_result = self.assigner.assign(
pred_instances=pred_instances,
gt_instances=downsampled_gt_instances,
img_meta=img_meta)
sampling_result = self.sampler.sample(
assign_result=assign_result,
pred_instances=pred_instances,
gt_instances=gt_instances)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
# label target
labels = gt_labels.new_full((self.num_queries, ),
self.num_classes,
dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
label_weights = gt_labels.new_ones(self.num_queries)
# mask target
mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds]
mask_weights = mask_pred.new_zeros((self.num_queries, ))
mask_weights[pos_inds] = 1.0
return (labels, label_weights, mask_targets, mask_weights, pos_inds,
neg_inds, sampling_result)
def loss_by_feat(self, all_cls_scores: Tensor, all_mask_preds: Tensor,
batch_gt_instances: List[InstanceData],
batch_img_metas: List[dict]) -> Dict[str, Tensor]:
"""Loss function.
Args:
all_cls_scores (Tensor): Classification scores for all decoder
layers with shape (num_decoder, batch_size, num_queries,
cls_out_channels). Note `cls_out_channels` should includes
background.
all_mask_preds (Tensor): Mask scores for all decoder layers with
shape (num_decoder, batch_size, num_queries, h, w).
batch_gt_instances (list[obj:`InstanceData`]): each contains
``labels`` and ``masks``.
batch_img_metas (list[dict]): List of image meta information.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
num_dec_layers = len(all_cls_scores)
batch_gt_instances_list = [
batch_gt_instances for _ in range(num_dec_layers)
]
img_metas_list = [batch_img_metas for _ in range(num_dec_layers)]
losses_cls, losses_mask, losses_dice = multi_apply(
self._loss_by_feat_single, all_cls_scores, all_mask_preds,
batch_gt_instances_list, img_metas_list)
loss_dict = dict()
# loss from the last decoder layer
loss_dict['loss_cls'] = losses_cls[-1]
loss_dict['loss_mask'] = losses_mask[-1]
loss_dict['loss_dice'] = losses_dice[-1]
# loss from other decoder layers
num_dec_layer = 0
for loss_cls_i, loss_mask_i, loss_dice_i in zip(
losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]):
loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
loss_dict[f'd{num_dec_layer}.loss_mask'] = loss_mask_i
loss_dict[f'd{num_dec_layer}.loss_dice'] = loss_dice_i
num_dec_layer += 1
return loss_dict
def _loss_by_feat_single(self, cls_scores: Tensor, mask_preds: Tensor,
batch_gt_instances: List[InstanceData],
batch_img_metas: List[dict]) -> Tuple[Tensor]:
"""Loss function for outputs from a single decoder layer.
Args:
cls_scores (Tensor): Mask score logits from a single decoder layer
for all images. Shape (batch_size, num_queries,
cls_out_channels). Note `cls_out_channels` should includes
background.
mask_preds (Tensor): Mask logits for a pixel decoder for all
images. Shape (batch_size, num_queries, h, w).
batch_gt_instances (list[obj:`InstanceData`]): each contains
``labels`` and ``masks``.
batch_img_metas (list[dict]): List of image meta information.
Returns:
tuple[Tensor]: Loss components for outputs from a single decoder\
layer.
"""
num_imgs = cls_scores.size(0)
cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
mask_preds_list = [mask_preds[i] for i in range(num_imgs)]
(labels_list, label_weights_list, mask_targets_list, mask_weights_list,
avg_factor) = self.get_targets(cls_scores_list, mask_preds_list,
batch_gt_instances, batch_img_metas)
# shape (batch_size, num_queries)
labels = torch.stack(labels_list, dim=0)
# shape (batch_size, num_queries)
label_weights = torch.stack(label_weights_list, dim=0)
# shape (num_total_gts, h, w)
mask_targets = torch.cat(mask_targets_list, dim=0)
# shape (batch_size, num_queries)
mask_weights = torch.stack(mask_weights_list, dim=0)
# classfication loss
# shape (batch_size * num_queries, )
cls_scores = cls_scores.flatten(0, 1)
labels = labels.flatten(0, 1)
label_weights = label_weights.flatten(0, 1)
class_weight = cls_scores.new_tensor(self.class_weight)
loss_cls = self.loss_cls(
cls_scores,
labels,
label_weights,
avg_factor=class_weight[labels].sum())
num_total_masks = reduce_mean(cls_scores.new_tensor([avg_factor]))
num_total_masks = max(num_total_masks, 1)
# extract positive ones
# shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
mask_preds = mask_preds[mask_weights > 0]
target_shape = mask_targets.shape[-2:]
if mask_targets.shape[0] == 0:
# zero match
loss_dice = mask_preds.sum()
loss_mask = mask_preds.sum()
return loss_cls, loss_mask, loss_dice
# upsample to shape of target
# shape (num_total_gts, h, w)
mask_preds = F.interpolate(
mask_preds.unsqueeze(1),
target_shape,
mode='bilinear',
align_corners=False).squeeze(1)
# dice loss
loss_dice = self.loss_dice(
mask_preds, mask_targets, avg_factor=num_total_masks)
# mask loss
# FocalLoss support input of shape (n, num_class)
h, w = mask_preds.shape[-2:]
# shape (num_total_gts, h, w) -> (num_total_gts * h * w, 1)
mask_preds = mask_preds.reshape(-1, 1)
# shape (num_total_gts, h, w) -> (num_total_gts * h * w)
mask_targets = mask_targets.reshape(-1)
# target is (1 - mask_targets) !!!
loss_mask = self.loss_mask(
mask_preds, 1 - mask_targets, avg_factor=num_total_masks * h * w)
return loss_cls, loss_mask, loss_dice
def forward(self, x: Tuple[Tensor],
batch_data_samples: SampleList) -> Tuple[Tensor]:
"""Forward function.
Args:
x (tuple[Tensor]): Features from the upstream network, each
is a 4D-tensor.
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
Returns:
tuple[Tensor]: a tuple contains two elements.
- all_cls_scores (Tensor): Classification scores for each\
scale level. Each is a 4D-tensor with shape\
(num_decoder, batch_size, num_queries, cls_out_channels).\
Note `cls_out_channels` should includes background.
- all_mask_preds (Tensor): Mask scores for each decoder\
layer. Each with shape (num_decoder, batch_size,\
num_queries, h, w).
"""
batch_img_metas = [
data_sample.metainfo for data_sample in batch_data_samples
]
batch_size = len(batch_img_metas)
input_img_h, input_img_w = batch_img_metas[0]['batch_input_shape']
padding_mask = x[-1].new_ones((batch_size, input_img_h, input_img_w),
dtype=torch.float32)
for i in range(batch_size):
img_h, img_w = batch_img_metas[i]['img_shape']
padding_mask[i, :img_h, :img_w] = 0
padding_mask = F.interpolate(
padding_mask.unsqueeze(1), size=x[-1].shape[-2:],
mode='nearest').to(torch.bool).squeeze(1)
# when backbone is swin, memory is output of last stage of swin.
# when backbone is r50, memory is output of tranformer encoder.
mask_features, memory = self.pixel_decoder(x, batch_img_metas)
pos_embed = self.decoder_pe(padding_mask)
memory = self.decoder_input_proj(memory)
# shape (batch_size, c, h, w) -> (batch_size, h*w, c)
memory = memory.flatten(2).permute(0, 2, 1)
pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
# shape (batch_size, h * w)
padding_mask = padding_mask.flatten(1)
# shape = (num_queries, embed_dims)
query_embed = self.query_embed.weight
# shape = (batch_size, num_queries, embed_dims)
query_embed = query_embed.unsqueeze(0).repeat(batch_size, 1, 1)
target = torch.zeros_like(query_embed)
# shape (num_decoder, num_queries, batch_size, embed_dims)
out_dec = self.transformer_decoder(
query=target,
key=memory,
value=memory,
query_pos=query_embed,
key_pos=pos_embed,
key_padding_mask=padding_mask)
# cls_scores
all_cls_scores = self.cls_embed(out_dec)
# mask_preds
mask_embed = self.mask_embed(out_dec)
all_mask_preds = torch.einsum('lbqc,bchw->lbqhw', mask_embed,
mask_features)
return all_cls_scores, all_mask_preds
def loss(
self,
x: Tuple[Tensor],
batch_data_samples: SampleList,
) -> Dict[str, Tensor]:
"""Perform forward propagation and loss calculation of the panoptic
head on the features of the upstream network.
Args:
x (tuple[Tensor]): Multi-level features from the upstream
network, each is a 4D-tensor.
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
batch_img_metas = []
batch_gt_instances = []
batch_gt_semantic_segs = []
for data_sample in batch_data_samples:
batch_img_metas.append(data_sample.metainfo)
batch_gt_instances.append(data_sample.gt_instances)
if 'gt_sem_seg' in data_sample:
batch_gt_semantic_segs.append(data_sample.gt_sem_seg)
else:
batch_gt_semantic_segs.append(None)
# forward
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
# preprocess ground truth
batch_gt_instances = self.preprocess_gt(batch_gt_instances,
batch_gt_semantic_segs)
# loss
losses = self.loss_by_feat(all_cls_scores, all_mask_preds,
batch_gt_instances, batch_img_metas)
return losses
def predict(self, x: Tuple[Tensor],
batch_data_samples: SampleList) -> Tuple[Tensor]:
"""Test without augmentaton.
Args:
x (tuple[Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
Returns:
tuple[Tensor]: A tuple contains two tensors.
- mask_cls_results (Tensor): Mask classification logits,\
shape (batch_size, num_queries, cls_out_channels).
Note `cls_out_channels` should includes background.
- mask_pred_results (Tensor): Mask logits, shape \
(batch_size, num_queries, h, w).
"""
batch_img_metas = [
data_sample.metainfo for data_sample in batch_data_samples
]
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
mask_cls_results = all_cls_scores[-1]
mask_pred_results = all_mask_preds[-1]
# upsample masks
img_shape = batch_img_metas[0]['batch_input_shape']
mask_pred_results = F.interpolate(
mask_pred_results,
size=(img_shape[0], img_shape[1]),
mode='bilinear',
align_corners=False)
return mask_cls_results, mask_pred_results
|