Spaces:
Sleeping
Sleeping
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. | |
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
from collections import OrderedDict | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from modules.campplus.layers import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, BasicResBlock, get_nonlinear | |
class FCM(nn.Module): | |
def __init__(self, | |
block=BasicResBlock, | |
num_blocks=[2, 2], | |
m_channels=32, | |
feat_dim=80): | |
super(FCM, self).__init__() | |
self.in_planes = m_channels | |
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False) | |
self.bn1 = nn.BatchNorm2d(m_channels) | |
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2) | |
self.layer2 = self._make_layer(block, m_channels, num_blocks[1], stride=2) | |
self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False) | |
self.bn2 = nn.BatchNorm2d(m_channels) | |
self.out_channels = m_channels * (feat_dim // 8) | |
def _make_layer(self, block, planes, num_blocks, stride): | |
strides = [stride] + [1] * (num_blocks - 1) | |
layers = [] | |
for stride in strides: | |
layers.append(block(self.in_planes, planes, stride)) | |
self.in_planes = planes * block.expansion | |
return nn.Sequential(*layers) | |
def forward(self, x): | |
x = x.unsqueeze(1) | |
out = F.relu(self.bn1(self.conv1(x))) | |
out = self.layer1(out) | |
out = self.layer2(out) | |
out = F.relu(self.bn2(self.conv2(out))) | |
shape = out.shape | |
out = out.reshape(shape[0], shape[1]*shape[2], shape[3]) | |
return out | |
class CAMPPlus(nn.Module): | |
def __init__(self, | |
feat_dim=80, | |
embedding_size=512, | |
growth_rate=32, | |
bn_size=4, | |
init_channels=128, | |
config_str='batchnorm-relu', | |
memory_efficient=True): | |
super(CAMPPlus, self).__init__() | |
self.head = FCM(feat_dim=feat_dim) | |
channels = self.head.out_channels | |
self.xvector = nn.Sequential( | |
OrderedDict([ | |
('tdnn', | |
TDNNLayer(channels, | |
init_channels, | |
5, | |
stride=2, | |
dilation=1, | |
padding=-1, | |
config_str=config_str)), | |
])) | |
channels = init_channels | |
for i, (num_layers, kernel_size, | |
dilation) in enumerate(zip((12, 24, 16), (3, 3, 3), (1, 2, 2))): | |
block = CAMDenseTDNNBlock(num_layers=num_layers, | |
in_channels=channels, | |
out_channels=growth_rate, | |
bn_channels=bn_size * growth_rate, | |
kernel_size=kernel_size, | |
dilation=dilation, | |
config_str=config_str, | |
memory_efficient=memory_efficient) | |
self.xvector.add_module('block%d' % (i + 1), block) | |
channels = channels + num_layers * growth_rate | |
self.xvector.add_module( | |
'transit%d' % (i + 1), | |
TransitLayer(channels, | |
channels // 2, | |
bias=False, | |
config_str=config_str)) | |
channels //= 2 | |
self.xvector.add_module( | |
'out_nonlinear', get_nonlinear(config_str, channels)) | |
self.xvector.add_module('stats', StatsPool()) | |
self.xvector.add_module( | |
'dense', | |
DenseLayer(channels * 2, embedding_size, config_str='batchnorm_')) | |
for m in self.modules(): | |
if isinstance(m, (nn.Conv1d, nn.Linear)): | |
nn.init.kaiming_normal_(m.weight.data) | |
if m.bias is not None: | |
nn.init.zeros_(m.bias) | |
def forward(self, x): | |
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) | |
x = self.head(x) | |
x = self.xvector(x) | |
return x |