Spaces:
Runtime error
Runtime error
File size: 1,438 Bytes
0102e16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import torch
from torch.nn import functional as F
from funasr_detach.models.decoder.abs_decoder import AbsDecoder
class DenseDecoder(AbsDecoder):
def __init__(
self,
vocab_size,
encoder_output_size,
num_nodes_resnet1: int = 256,
num_nodes_last_layer: int = 256,
batchnorm_momentum: float = 0.5,
):
super(DenseDecoder, self).__init__()
self.resnet1_dense = torch.nn.Linear(encoder_output_size, num_nodes_resnet1)
self.resnet1_bn = torch.nn.BatchNorm1d(
num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum
)
self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer)
self.resnet2_bn = torch.nn.BatchNorm1d(
num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum
)
self.output_dense = torch.nn.Linear(
num_nodes_last_layer, vocab_size, bias=False
)
def forward(self, features):
embeddings = {}
features = self.resnet1_dense(features)
embeddings["resnet1_dense"] = features
features = F.relu(features)
features = self.resnet1_bn(features)
features = self.resnet2_dense(features)
embeddings["resnet2_dense"] = features
features = F.relu(features)
features = self.resnet2_bn(features)
features = self.output_dense(features)
return features, embeddings
|