mrfakename's picture
Super-squash branch 'main' using huggingface_hub
0102e16 verified
raw
history blame
43.3 kB
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import torch
import logging
import numpy as np
from typing import Tuple
from funasr_detach.register import tables
from funasr_detach.models.scama import utils as myutils
from funasr_detach.models.transformer.utils.repeat import repeat
from funasr_detach.models.transformer.layer_norm import LayerNorm
from funasr_detach.models.transformer.embedding import PositionalEncoding
from funasr_detach.models.paraformer.decoder import (
DecoderLayerSANM,
ParaformerSANMDecoder,
)
from funasr_detach.models.sanm.positionwise_feed_forward import (
PositionwiseFeedForwardDecoderSANM,
)
from funasr_detach.models.sanm.attention import (
MultiHeadedAttentionSANMDecoder,
MultiHeadedAttentionCrossAtt,
)
class ContextualDecoderLayer(torch.nn.Module):
def __init__(
self,
size,
self_attn,
src_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
):
"""Construct an DecoderLayer object."""
super(ContextualDecoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
if self_attn is not None:
self.norm2 = LayerNorm(size)
if src_attn is not None:
self.norm3 = LayerNorm(size)
self.dropout = torch.nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear1 = torch.nn.Linear(size + size, size)
self.concat_linear2 = torch.nn.Linear(size + size, size)
def forward(
self,
tgt,
tgt_mask,
memory,
memory_mask,
cache=None,
):
# tgt = self.dropout(tgt)
if isinstance(tgt, Tuple):
tgt, _ = tgt
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
tgt = self.feed_forward(tgt)
x = tgt
if self.normalize_before:
tgt = self.norm2(tgt)
if self.training:
cache = None
x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
x = residual + self.dropout(x)
x_self_attn = x
residual = x
if self.normalize_before:
x = self.norm3(x)
x = self.src_attn(x, memory, memory_mask)
x_src_attn = x
x = residual + self.dropout(x)
return x, tgt_mask, x_self_attn, x_src_attn
class ContextualBiasDecoder(torch.nn.Module):
def __init__(
self,
size,
src_attn,
dropout_rate,
normalize_before=True,
):
"""Construct an DecoderLayer object."""
super(ContextualBiasDecoder, self).__init__()
self.size = size
self.src_attn = src_attn
if src_attn is not None:
self.norm3 = LayerNorm(size)
self.dropout = torch.nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
x = tgt
if self.src_attn is not None:
if self.normalize_before:
x = self.norm3(x)
x = self.dropout(self.src_attn(x, memory, memory_mask))
return x, tgt_mask, memory, memory_mask, cache
@tables.register("decoder_classes", "ContextualParaformerDecoder")
class ContextualParaformerDecoder(ParaformerSANMDecoder):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2006.01713
"""
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
att_layer_num: int = 6,
kernel_size: int = 21,
sanm_shfit: int = 0,
):
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
if input_layer == "none":
self.embed = None
if input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(vocab_size, attention_dim),
# pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(vocab_size, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate),
)
else:
raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
self.normalize_before = normalize_before
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
if use_output_layer:
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
else:
self.output_layer = None
self.att_layer_num = att_layer_num
self.num_blocks = num_blocks
if sanm_shfit is None:
sanm_shfit = (kernel_size - 1) // 2
self.decoders = repeat(
att_layer_num - 1,
lambda lnum: DecoderLayerSANM(
attention_dim,
MultiHeadedAttentionSANMDecoder(
attention_dim,
self_attention_dropout_rate,
kernel_size,
sanm_shfit=sanm_shfit,
),
MultiHeadedAttentionCrossAtt(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForwardDecoderSANM(
attention_dim, linear_units, dropout_rate
),
dropout_rate,
normalize_before,
concat_after,
),
)
self.dropout = torch.nn.Dropout(dropout_rate)
self.bias_decoder = ContextualBiasDecoder(
size=attention_dim,
src_attn=MultiHeadedAttentionCrossAtt(
attention_heads, attention_dim, src_attention_dropout_rate
),
dropout_rate=dropout_rate,
normalize_before=True,
)
self.bias_output = torch.nn.Conv1d(
attention_dim * 2, attention_dim, 1, bias=False
)
self.last_decoder = ContextualDecoderLayer(
attention_dim,
MultiHeadedAttentionSANMDecoder(
attention_dim,
self_attention_dropout_rate,
kernel_size,
sanm_shfit=sanm_shfit,
),
MultiHeadedAttentionCrossAtt(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForwardDecoderSANM(
attention_dim, linear_units, dropout_rate
),
dropout_rate,
normalize_before,
concat_after,
)
if num_blocks - att_layer_num <= 0:
self.decoders2 = None
else:
self.decoders2 = repeat(
num_blocks - att_layer_num,
lambda lnum: DecoderLayerSANM(
attention_dim,
MultiHeadedAttentionSANMDecoder(
attention_dim,
self_attention_dropout_rate,
kernel_size,
sanm_shfit=0,
),
None,
PositionwiseFeedForwardDecoderSANM(
attention_dim, linear_units, dropout_rate
),
dropout_rate,
normalize_before,
concat_after,
),
)
self.decoders3 = repeat(
1,
lambda lnum: DecoderLayerSANM(
attention_dim,
None,
None,
PositionwiseFeedForwardDecoderSANM(
attention_dim, linear_units, dropout_rate
),
dropout_rate,
normalize_before,
concat_after,
),
)
def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
contextual_info: torch.Tensor,
clas_scale: float = 1.0,
return_hidden: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
hlens: (batch)
ys_in_pad:
input token ids, int64 (batch, maxlen_out)
if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
ys_in_lens: (batch)
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, token)
if use_output_layer is True,
olens: (batch, )
"""
tgt = ys_in_pad
tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
memory = hs_pad
memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
x = tgt
x, tgt_mask, memory, memory_mask, _ = self.decoders(
x, tgt_mask, memory, memory_mask
)
_, _, x_self_attn, x_src_attn = self.last_decoder(
x, tgt_mask, memory, memory_mask
)
# contextual paraformer related
contextual_length = (
torch.Tensor([contextual_info.shape[1]]).int().repeat(hs_pad.shape[0])
)
contextual_mask = myutils.sequence_mask(
contextual_length, device=memory.device
)[:, None, :]
cx, tgt_mask, _, _, _ = self.bias_decoder(
x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask
)
if self.bias_output is not None:
x = torch.cat([x_src_attn, cx * clas_scale], dim=2)
x = self.bias_output(x.transpose(1, 2)).transpose(1, 2) # 2D -> D
x = x_self_attn + self.dropout(x)
if self.decoders2 is not None:
x, tgt_mask, memory, memory_mask, _ = self.decoders2(
x, tgt_mask, memory, memory_mask
)
x, tgt_mask, memory, memory_mask, _ = self.decoders3(
x, tgt_mask, memory, memory_mask
)
if self.normalize_before:
x = self.after_norm(x)
olens = tgt_mask.sum(1)
if self.output_layer is not None and return_hidden is False:
x = self.output_layer(x)
return x, olens
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
map_dict_local = {
## decoder
# ffn
"{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch): {
"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch): {
"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders.layeridx.feed_forward.w_1.weight".format(
tensor_name_prefix_torch
): {
"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(
tensor_name_prefix_tf
),
"squeeze": 0,
"transpose": (1, 0),
}, # (1024,256),(1,256,1024)
"{}.decoders.layeridx.feed_forward.w_1.bias".format(
tensor_name_prefix_torch
): {
"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.decoders.layeridx.feed_forward.norm.weight".format(
tensor_name_prefix_torch
): {
"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.decoders.layeridx.feed_forward.norm.bias".format(
tensor_name_prefix_torch
): {
"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.decoders.layeridx.feed_forward.w_2.weight".format(
tensor_name_prefix_torch
): {
"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(
tensor_name_prefix_tf
),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,1024),(1,1024,256)
# fsmn
"{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch): {
"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch): {
"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders.layeridx.self_attn.fsmn_block.weight".format(
tensor_name_prefix_torch
): {
"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
tensor_name_prefix_tf
),
"squeeze": 0,
"transpose": (1, 2, 0),
}, # (256,1,31),(1,31,256,1)
# src att
"{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch): {
"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch): {
"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders.layeridx.src_attn.linear_q.weight".format(
tensor_name_prefix_torch
): {
"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(
tensor_name_prefix_tf
),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,256),(1,256,256)
"{}.decoders.layeridx.src_attn.linear_q.bias".format(
tensor_name_prefix_torch
): {
"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders.layeridx.src_attn.linear_k_v.weight".format(
tensor_name_prefix_torch
): {
"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(
tensor_name_prefix_tf
),
"squeeze": 0,
"transpose": (1, 0),
}, # (1024,256),(1,256,1024)
"{}.decoders.layeridx.src_attn.linear_k_v.bias".format(
tensor_name_prefix_torch
): {
"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.decoders.layeridx.src_attn.linear_out.weight".format(
tensor_name_prefix_torch
): {
"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(
tensor_name_prefix_tf
),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,256),(1,256,256)
"{}.decoders.layeridx.src_attn.linear_out.bias".format(
tensor_name_prefix_torch
): {
"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
# dnn
"{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch): {
"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch): {
"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders3.layeridx.feed_forward.w_1.weight".format(
tensor_name_prefix_torch
): {
"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(
tensor_name_prefix_tf
),
"squeeze": 0,
"transpose": (1, 0),
}, # (1024,256),(1,256,1024)
"{}.decoders3.layeridx.feed_forward.w_1.bias".format(
tensor_name_prefix_torch
): {
"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.decoders3.layeridx.feed_forward.norm.weight".format(
tensor_name_prefix_torch
): {
"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.decoders3.layeridx.feed_forward.norm.bias".format(
tensor_name_prefix_torch
): {
"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.decoders3.layeridx.feed_forward.w_2.weight".format(
tensor_name_prefix_torch
): {
"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(
tensor_name_prefix_tf
),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,1024),(1,1024,256)
# embed_concat_ffn
"{}.embed_concat_ffn.layeridx.norm1.weight".format(
tensor_name_prefix_torch
): {
"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.embed_concat_ffn.layeridx.norm1.bias".format(
tensor_name_prefix_torch
): {
"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(
tensor_name_prefix_torch
): {
"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (1024,256),(1,256,1024)
"{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(
tensor_name_prefix_torch
): {
"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(
tensor_name_prefix_torch
): {
"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(
tensor_name_prefix_torch
): {
"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(
tensor_name_prefix_torch
): {
"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,1024),(1,1024,256)
# out norm
"{}.after_norm.weight".format(tensor_name_prefix_torch): {
"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.after_norm.bias".format(tensor_name_prefix_torch): {
"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
# in embed
"{}.embed.0.weight".format(tensor_name_prefix_torch): {
"name": "{}/w_embs".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (4235,256),(4235,256)
# out layer
"{}.output_layer.weight".format(tensor_name_prefix_torch): {
"name": [
"{}/dense/kernel".format(tensor_name_prefix_tf),
"{}/w_embs".format(tensor_name_prefix_tf),
],
"squeeze": [None, None],
"transpose": [(1, 0), None],
}, # (4235,256),(256,4235)
"{}.output_layer.bias".format(tensor_name_prefix_torch): {
"name": [
"{}/dense/bias".format(tensor_name_prefix_tf),
(
"seq2seq/2bias"
if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1"
else "seq2seq/bias"
),
],
"squeeze": [None, None],
"transpose": [None, None],
}, # (4235,),(4235,)
## clas decoder
# src att
"{}.bias_decoder.norm3.weight".format(tensor_name_prefix_torch): {
"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/gamma".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.bias_decoder.norm3.bias".format(tensor_name_prefix_torch): {
"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/beta".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.bias_decoder.src_attn.linear_q.weight".format(
tensor_name_prefix_torch
): {
"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/kernel".format(
tensor_name_prefix_tf
),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,256),(1,256,256)
"{}.bias_decoder.src_attn.linear_q.bias".format(tensor_name_prefix_torch): {
"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/bias".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.bias_decoder.src_attn.linear_k_v.weight".format(
tensor_name_prefix_torch
): {
"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/kernel".format(
tensor_name_prefix_tf
),
"squeeze": 0,
"transpose": (1, 0),
}, # (1024,256),(1,256,1024)
"{}.bias_decoder.src_attn.linear_k_v.bias".format(
tensor_name_prefix_torch
): {
"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/bias".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.bias_decoder.src_attn.linear_out.weight".format(
tensor_name_prefix_torch
): {
"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/kernel".format(
tensor_name_prefix_tf
),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,256),(1,256,256)
"{}.bias_decoder.src_attn.linear_out.bias".format(
tensor_name_prefix_torch
): {
"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/bias".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
# dnn
"{}.bias_output.weight".format(tensor_name_prefix_torch): {
"name": "{}/decoder_fsmn_layer_15/conv1d/kernel".format(
tensor_name_prefix_tf
),
"squeeze": None,
"transpose": (2, 1, 0),
}, # (1024,256),(1,256,1024)
}
return map_dict_local
def convert_tf2torch(
self,
var_dict_tf,
var_dict_torch,
):
map_dict = self.gen_tf2torch_map_dict()
var_dict_torch_update = dict()
decoder_layeridx_sets = set()
for name in sorted(var_dict_torch.keys(), reverse=False):
names = name.split(".")
if names[0] == self.tf2torch_tensor_name_prefix_torch:
if names[1] == "decoders":
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
layeridx_bias = 0
layeridx += layeridx_bias
decoder_layeridx_sets.add(layeridx)
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(
data_tf, axis=map_dict[name_q]["squeeze"]
)
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(
data_tf, map_dict[name_q]["transpose"]
)
data_tf = (
torch.from_numpy(data_tf).type(torch.float32).to("cpu")
)
assert (
var_dict_torch[name].size() == data_tf.size()
), "{}, {}, {} != {}".format(
name, name_tf, var_dict_torch[name].size(), data_tf.size()
)
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_v, var_dict_tf[name_tf].shape
)
)
elif names[1] == "last_decoder":
layeridx = 15
name_q = name.replace("last_decoder", "decoders.layeridx")
layeridx_bias = 0
layeridx += layeridx_bias
decoder_layeridx_sets.add(layeridx)
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(
data_tf, axis=map_dict[name_q]["squeeze"]
)
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(
data_tf, map_dict[name_q]["transpose"]
)
data_tf = (
torch.from_numpy(data_tf).type(torch.float32).to("cpu")
)
assert (
var_dict_torch[name].size() == data_tf.size()
), "{}, {}, {} != {}".format(
name, name_tf, var_dict_torch[name].size(), data_tf.size()
)
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_v, var_dict_tf[name_tf].shape
)
)
elif names[1] == "decoders2":
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
name_q = name_q.replace("decoders2", "decoders")
layeridx_bias = len(decoder_layeridx_sets)
layeridx += layeridx_bias
if "decoders." in name:
decoder_layeridx_sets.add(layeridx)
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(
data_tf, axis=map_dict[name_q]["squeeze"]
)
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(
data_tf, map_dict[name_q]["transpose"]
)
data_tf = (
torch.from_numpy(data_tf).type(torch.float32).to("cpu")
)
assert (
var_dict_torch[name].size() == data_tf.size()
), "{}, {}, {} != {}".format(
name, name_tf, var_dict_torch[name].size(), data_tf.size()
)
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_v, var_dict_tf[name_tf].shape
)
)
elif names[1] == "decoders3":
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
layeridx_bias = 0
layeridx += layeridx_bias
if "decoders." in name:
decoder_layeridx_sets.add(layeridx)
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(
data_tf, axis=map_dict[name_q]["squeeze"]
)
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(
data_tf, map_dict[name_q]["transpose"]
)
data_tf = (
torch.from_numpy(data_tf).type(torch.float32).to("cpu")
)
assert (
var_dict_torch[name].size() == data_tf.size()
), "{}, {}, {} != {}".format(
name, name_tf, var_dict_torch[name].size(), data_tf.size()
)
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_v, var_dict_tf[name_tf].shape
)
)
elif names[1] == "bias_decoder":
name_q = name
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(
data_tf, axis=map_dict[name_q]["squeeze"]
)
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(
data_tf, map_dict[name_q]["transpose"]
)
data_tf = (
torch.from_numpy(data_tf).type(torch.float32).to("cpu")
)
assert (
var_dict_torch[name].size() == data_tf.size()
), "{}, {}, {} != {}".format(
name, name_tf, var_dict_torch[name].size(), data_tf.size()
)
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_v, var_dict_tf[name_tf].shape
)
)
elif (
names[1] == "embed"
or names[1] == "output_layer"
or names[1] == "bias_output"
):
name_tf = map_dict[name]["name"]
if isinstance(name_tf, list):
idx_list = 0
if name_tf[idx_list] in var_dict_tf.keys():
pass
else:
idx_list = 1
data_tf = var_dict_tf[name_tf[idx_list]]
if map_dict[name]["squeeze"][idx_list] is not None:
data_tf = np.squeeze(
data_tf, axis=map_dict[name]["squeeze"][idx_list]
)
if map_dict[name]["transpose"][idx_list] is not None:
data_tf = np.transpose(
data_tf, map_dict[name]["transpose"][idx_list]
)
data_tf = (
torch.from_numpy(data_tf).type(torch.float32).to("cpu")
)
assert (
var_dict_torch[name].size() == data_tf.size()
), "{}, {}, {} != {}".format(
name, name_tf, var_dict_torch[name].size(), data_tf.size()
)
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name,
data_tf.size(),
name_tf[idx_list],
var_dict_tf[name_tf[idx_list]].shape,
)
)
else:
data_tf = var_dict_tf[name_tf]
if map_dict[name]["squeeze"] is not None:
data_tf = np.squeeze(
data_tf, axis=map_dict[name]["squeeze"]
)
if map_dict[name]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
data_tf = (
torch.from_numpy(data_tf).type(torch.float32).to("cpu")
)
assert (
var_dict_torch[name].size() == data_tf.size()
), "{}, {}, {} != {}".format(
name, name_tf, var_dict_torch[name].size(), data_tf.size()
)
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name,
data_tf.size(),
name_tf,
var_dict_tf[name_tf].shape,
)
)
elif names[1] == "after_norm":
name_tf = map_dict[name]["name"]
data_tf = var_dict_tf[name_tf]
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
)
)
elif names[1] == "embed_concat_ffn":
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
layeridx_bias = 0
layeridx += layeridx_bias
if "decoders." in name:
decoder_layeridx_sets.add(layeridx)
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(
data_tf, axis=map_dict[name_q]["squeeze"]
)
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(
data_tf, map_dict[name_q]["transpose"]
)
data_tf = (
torch.from_numpy(data_tf).type(torch.float32).to("cpu")
)
assert (
var_dict_torch[name].size() == data_tf.size()
), "{}, {}, {} != {}".format(
name, name_tf, var_dict_torch[name].size(), data_tf.size()
)
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_v, var_dict_tf[name_tf].shape
)
)
return var_dict_torch_update