KyanChen's picture
init
f549064
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
class LayerScale(nn.Module):
"""LayerScale layer.
Args:
dim (int): Dimension of input features.
inplace (bool): inplace: can optionally do the
operation in-place. Defaults to False.
data_format (str): The input data format, could be 'channels_last'
or 'channels_first', representing (B, C, H, W) and
(B, N, C) format data respectively. Defaults to 'channels_last'.
"""
def __init__(self,
dim: int,
inplace: bool = False,
data_format: str = 'channels_last'):
super().__init__()
assert data_format in ('channels_last', 'channels_first'), \
"'data_format' could only be channels_last or channels_first."
self.inplace = inplace
self.data_format = data_format
self.weight = nn.Parameter(torch.ones(dim) * 1e-5)
def forward(self, x):
if self.data_format == 'channels_first':
if self.inplace:
return x.mul_(self.weight.view(-1, 1, 1))
else:
return x * self.weight.view(-1, 1, 1)
return x.mul_(self.weight) if self.inplace else x * self.weight