Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class FsmnLayer(nn.Module): | |
def __init__( | |
self, | |
input_dim, | |
out_dim, | |
hidden_dim, | |
left_frame=1, | |
right_frame=1, | |
left_dilation=1, | |
right_dilation=1, | |
): | |
super(FsmnLayer, self).__init__() | |
self.input_dim = input_dim | |
self.out_dim = out_dim | |
self.hidden_dim = hidden_dim | |
self.left_frame = left_frame | |
self.right_frame = right_frame | |
self.left_dilation = left_dilation | |
self.right_dilation = right_dilation | |
self.conv_in = nn.Conv1d(input_dim, hidden_dim, kernel_size=1) | |
if left_frame > 0: | |
self.pad_left = nn.ConstantPad1d([left_dilation * left_frame, 0], 0.0) | |
self.conv_left = nn.Conv1d( | |
hidden_dim, | |
hidden_dim, | |
kernel_size=left_frame + 1, | |
dilation=left_dilation, | |
bias=False, | |
groups=hidden_dim, | |
) | |
if right_frame > 0: | |
self.pad_right = nn.ConstantPad1d([-right_dilation, right_dilation * right_frame], 0.0) | |
self.conv_right = nn.Conv1d( | |
hidden_dim, | |
hidden_dim, | |
kernel_size=right_frame, | |
dilation=right_dilation, | |
bias=False, | |
groups=hidden_dim, | |
) | |
self.conv_out = nn.Conv1d(hidden_dim, out_dim, kernel_size=1) | |
# cache = 1, self.hidden_dim, left_frame * left_dilation + right_frame * right_dilation | |
self.cache_size = left_frame * left_dilation + right_frame * right_dilation | |
self.buffer_size = self.hidden_dim * self.cache_size | |
self.p_in_raw_chache_size = self.right_frame * self.right_dilation | |
self.p_in_raw_buffer_size = self.hidden_dim * self.p_in_raw_chache_size | |
self.hidden_chache_size = self.right_frame * self.right_dilation | |
self.hidden_buffer_size = self.hidden_dim * self.hidden_chache_size | |
def forward(self, x, hidden=None): | |
x_data = x.transpose(1, 2) | |
p_in = self.conv_in(x_data) | |
if self.left_frame > 0: | |
p_left = self.pad_left(p_in) | |
p_left = self.conv_left(p_left) | |
else: | |
p_left = 0 | |
if self.right_frame > 0: | |
p_right = self.pad_right(p_in) | |
p_right = self.conv_right(p_right) | |
else: | |
p_right = 0 | |
p_out = p_in + p_right + p_left | |
if hidden is not None: | |
p_out = hidden + p_out | |
out = F.relu(self.conv_out(p_out)) | |
out = out.transpose(1, 2) | |
return out, p_out | |
def infer(self, x, buffer, buffer_index, buffer_out, hidden=None): | |
# type: (Tensor, Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor] | |
p_in_raw = self.conv_in(x) | |
cnn_buffer = buffer[buffer_index : buffer_index + self.buffer_size].reshape( | |
[1, self.hidden_dim, self.cache_size] | |
) | |
p_in = torch.cat([cnn_buffer, p_in_raw], dim=2) | |
# buffer[buffer_index: buffer_index + self.buffer_size] = p_in[:, :, -self.cache_size:].reshape(-1) | |
buffer_out.append(p_in[:, :, -self.cache_size :].reshape(-1)) | |
buffer_index = buffer_index + self.buffer_size | |
if self.left_frame > 0: | |
if self.right_frame > 0: | |
p_left = p_in[:, :, : -self.right_frame * self.right_dilation] | |
else: | |
p_left = p_in[:, :] | |
p_left_out = self.conv_left(p_left) | |
else: | |
p_left_out = torch.tensor([0]) | |
if self.right_frame > 0: | |
p_right = p_in[:, :, self.left_frame * self.left_dilation + 1 :] | |
p_right_out = self.conv_right(p_right) | |
else: | |
p_right_out = torch.tensor([0]) | |
if self.right_frame > 0: | |
p_in_raw_cnn_buffer = buffer[ | |
buffer_index : buffer_index + self.p_in_raw_buffer_size | |
].reshape([1, self.hidden_dim, self.p_in_raw_chache_size]) | |
p_in_raw = torch.cat([p_in_raw_cnn_buffer, p_in_raw], dim=2) | |
# buffer[buffer_index: buffer_index + self.p_in_raw_buffer_size] = p_in_raw[:, :, -self.p_in_raw_chache_size:].reshape(-1) | |
buffer_out.append(p_in_raw[:, :, -self.p_in_raw_chache_size :].reshape(-1)) | |
buffer_index = buffer_index + self.p_in_raw_buffer_size | |
p_in_raw = p_in_raw[:, :, : -self.p_in_raw_chache_size] | |
p_out = p_in_raw + p_left_out + p_right_out | |
if hidden is not None: | |
if self.right_frame > 0: | |
hidden_cnn_buffer = buffer[ | |
buffer_index : buffer_index + self.hidden_buffer_size | |
].reshape([1, self.hidden_dim, self.hidden_chache_size]) | |
hidden = torch.cat([hidden_cnn_buffer, hidden], dim=2) | |
# buffer[buffer_index: buffer_index + self.hidden_buffer_size] = hidden[:, :, -self.hidden_chache_size:].reshape(-1) | |
buffer_out.append(hidden[:, :, -self.hidden_chache_size :].reshape(-1)) | |
buffer_index = buffer_index + self.hidden_buffer_size | |
hidden = hidden[:, :, : -self.hidden_chache_size] | |
p_out = hidden + p_out | |
out = F.relu(self.conv_out(p_out)) | |
return out, buffer, buffer_index, buffer_out, p_out | |