import torch from torch.nn import functional as F from funasr_detach.models.decoder.abs_decoder import AbsDecoder class DenseDecoder(AbsDecoder): def __init__( self, vocab_size, encoder_output_size, num_nodes_resnet1: int = 256, num_nodes_last_layer: int = 256, batchnorm_momentum: float = 0.5, ): super(DenseDecoder, self).__init__() self.resnet1_dense = torch.nn.Linear(encoder_output_size, num_nodes_resnet1) self.resnet1_bn = torch.nn.BatchNorm1d( num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum ) self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer) self.resnet2_bn = torch.nn.BatchNorm1d( num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum ) self.output_dense = torch.nn.Linear( num_nodes_last_layer, vocab_size, bias=False ) def forward(self, features): embeddings = {} features = self.resnet1_dense(features) embeddings["resnet1_dense"] = features features = F.relu(features) features = self.resnet1_bn(features) features = self.resnet2_dense(features) embeddings["resnet2_dense"] = features features = F.relu(features) features = self.resnet2_bn(features) features = self.output_dense(features) return features, embeddings