Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
class LayerScale(nn.Module): | |
"""LayerScale layer. | |
Args: | |
dim (int): Dimension of input features. | |
inplace (bool): inplace: can optionally do the | |
operation in-place. Defaults to False. | |
data_format (str): The input data format, could be 'channels_last' | |
or 'channels_first', representing (B, C, H, W) and | |
(B, N, C) format data respectively. Defaults to 'channels_last'. | |
""" | |
def __init__(self, | |
dim: int, | |
inplace: bool = False, | |
data_format: str = 'channels_last'): | |
super().__init__() | |
assert data_format in ('channels_last', 'channels_first'), \ | |
"'data_format' could only be channels_last or channels_first." | |
self.inplace = inplace | |
self.data_format = data_format | |
self.weight = nn.Parameter(torch.ones(dim) * 1e-5) | |
def forward(self, x): | |
if self.data_format == 'channels_first': | |
if self.inplace: | |
return x.mul_(self.weight.view(-1, 1, 1)) | |
else: | |
return x * self.weight.view(-1, 1, 1) | |
return x.mul_(self.weight) if self.inplace else x * self.weight | |