File size: 3,777 Bytes
f549064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# Copyright (c) OpenMMLab. All rights reserved.
import math
from functools import partial

import torch
import torch.nn as nn
from mmengine.model import BaseModule
from mmengine.utils import digit_version


class ConditionalPositionEncoding(BaseModule):
    """The Conditional Position Encoding (CPE) module.

    The CPE is the implementation of 'Conditional Positional Encodings
    for Vision Transformers <https://arxiv.org/abs/2102.10882>'_.

    Args:
       in_channels (int): Number of input channels.
       embed_dims (int): The feature dimension. Default: 768.
       stride (int): Stride of conv layer. Default: 1.
    """

    def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None):
        super(ConditionalPositionEncoding, self).__init__(init_cfg=init_cfg)
        self.proj = nn.Conv2d(
            in_channels,
            embed_dims,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=True,
            groups=embed_dims)
        self.stride = stride

    def forward(self, x, hw_shape):
        B, N, C = x.shape
        H, W = hw_shape
        feat_token = x
        # convert (B, N, C) to (B, C, H, W)
        cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W).contiguous()
        if self.stride == 1:
            x = self.proj(cnn_feat) + cnn_feat
        else:
            x = self.proj(cnn_feat)
        x = x.flatten(2).transpose(1, 2)
        return x


class PositionEncodingFourier(BaseModule):
    """The Position Encoding Fourier (PEF) module.

    The PEF is adopted from EdgeNeXt <https://arxiv.org/abs/2206.10589>'_.
    Args:
        in_channels (int): Number of input channels.
            Default: 32
        embed_dims (int): The feature dimension.
            Default: 768.
        temperature (int): Temperature.
            Default: 10000.
        dtype (torch.dtype): The data type.
            Default: torch.float32.
        init_cfg (dict): The config dict for initializing the module.
            Default: None.
    """

    def __init__(self,
                 in_channels=32,
                 embed_dims=768,
                 temperature=10000,
                 dtype=torch.float32,
                 init_cfg=None):
        super(PositionEncodingFourier, self).__init__(init_cfg=init_cfg)
        self.proj = nn.Conv2d(in_channels * 2, embed_dims, kernel_size=1)
        self.scale = 2 * math.pi
        self.in_channels = in_channels
        self.embed_dims = embed_dims
        self.dtype = dtype

        if digit_version(torch.__version__) < digit_version('1.8.0'):
            floor_div = torch.floor_divide
        else:
            floor_div = partial(torch.div, rounding_mode='floor')
        dim_t = torch.arange(in_channels, dtype=self.dtype)
        self.dim_t = temperature**(2 * floor_div(dim_t, 2) / in_channels)

    def forward(self, bhw_shape):
        B, H, W = bhw_shape
        mask = torch.zeros(B, H, W).bool().to(self.proj.weight.device)
        not_mask = ~mask
        eps = 1e-6
        y_embed = not_mask.cumsum(1, dtype=self.dtype)
        x_embed = not_mask.cumsum(2, dtype=self.dtype)
        y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
        x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = self.dim_t.to(mask.device)
        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack(
            (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
            dim=4).flatten(3)
        pos_y = torch.stack(
            (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
            dim=4).flatten(3)

        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        pos = self.proj(pos)

        return pos