Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torch.autograd import Function | |
def tile(x, count, dim=0): | |
""" | |
Tiles x on dimension dim count times. | |
""" | |
perm = list(range(len(x.size()))) | |
if dim != 0: | |
perm[0], perm[dim] = perm[dim], perm[0] | |
x = x.permute(perm).contiguous() | |
out_size = list(x.size()) | |
out_size[0] *= count | |
batch = x.size(0) | |
x = x.view(batch, -1) \ | |
.transpose(0, 1) \ | |
.repeat(count, 1) \ | |
.transpose(0, 1) \ | |
.contiguous() \ | |
.view(*out_size) | |
if dim != 0: | |
x = x.permute(perm).contiguous() | |
return x | |
class Linear(torch.nn.Module): | |
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): | |
super(Linear, self).__init__() | |
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) | |
torch.nn.init.xavier_uniform_( | |
self.linear_layer.weight, | |
gain=torch.nn.init.calculate_gain(w_init_gain)) | |
def forward(self, x): | |
return self.linear_layer(x) | |
class Conv1d(torch.nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, | |
padding=None, dilation=1, bias=True, w_init_gain='linear', param=None): | |
super(Conv1d, self).__init__() | |
if padding is None: | |
assert(kernel_size % 2 == 1) | |
padding = int(dilation * (kernel_size - 1)/2) | |
self.conv = torch.nn.Conv1d(in_channels, out_channels, | |
kernel_size=kernel_size, stride=stride, | |
padding=padding, dilation=dilation, | |
bias=bias) | |
torch.nn.init.xavier_uniform_( | |
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param)) | |
def forward(self, x): | |
# x: BxDxT | |
return self.conv(x) | |
def tile(x, count, dim=0): | |
""" | |
Tiles x on dimension dim count times. | |
""" | |
perm = list(range(len(x.size()))) | |
if dim != 0: | |
perm[0], perm[dim] = perm[dim], perm[0] | |
x = x.permute(perm).contiguous() | |
out_size = list(x.size()) | |
out_size[0] *= count | |
batch = x.size(0) | |
x = x.view(batch, -1) \ | |
.transpose(0, 1) \ | |
.repeat(count, 1) \ | |
.transpose(0, 1) \ | |
.contiguous() \ | |
.view(*out_size) | |
if dim != 0: | |
x = x.permute(perm).contiguous() | |
return x | |