Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import pdb | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class DTCBlock(nn.Module): | |
def __init__( | |
self, input_dim, output_dim, kernel_size, stride, causal_conv, dilation, dropout_rate | |
): | |
super(DTCBlock, self).__init__() | |
self.input_dim = input_dim | |
self.output_dim = output_dim | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.dilation = dilation | |
if causal_conv: | |
self.padding = 0 | |
self.lorder = (kernel_size - 1) * self.dilation | |
self.left_padding = nn.ConstantPad1d((self.lorder, 0), 0.0) | |
else: | |
assert (kernel_size - 1) % 2 == 0 | |
self.padding = ((kernel_size - 1) // 2) * self.dilation | |
self.lorder = 0 | |
self.causal_conv = causal_conv | |
self.depthwise_conv = nn.Conv1d( | |
self.input_dim, | |
self.input_dim, | |
self.kernel_size, | |
self.stride, | |
self.padding, | |
self.dilation, | |
groups=self.input_dim, | |
) | |
self.point_conv_1 = nn.Conv1d(self.input_dim, self.input_dim, 1, 1, self.padding) | |
self.point_conv_2 = nn.Conv1d(self.input_dim, self.input_dim, 1, 1, self.padding) | |
self.bn_1 = nn.BatchNorm1d(self.input_dim) | |
self.bn_2 = nn.BatchNorm1d(self.input_dim) | |
self.bn_3 = nn.BatchNorm1d(self.input_dim) | |
self.dropout = nn.Dropout(p=dropout_rate) | |
# buffer = 1, self.input_dim, self.lorder | |
self.lorder = (kernel_size - 1) * self.dilation - (self.stride - 1) | |
self.buffer_size = 1 * self.input_dim * self.lorder | |
def forward(self, x): | |
x_in = x | |
x_data = x_in.transpose(1, 2) | |
if self.causal_conv: | |
x_data_pad = self.left_padding(x_data) | |
else: | |
x_data_pad = x_data | |
x_depth = self.depthwise_conv(x_data_pad) | |
x_bn_1 = self.bn_1(x_depth) | |
x_point_1 = self.point_conv_1(x_bn_1) | |
x_bn_2 = self.bn_2(x_point_1) | |
x_relu_2 = torch.relu(x_bn_2) | |
x_point_2 = self.point_conv_2(x_relu_2) | |
x_bn_3 = self.bn_3(x_point_2) | |
x_bn_3 = x_bn_3.transpose(1, 2) | |
if self.stride == 1: | |
x_relu_3 = torch.relu(x_bn_3 + x_in) | |
else: | |
x_relu_3 = torch.relu(x_bn_3) | |
x_drop = self.dropout(x_relu_3) | |
return x_drop | |
def infer(self, x, buffer, buffer_index, buffer_out): | |
# type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor] | |
x_in = x | |
x = x_in.transpose(1, 2) | |
cnn_buffer = buffer[buffer_index : buffer_index + self.buffer_size].reshape( | |
[1, self.input_dim, self.lorder] | |
) | |
x = torch.cat([cnn_buffer, x], dim=2) | |
buffer_out.append(x[:, :, -self.lorder :].reshape(-1)) | |
buffer_index = buffer_index + self.buffer_size | |
x = self.depthwise_conv(x) | |
x = self.bn_1(x) | |
x = self.point_conv_1(x) | |
x = self.bn_2(x) | |
x = torch.relu(x) | |
x = self.point_conv_2(x) | |
x = self.bn_3(x) | |
x = x.transpose(1, 2) | |
if self.stride == 1: | |
x = torch.relu(x + x_in) | |
else: | |
x = torch.relu(x) | |
return x, buffer, buffer_index, buffer_out | |