Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from funasr_detach.models.data2vec.multihead_attention import MultiheadAttention | |
class Fp32LayerNorm(nn.LayerNorm): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def forward(self, input): | |
output = F.layer_norm( | |
input.float(), | |
self.normalized_shape, | |
self.weight.float() if self.weight is not None else None, | |
self.bias.float() if self.bias is not None else None, | |
self.eps, | |
) | |
return output.type_as(input) | |
class Fp32GroupNorm(nn.GroupNorm): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def forward(self, input): | |
output = F.group_norm( | |
input.float(), | |
self.num_groups, | |
self.weight.float() if self.weight is not None else None, | |
self.bias.float() if self.bias is not None else None, | |
self.eps, | |
) | |
return output.type_as(input) | |
class TransposeLast(nn.Module): | |
def __init__(self, deconstruct_idx=None): | |
super().__init__() | |
self.deconstruct_idx = deconstruct_idx | |
def forward(self, x): | |
if self.deconstruct_idx is not None: | |
x = x[self.deconstruct_idx] | |
return x.transpose(-2, -1) | |
class SamePad(nn.Module): | |
def __init__(self, kernel_size, causal=False): | |
super().__init__() | |
if causal: | |
self.remove = kernel_size - 1 | |
else: | |
self.remove = 1 if kernel_size % 2 == 0 else 0 | |
def forward(self, x): | |
if self.remove > 0: | |
x = x[:, :, : -self.remove] | |
return x | |
def pad_to_multiple(x, multiple, dim=-1, value=0): | |
# Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41 | |
if x is None: | |
return None, 0 | |
tsz = x.size(dim) | |
m = tsz / multiple | |
remainder = math.ceil(m) * multiple - tsz | |
if m.is_integer(): | |
return x, 0 | |
pad_offset = (0,) * (-1 - dim) * 2 | |
return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder | |
def gelu_accurate(x): | |
if not hasattr(gelu_accurate, "_a"): | |
gelu_accurate._a = math.sqrt(2 / math.pi) | |
return ( | |
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) | |
) | |
def gelu(x: torch.Tensor) -> torch.Tensor: | |
return torch.nn.functional.gelu(x.float()).type_as(x) | |
def get_available_activation_fns(): | |
return [ | |
"relu", | |
"gelu", | |
"gelu_fast", # deprecated | |
"gelu_accurate", | |
"tanh", | |
"linear", | |
] | |
def get_activation_fn(activation: str): | |
"""Returns the activation function corresponding to `activation`""" | |
if activation == "relu": | |
return F.relu | |
elif activation == "gelu": | |
return gelu | |
elif activation == "gelu_accurate": | |
return gelu_accurate | |
elif activation == "tanh": | |
return torch.tanh | |
elif activation == "linear": | |
return lambda x: x | |
elif activation == "swish": | |
return torch.nn.SiLU | |
else: | |
raise RuntimeError("--activation-fn {} not supported".format(activation)) | |
def init_bert_params(module): | |
""" | |
Initialize the weights specific to the BERT Model. | |
This overrides the default initializations depending on the specified arguments. | |
1. If normal_init_linear_weights is set then weights of linear | |
layer will be initialized using the normal distribution and | |
bais will be set to the specified value. | |
2. If normal_init_embed_weights is set then weights of embedding | |
layer will be initialized using the normal distribution. | |
3. If normal_init_proj_weights is set then weights of | |
in_project_weight for MultiHeadAttention initialized using | |
the normal distribution (to be validated). | |
""" | |
def normal_(data): | |
# with FSDP, module params will be on CUDA, so we cast them back to CPU | |
# so that the RNG is consistent with and without FSDP | |
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) | |
if isinstance(module, nn.Linear): | |
normal_(module.weight.data) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
if isinstance(module, nn.Embedding): | |
normal_(module.weight.data) | |
if module.padding_idx is not None: | |
module.weight.data[module.padding_idx].zero_() | |
if isinstance(module, MultiheadAttention): | |
normal_(module.q_proj.weight.data) | |
normal_(module.k_proj.weight.data) | |
normal_(module.v_proj.weight.data) | |