File size: 1,438 Bytes
0102e16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
from torch.nn import functional as F
from funasr_detach.models.decoder.abs_decoder import AbsDecoder


class DenseDecoder(AbsDecoder):
    def __init__(
        self,
        vocab_size,
        encoder_output_size,
        num_nodes_resnet1: int = 256,
        num_nodes_last_layer: int = 256,
        batchnorm_momentum: float = 0.5,
    ):
        super(DenseDecoder, self).__init__()
        self.resnet1_dense = torch.nn.Linear(encoder_output_size, num_nodes_resnet1)
        self.resnet1_bn = torch.nn.BatchNorm1d(
            num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum
        )

        self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer)
        self.resnet2_bn = torch.nn.BatchNorm1d(
            num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum
        )

        self.output_dense = torch.nn.Linear(
            num_nodes_last_layer, vocab_size, bias=False
        )

    def forward(self, features):
        embeddings = {}
        features = self.resnet1_dense(features)
        embeddings["resnet1_dense"] = features
        features = F.relu(features)
        features = self.resnet1_bn(features)

        features = self.resnet2_dense(features)
        embeddings["resnet2_dense"] = features
        features = F.relu(features)
        features = self.resnet2_bn(features)

        features = self.output_dense(features)
        return features, embeddings