lxysl's picture
upload vita-1.5 app.py
bc752b1
raw
history blame
5.42 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
class FsmnLayer(nn.Module):
def __init__(
self,
input_dim,
out_dim,
hidden_dim,
left_frame=1,
right_frame=1,
left_dilation=1,
right_dilation=1,
):
super(FsmnLayer, self).__init__()
self.input_dim = input_dim
self.out_dim = out_dim
self.hidden_dim = hidden_dim
self.left_frame = left_frame
self.right_frame = right_frame
self.left_dilation = left_dilation
self.right_dilation = right_dilation
self.conv_in = nn.Conv1d(input_dim, hidden_dim, kernel_size=1)
if left_frame > 0:
self.pad_left = nn.ConstantPad1d([left_dilation * left_frame, 0], 0.0)
self.conv_left = nn.Conv1d(
hidden_dim,
hidden_dim,
kernel_size=left_frame + 1,
dilation=left_dilation,
bias=False,
groups=hidden_dim,
)
if right_frame > 0:
self.pad_right = nn.ConstantPad1d([-right_dilation, right_dilation * right_frame], 0.0)
self.conv_right = nn.Conv1d(
hidden_dim,
hidden_dim,
kernel_size=right_frame,
dilation=right_dilation,
bias=False,
groups=hidden_dim,
)
self.conv_out = nn.Conv1d(hidden_dim, out_dim, kernel_size=1)
# cache = 1, self.hidden_dim, left_frame * left_dilation + right_frame * right_dilation
self.cache_size = left_frame * left_dilation + right_frame * right_dilation
self.buffer_size = self.hidden_dim * self.cache_size
self.p_in_raw_chache_size = self.right_frame * self.right_dilation
self.p_in_raw_buffer_size = self.hidden_dim * self.p_in_raw_chache_size
self.hidden_chache_size = self.right_frame * self.right_dilation
self.hidden_buffer_size = self.hidden_dim * self.hidden_chache_size
@torch.jit.unused
def forward(self, x, hidden=None):
x_data = x.transpose(1, 2)
p_in = self.conv_in(x_data)
if self.left_frame > 0:
p_left = self.pad_left(p_in)
p_left = self.conv_left(p_left)
else:
p_left = 0
if self.right_frame > 0:
p_right = self.pad_right(p_in)
p_right = self.conv_right(p_right)
else:
p_right = 0
p_out = p_in + p_right + p_left
if hidden is not None:
p_out = hidden + p_out
out = F.relu(self.conv_out(p_out))
out = out.transpose(1, 2)
return out, p_out
@torch.jit.export
def infer(self, x, buffer, buffer_index, buffer_out, hidden=None):
# type: (Tensor, Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor]
p_in_raw = self.conv_in(x)
cnn_buffer = buffer[buffer_index : buffer_index + self.buffer_size].reshape(
[1, self.hidden_dim, self.cache_size]
)
p_in = torch.cat([cnn_buffer, p_in_raw], dim=2)
# buffer[buffer_index: buffer_index + self.buffer_size] = p_in[:, :, -self.cache_size:].reshape(-1)
buffer_out.append(p_in[:, :, -self.cache_size :].reshape(-1))
buffer_index = buffer_index + self.buffer_size
if self.left_frame > 0:
if self.right_frame > 0:
p_left = p_in[:, :, : -self.right_frame * self.right_dilation]
else:
p_left = p_in[:, :]
p_left_out = self.conv_left(p_left)
else:
p_left_out = torch.tensor([0])
if self.right_frame > 0:
p_right = p_in[:, :, self.left_frame * self.left_dilation + 1 :]
p_right_out = self.conv_right(p_right)
else:
p_right_out = torch.tensor([0])
if self.right_frame > 0:
p_in_raw_cnn_buffer = buffer[
buffer_index : buffer_index + self.p_in_raw_buffer_size
].reshape([1, self.hidden_dim, self.p_in_raw_chache_size])
p_in_raw = torch.cat([p_in_raw_cnn_buffer, p_in_raw], dim=2)
# buffer[buffer_index: buffer_index + self.p_in_raw_buffer_size] = p_in_raw[:, :, -self.p_in_raw_chache_size:].reshape(-1)
buffer_out.append(p_in_raw[:, :, -self.p_in_raw_chache_size :].reshape(-1))
buffer_index = buffer_index + self.p_in_raw_buffer_size
p_in_raw = p_in_raw[:, :, : -self.p_in_raw_chache_size]
p_out = p_in_raw + p_left_out + p_right_out
if hidden is not None:
if self.right_frame > 0:
hidden_cnn_buffer = buffer[
buffer_index : buffer_index + self.hidden_buffer_size
].reshape([1, self.hidden_dim, self.hidden_chache_size])
hidden = torch.cat([hidden_cnn_buffer, hidden], dim=2)
# buffer[buffer_index: buffer_index + self.hidden_buffer_size] = hidden[:, :, -self.hidden_chache_size:].reshape(-1)
buffer_out.append(hidden[:, :, -self.hidden_chache_size :].reshape(-1))
buffer_index = buffer_index + self.hidden_buffer_size
hidden = hidden[:, :, : -self.hidden_chache_size]
p_out = hidden + p_out
out = F.relu(self.conv_out(p_out))
return out, buffer, buffer_index, buffer_out, p_out