Spaces:
Runtime error
Runtime error
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
# -*- encoding: utf-8 -*- | |
import torch | |
import torch.nn as nn | |
from asteroid_filterbanks import Encoder, ParamSincFB | |
from .RawNetBasicBlock import Bottle2neck, PreEmphasis | |
class RawNet3(nn.Module): | |
def __init__(self, block, model_scale, context, summed, C=1024, **kwargs): | |
super().__init__() | |
nOut = kwargs["nOut"] | |
self.context = context | |
self.encoder_type = kwargs["encoder_type"] | |
self.log_sinc = kwargs["log_sinc"] | |
self.norm_sinc = kwargs["norm_sinc"] | |
self.out_bn = kwargs["out_bn"] | |
self.summed = summed | |
self.preprocess = nn.Sequential( | |
PreEmphasis(), nn.InstanceNorm1d(1, eps=1e-4, affine=True) | |
) | |
self.conv1 = Encoder( | |
ParamSincFB( | |
C // 4, | |
251, | |
stride=kwargs["sinc_stride"], | |
) | |
) | |
self.relu = nn.ReLU() | |
self.bn1 = nn.BatchNorm1d(C // 4) | |
self.layer1 = block( | |
C // 4, C, kernel_size=3, dilation=2, scale=model_scale, pool=5 | |
) | |
self.layer2 = block(C, C, kernel_size=3, dilation=3, scale=model_scale, pool=3) | |
self.layer3 = block(C, C, kernel_size=3, dilation=4, scale=model_scale) | |
self.layer4 = nn.Conv1d(3 * C, 1536, kernel_size=1) | |
if self.context: | |
attn_input = 1536 * 3 | |
else: | |
attn_input = 1536 | |
print("self.encoder_type", self.encoder_type) | |
if self.encoder_type == "ECA": | |
attn_output = 1536 | |
elif self.encoder_type == "ASP": | |
attn_output = 1 | |
else: | |
raise ValueError("Undefined encoder") | |
self.attention = nn.Sequential( | |
nn.Conv1d(attn_input, 128, kernel_size=1), | |
nn.ReLU(), | |
nn.BatchNorm1d(128), | |
nn.Conv1d(128, attn_output, kernel_size=1), | |
nn.Softmax(dim=2), | |
) | |
self.bn5 = nn.BatchNorm1d(3072) | |
self.fc6 = nn.Linear(3072, nOut) | |
self.bn6 = nn.BatchNorm1d(nOut) | |
self.mp3 = nn.MaxPool1d(3) | |
def forward(self, x): | |
""" | |
:param x: input mini-batch (bs, samp) | |
""" | |
with torch.cuda.amp.autocast(enabled=False): | |
x = self.preprocess(x) | |
x = torch.abs(self.conv1(x)) | |
if self.log_sinc: | |
x = torch.log(x + 1e-6) | |
if self.norm_sinc == "mean": | |
x = x - torch.mean(x, dim=-1, keepdim=True) | |
elif self.norm_sinc == "mean_std": | |
m = torch.mean(x, dim=-1, keepdim=True) | |
s = torch.std(x, dim=-1, keepdim=True) | |
s[s < 0.001] = 0.001 | |
x = (x - m) / s | |
if self.summed: | |
x1 = self.layer1(x) | |
x2 = self.layer2(x1) | |
x3 = self.layer3(self.mp3(x1) + x2) | |
else: | |
x1 = self.layer1(x) | |
x2 = self.layer2(x1) | |
x3 = self.layer3(x2) | |
x = self.layer4(torch.cat((self.mp3(x1), x2, x3), dim=1)) | |
x = self.relu(x) | |
t = x.size()[-1] | |
if self.context: | |
global_x = torch.cat( | |
( | |
x, | |
torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t), | |
torch.sqrt( | |
torch.var(x, dim=2, keepdim=True).clamp(min=1e-4, max=1e4) | |
).repeat(1, 1, t), | |
), | |
dim=1, | |
) | |
else: | |
global_x = x | |
w = self.attention(global_x) | |
mu = torch.sum(x * w, dim=2) | |
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-4, max=1e4)) | |
x = torch.cat((mu, sg), 1) | |
x = self.bn5(x) | |
x = self.fc6(x) | |
if self.out_bn: | |
x = self.bn6(x) | |
return x | |
def MainModel(**kwargs): | |
model = RawNet3(Bottle2neck, model_scale=8, context=True, summed=True, **kwargs) | |
return model | |