import torch import torch.nn as nn class LayerScale(nn.Module): def __init__( self, dim: int, init_values: float | torch.Tensor = 1e-5, inplace: bool = False, ) -> None: super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma