|
import torch.nn as nn |
|
import torch |
|
|
|
class nonlinearity(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x): |
|
|
|
return x * torch.sigmoid(x) |
|
|
|
class ResConv1DBlock(nn.Module): |
|
def __init__(self, n_in, n_state, dilation=1, activation='silu', norm=None, dropout=None): |
|
super().__init__() |
|
padding = dilation |
|
self.norm = norm |
|
if norm == "LN": |
|
self.norm1 = nn.LayerNorm(n_in) |
|
self.norm2 = nn.LayerNorm(n_in) |
|
elif norm == "GN": |
|
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True) |
|
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True) |
|
elif norm == "BN": |
|
self.norm1 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True) |
|
self.norm2 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True) |
|
|
|
else: |
|
self.norm1 = nn.Identity() |
|
self.norm2 = nn.Identity() |
|
|
|
if activation == "relu": |
|
self.activation1 = nn.ReLU() |
|
self.activation2 = nn.ReLU() |
|
|
|
elif activation == "silu": |
|
self.activation1 = nonlinearity() |
|
self.activation2 = nonlinearity() |
|
|
|
elif activation == "gelu": |
|
self.activation1 = nn.GELU() |
|
self.activation2 = nn.GELU() |
|
|
|
|
|
|
|
self.conv1 = nn.Conv1d(n_in, n_state, 3, 1, padding, dilation) |
|
self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0,) |
|
|
|
|
|
def forward(self, x): |
|
x_orig = x |
|
if self.norm == "LN": |
|
x = self.norm1(x.transpose(-2, -1)) |
|
x = self.activation1(x.transpose(-2, -1)) |
|
else: |
|
x = self.norm1(x) |
|
x = self.activation1(x) |
|
|
|
x = self.conv1(x) |
|
|
|
if self.norm == "LN": |
|
x = self.norm2(x.transpose(-2, -1)) |
|
x = self.activation2(x.transpose(-2, -1)) |
|
else: |
|
x = self.norm2(x) |
|
x = self.activation2(x) |
|
|
|
x = self.conv2(x) |
|
x = x + x_orig |
|
return x |
|
|
|
class Resnet1D(nn.Module): |
|
def __init__(self, n_in, n_depth, dilation_growth_rate=1, reverse_dilation=True, activation='relu', norm=None): |
|
super().__init__() |
|
|
|
blocks = [ResConv1DBlock(n_in, n_in, dilation=dilation_growth_rate ** depth, activation=activation, norm=norm) for depth in range(n_depth)] |
|
if reverse_dilation: |
|
blocks = blocks[::-1] |
|
|
|
self.model = nn.Sequential(*blocks) |
|
|
|
def forward(self, x): |
|
return self.model(x) |