Spaces:
Runtime error
Runtime error
File size: 5,907 Bytes
f549064 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Tuple
from mmengine.model import uniform_init
from torch import Tensor, nn
from mmdet.registry import MODELS
from ..layers import SinePositionalEncoding
from ..layers.transformer import (DABDetrTransformerDecoder,
DABDetrTransformerEncoder, inverse_sigmoid)
from .detr import DETR
@MODELS.register_module()
class DABDETR(DETR):
r"""Implementation of `DAB-DETR:
Dynamic Anchor Boxes are Better Queries for DETR.
<https://arxiv.org/abs/2201.12329>`_.
Code is modified from the `official github repo
<https://github.com/IDEA-Research/DAB-DETR>`_.
Args:
with_random_refpoints (bool): Whether to randomly initialize query
embeddings and not update them during training.
Defaults to False.
num_patterns (int): Inspired by Anchor-DETR. Defaults to 0.
"""
def __init__(self,
*args,
with_random_refpoints: bool = False,
num_patterns: int = 0,
**kwargs) -> None:
self.with_random_refpoints = with_random_refpoints
assert isinstance(num_patterns, int), \
f'num_patterns should be int but {num_patterns}.'
self.num_patterns = num_patterns
super().__init__(*args, **kwargs)
def _init_layers(self) -> None:
"""Initialize layers except for backbone, neck and bbox_head."""
self.positional_encoding = SinePositionalEncoding(
**self.positional_encoding)
self.encoder = DABDetrTransformerEncoder(**self.encoder)
self.decoder = DABDetrTransformerDecoder(**self.decoder)
self.embed_dims = self.encoder.embed_dims
self.query_dim = self.decoder.query_dim
self.query_embedding = nn.Embedding(self.num_queries, self.query_dim)
if self.num_patterns > 0:
self.patterns = nn.Embedding(self.num_patterns, self.embed_dims)
num_feats = self.positional_encoding.num_feats
assert num_feats * 2 == self.embed_dims, \
f'embed_dims should be exactly 2 times of num_feats. ' \
f'Found {self.embed_dims} and {num_feats}.'
def init_weights(self) -> None:
"""Initialize weights for Transformer and other components."""
super(DABDETR, self).init_weights()
if self.with_random_refpoints:
uniform_init(self.query_embedding)
self.query_embedding.weight.data[:, :2] = \
inverse_sigmoid(self.query_embedding.weight.data[:, :2])
self.query_embedding.weight.data[:, :2].requires_grad = False
def pre_decoder(self, memory: Tensor) -> Tuple[Dict, Dict]:
"""Prepare intermediate variables before entering Transformer decoder,
such as `query`, `query_pos`.
Args:
memory (Tensor): The output embeddings of the Transformer encoder,
has shape (bs, num_feat_points, dim).
Returns:
tuple[dict, dict]: The first dict contains the inputs of decoder
and the second dict contains the inputs of the bbox_head function.
- decoder_inputs_dict (dict): The keyword args dictionary of
`self.forward_decoder()`, which includes 'query', 'query_pos',
'memory' and 'reg_branches'.
- head_inputs_dict (dict): The keyword args dictionary of the
bbox_head functions, which is usually empty, or includes
`enc_outputs_class` and `enc_outputs_class` when the detector
support 'two stage' or 'query selection' strategies.
"""
batch_size = memory.size(0)
query_pos = self.query_embedding.weight
query_pos = query_pos.unsqueeze(0).repeat(batch_size, 1, 1)
if self.num_patterns == 0:
query = query_pos.new_zeros(batch_size, self.num_queries,
self.embed_dims)
else:
query = self.patterns.weight[:, None, None, :]\
.repeat(1, self.num_queries, batch_size, 1)\
.view(-1, batch_size, self.embed_dims)\
.permute(1, 0, 2)
query_pos = query_pos.repeat(1, self.num_patterns, 1)
decoder_inputs_dict = dict(
query_pos=query_pos, query=query, memory=memory)
head_inputs_dict = dict()
return decoder_inputs_dict, head_inputs_dict
def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor,
memory_mask: Tensor, memory_pos: Tensor) -> Dict:
"""Forward with Transformer decoder.
Args:
query (Tensor): The queries of decoder inputs, has shape
(bs, num_queries, dim).
query_pos (Tensor): The positional queries of decoder inputs,
has shape (bs, num_queries, dim).
memory (Tensor): The output embeddings of the Transformer encoder,
has shape (bs, num_feat_points, dim).
memory_mask (Tensor): ByteTensor, the padding mask of the memory,
has shape (bs, num_feat_points).
memory_pos (Tensor): The positional embeddings of memory, has
shape (bs, num_feat_points, dim).
Returns:
dict: The dictionary of decoder outputs, which includes the
`hidden_states` and `references` of the decoder output.
"""
hidden_states, references = self.decoder(
query=query,
key=memory,
query_pos=query_pos,
key_pos=memory_pos,
key_padding_mask=memory_mask,
reg_branches=self.bbox_head.
fc_reg # iterative refinement for anchor boxes
)
head_inputs_dict = dict(
hidden_states=hidden_states, references=references)
return head_inputs_dict
|