mrfakename's picture
Super-squash branch 'main' using huggingface_hub
0102e16 verified
raw
history blame
14.6 kB
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask
from funasr_detach.models.transformer.layer_norm import LayerNorm
from funasr_detach.models.encoder.abs_encoder import AbsEncoder
import math
from funasr_detach.models.transformer.utils.repeat import repeat
from funasr_detach.models.transformer.utils.multi_layer_conv import FsmnFeedForward
class FsmnBlock(torch.nn.Module):
def __init__(
self,
n_feat,
dropout_rate,
kernel_size,
fsmn_shift=0,
):
super().__init__()
self.dropout = nn.Dropout(p=dropout_rate)
self.fsmn_block = nn.Conv1d(
n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
)
# padding
left_padding = (kernel_size - 1) // 2
if fsmn_shift > 0:
left_padding = left_padding + fsmn_shift
right_padding = kernel_size - 1 - left_padding
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
def forward(self, inputs, mask, mask_shfit_chunk=None):
b, t, d = inputs.size()
if mask is not None:
mask = torch.reshape(mask, (b, -1, 1))
if mask_shfit_chunk is not None:
mask = mask * mask_shfit_chunk
inputs = inputs * mask
x = inputs.transpose(1, 2)
x = self.pad_fn(x)
x = self.fsmn_block(x)
x = x.transpose(1, 2)
x = x + inputs
x = self.dropout(x)
return x * mask
class EncoderLayer(torch.nn.Module):
def __init__(self, in_size, size, feed_forward, fsmn_block, dropout_rate=0.0):
super().__init__()
self.in_size = in_size
self.size = size
self.ffn = feed_forward
self.memory = fsmn_block
self.dropout = nn.Dropout(dropout_rate)
def forward(
self, xs_pad: torch.Tensor, mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# xs_pad in Batch, Time, Dim
context = self.ffn(xs_pad)[0]
memory = self.memory(context, mask)
memory = self.dropout(memory)
if self.in_size == self.size:
return memory + xs_pad, mask
return memory, mask
class FsmnEncoder(AbsEncoder):
"""Encoder using Fsmn"""
def __init__(
self,
in_units,
filter_size,
fsmn_num_layers,
dnn_num_layers,
num_memory_units=512,
ffn_inner_dim=2048,
dropout_rate=0.0,
shift=0,
position_encoder=None,
sample_rate=1,
out_units=None,
tf2torch_tensor_name_prefix_torch="post_net",
tf2torch_tensor_name_prefix_tf="EAND/post_net",
):
"""Initializes the parameters of the encoder.
Args:
filter_size: the total order of memory block
fsmn_num_layers: The number of fsmn layers.
dnn_num_layers: The number of dnn layers
num_units: The number of memory units.
ffn_inner_dim: The number of units of the inner linear transformation
in the feed forward layer.
dropout_rate: The probability to drop units from the outputs.
shift: left padding, to control delay
position_encoder: The :class:`opennmt.layers.position.PositionEncoder` to
apply on inputs or ``None``.
"""
super(FsmnEncoder, self).__init__()
self.in_units = in_units
self.filter_size = filter_size
self.fsmn_num_layers = fsmn_num_layers
self.dnn_num_layers = dnn_num_layers
self.num_memory_units = num_memory_units
self.ffn_inner_dim = ffn_inner_dim
self.dropout_rate = dropout_rate
self.shift = shift
if not isinstance(shift, list):
self.shift = [shift for _ in range(self.fsmn_num_layers)]
self.sample_rate = sample_rate
if not isinstance(sample_rate, list):
self.sample_rate = [sample_rate for _ in range(self.fsmn_num_layers)]
self.position_encoder = position_encoder
self.dropout = nn.Dropout(dropout_rate)
self.out_units = out_units
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.fsmn_layers = repeat(
self.fsmn_num_layers,
lambda lnum: EncoderLayer(
in_units if lnum == 0 else num_memory_units,
num_memory_units,
FsmnFeedForward(
in_units if lnum == 0 else num_memory_units,
ffn_inner_dim,
num_memory_units,
1,
dropout_rate,
),
FsmnBlock(
num_memory_units, dropout_rate, filter_size, self.shift[lnum]
),
),
)
self.dnn_layers = repeat(
dnn_num_layers,
lambda lnum: FsmnFeedForward(
num_memory_units,
ffn_inner_dim,
num_memory_units,
1,
dropout_rate,
),
)
if out_units is not None:
self.conv1d = nn.Conv1d(num_memory_units, out_units, 1, 1)
def output_size(self) -> int:
return self.num_memory_units
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
inputs = xs_pad
if self.position_encoder is not None:
inputs = self.position_encoder(inputs)
inputs = self.dropout(inputs)
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
inputs = self.fsmn_layers(inputs, masks)[0]
inputs = self.dnn_layers(inputs)[0]
if self.out_units is not None:
inputs = self.conv1d(inputs.transpose(1, 2)).transpose(1, 2)
return inputs, ilens, None
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
map_dict_local = {
# torch: conv1d.weight in "out_channel in_channel kernel_size"
# tf : conv1d.weight in "kernel_size in_channel out_channel"
# torch: linear.weight in "out_channel in_channel"
# tf : dense.weight in "in_channel out_channel"
# for fsmn_layers
"{}.fsmn_layers.layeridx.ffn.norm.bias".format(tensor_name_prefix_torch): {
"name": "{}/fsmn_layer_layeridx/ffn/LayerNorm/beta".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
},
"{}.fsmn_layers.layeridx.ffn.norm.weight".format(
tensor_name_prefix_torch
): {
"name": "{}/fsmn_layer_layeridx/ffn/LayerNorm/gamma".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
},
"{}.fsmn_layers.layeridx.ffn.w_1.bias".format(tensor_name_prefix_torch): {
"name": "{}/fsmn_layer_layeridx/ffn/conv1d/bias".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
},
"{}.fsmn_layers.layeridx.ffn.w_1.weight".format(tensor_name_prefix_torch): {
"name": "{}/fsmn_layer_layeridx/ffn/conv1d/kernel".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": (2, 1, 0),
},
"{}.fsmn_layers.layeridx.ffn.w_2.weight".format(tensor_name_prefix_torch): {
"name": "{}/fsmn_layer_layeridx/ffn/conv1d_1/kernel".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": (2, 1, 0),
},
"{}.fsmn_layers.layeridx.memory.fsmn_block.weight".format(
tensor_name_prefix_torch
): {
"name": "{}/fsmn_layer_layeridx/memory/depth_conv_w".format(
tensor_name_prefix_tf
),
"squeeze": 0,
"transpose": (1, 2, 0),
}, # (1, 31, 512, 1) -> (31, 512, 1) -> (512, 1, 31)
# for dnn_layers
"{}.dnn_layers.layeridx.norm.bias".format(tensor_name_prefix_torch): {
"name": "{}/dnn_layer_layeridx/LayerNorm/beta".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
},
"{}.dnn_layers.layeridx.norm.weight".format(tensor_name_prefix_torch): {
"name": "{}/dnn_layer_layeridx/LayerNorm/gamma".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
},
"{}.dnn_layers.layeridx.w_1.bias".format(tensor_name_prefix_torch): {
"name": "{}/dnn_layer_layeridx/conv1d/bias".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
},
"{}.dnn_layers.layeridx.w_1.weight".format(tensor_name_prefix_torch): {
"name": "{}/dnn_layer_layeridx/conv1d/kernel".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": (2, 1, 0),
},
"{}.dnn_layers.layeridx.w_2.weight".format(tensor_name_prefix_torch): {
"name": "{}/dnn_layer_layeridx/conv1d_1/kernel".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": (2, 1, 0),
},
}
if self.out_units is not None:
# add output layer
map_dict_local.update(
{
"{}.conv1d.weight".format(tensor_name_prefix_torch): {
"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (2, 1, 0),
},
"{}.conv1d.bias".format(tensor_name_prefix_torch): {
"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
}
)
return map_dict_local
def convert_tf2torch(
self,
var_dict_tf,
var_dict_torch,
):
map_dict = self.gen_tf2torch_map_dict()
var_dict_torch_update = dict()
for name in sorted(var_dict_torch.keys(), reverse=False):
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
# process special (first and last) layers
if name in map_dict:
name_tf = map_dict[name]["name"]
data_tf = var_dict_tf[name_tf]
if map_dict[name]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
if map_dict[name]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert (
var_dict_torch[name].size() == data_tf.size()
), "{}, {}, {} != {}".format(
name, name_tf, var_dict_torch[name].size(), data_tf.size()
)
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
)
)
# process general layers
else:
# self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
names = name.replace(
self.tf2torch_tensor_name_prefix_torch, "todo"
).split(".")
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(
data_tf, axis=map_dict[name_q]["squeeze"]
)
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(
data_tf, map_dict[name_q]["transpose"]
)
data_tf = (
torch.from_numpy(data_tf).type(torch.float32).to("cpu")
)
assert (
var_dict_torch[name].size() == data_tf.size()
), "{}, {}, {} != {}".format(
name, name_tf, var_dict_torch[name].size(), data_tf.size()
)
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name,
data_tf.size(),
name_tf,
var_dict_tf[name_tf].shape,
)
)
else:
logging.warning("{} is missed from tf checkpoint".format(name))
return var_dict_torch_update