3v324v23 commited on
Commit
0327556
·
1 Parent(s): cda63cb
Files changed (1) hide show
  1. ecapa_tdnn.py +396 -0
ecapa_tdnn.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchaudio.transforms as trans
5
+
6
+
7
+ # Res2Conv1d + BatchNorm1d + ReLU
8
+ class Res2Conv1dReluBn(nn.Module):
9
+ """
10
+ in_channels == out_channels == channels
11
+ """
12
+
13
+ def __init__(
14
+ self,
15
+ channels,
16
+ kernel_size=1,
17
+ stride=1,
18
+ padding=0,
19
+ dilation=1,
20
+ bias=True,
21
+ scale=4,
22
+ ):
23
+ super().__init__()
24
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
25
+ self.scale = scale
26
+ self.width = channels // scale
27
+ self.nums = scale if scale == 1 else scale - 1
28
+
29
+ self.convs = []
30
+ self.bns = []
31
+ for i in range(self.nums):
32
+ self.convs.append(
33
+ nn.Conv1d(
34
+ self.width,
35
+ self.width,
36
+ kernel_size,
37
+ stride,
38
+ padding,
39
+ dilation,
40
+ bias=bias,
41
+ )
42
+ )
43
+ self.bns.append(nn.BatchNorm1d(self.width))
44
+ self.convs = nn.ModuleList(self.convs)
45
+ self.bns = nn.ModuleList(self.bns)
46
+
47
+ def forward(self, x):
48
+ out = []
49
+ spx = torch.split(x, self.width, 1)
50
+ for i in range(self.nums):
51
+ if i == 0:
52
+ sp = spx[i]
53
+ else:
54
+ sp = sp + spx[i]
55
+ # Order: conv -> relu -> bn
56
+ sp = self.convs[i](sp)
57
+ sp = self.bns[i](F.relu(sp))
58
+ out.append(sp)
59
+ if self.scale != 1:
60
+ out.append(spx[self.nums])
61
+ out = torch.cat(out, dim=1)
62
+
63
+ return out
64
+
65
+
66
+ # Conv1d + BatchNorm1d + ReLU
67
+ class Conv1dReluBn(nn.Module):
68
+ def __init__(
69
+ self,
70
+ in_channels,
71
+ out_channels,
72
+ kernel_size=1,
73
+ stride=1,
74
+ padding=0,
75
+ dilation=1,
76
+ bias=True,
77
+ ):
78
+ super().__init__()
79
+ self.conv = nn.Conv1d(
80
+ in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
81
+ )
82
+ self.bn = nn.BatchNorm1d(out_channels)
83
+
84
+ def forward(self, x):
85
+ return self.bn(F.relu(self.conv(x)))
86
+
87
+
88
+ # The SE connection of 1D case.
89
+ class SE_Connect(nn.Module):
90
+ def __init__(self, channels, se_bottleneck_dim=128):
91
+ super().__init__()
92
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
93
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
94
+
95
+ def forward(self, x):
96
+ out = x.mean(dim=2)
97
+ out = F.relu(self.linear1(out))
98
+ out = torch.sigmoid(self.linear2(out))
99
+ out = x * out.unsqueeze(2)
100
+
101
+ return out
102
+
103
+
104
+ # SE-Res2Block of the ECAPA-TDNN architecture.
105
+ class SE_Res2Block(nn.Module):
106
+ def __init__(
107
+ self,
108
+ in_channels,
109
+ out_channels,
110
+ kernel_size,
111
+ stride,
112
+ padding,
113
+ dilation,
114
+ scale,
115
+ se_bottleneck_dim,
116
+ ):
117
+ super().__init__()
118
+ self.Conv1dReluBn1 = Conv1dReluBn(
119
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
120
+ )
121
+ self.Res2Conv1dReluBn = Res2Conv1dReluBn(
122
+ out_channels, kernel_size, stride, padding, dilation, scale=scale
123
+ )
124
+ self.Conv1dReluBn2 = Conv1dReluBn(
125
+ out_channels, out_channels, kernel_size=1, stride=1, padding=0
126
+ )
127
+ self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
128
+
129
+ self.shortcut = None
130
+ if in_channels != out_channels:
131
+ self.shortcut = nn.Conv1d(
132
+ in_channels=in_channels,
133
+ out_channels=out_channels,
134
+ kernel_size=1,
135
+ )
136
+
137
+ def forward(self, x):
138
+ residual = x
139
+ if self.shortcut:
140
+ residual = self.shortcut(x)
141
+
142
+ x = self.Conv1dReluBn1(x)
143
+ x = self.Res2Conv1dReluBn(x)
144
+ x = self.Conv1dReluBn2(x)
145
+ x = self.SE_Connect(x)
146
+
147
+ return x + residual
148
+
149
+
150
+ # Attentive weighted mean and standard deviation pooling.
151
+ class AttentiveStatsPool(nn.Module):
152
+ def __init__(self, in_dim, attention_channels=128, global_context_att=False):
153
+ super().__init__()
154
+ self.global_context_att = global_context_att
155
+
156
+ # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
157
+ if global_context_att:
158
+ self.linear1 = nn.Conv1d(
159
+ in_dim * 3, attention_channels, kernel_size=1
160
+ ) # equals W and b in the paper
161
+ else:
162
+ self.linear1 = nn.Conv1d(
163
+ in_dim, attention_channels, kernel_size=1
164
+ ) # equals W and b in the paper
165
+ self.linear2 = nn.Conv1d(
166
+ attention_channels, in_dim, kernel_size=1
167
+ ) # equals V and k in the paper
168
+
169
+ def forward(self, x):
170
+
171
+ if self.global_context_att:
172
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
173
+ context_std = torch.sqrt(
174
+ torch.var(x, dim=-1, keepdim=True) + 1e-10
175
+ ).expand_as(x)
176
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
177
+ else:
178
+ x_in = x
179
+
180
+ # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
181
+ alpha = torch.tanh(self.linear1(x_in))
182
+ # alpha = F.relu(self.linear1(x_in))
183
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
184
+ mean = torch.sum(alpha * x, dim=2)
185
+ residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
186
+ std = torch.sqrt(residuals.clamp(min=1e-9))
187
+ return torch.cat([mean, std], dim=1)
188
+
189
+
190
+ class ECAPA_TDNN(nn.Module):
191
+ def __init__(
192
+ self,
193
+ feat_dim=80,
194
+ channels=512,
195
+ emb_dim=192,
196
+ global_context_att=False,
197
+ feat_type="fbank",
198
+ sr=16000,
199
+ feature_selection="hidden_states",
200
+ update_extract=False,
201
+ config_path=None,
202
+ ):
203
+ super().__init__()
204
+
205
+ self.feat_type = feat_type
206
+ self.feature_selection = feature_selection
207
+ self.update_extract = update_extract
208
+ self.sr = sr
209
+
210
+ if feat_type == "fbank" or feat_type == "mfcc":
211
+ self.update_extract = False
212
+
213
+ win_len = int(sr * 0.025)
214
+ hop_len = int(sr * 0.01)
215
+
216
+ if feat_type == "fbank":
217
+ self.feature_extract = trans.MelSpectrogram(
218
+ sample_rate=sr,
219
+ n_fft=512,
220
+ win_length=win_len,
221
+ hop_length=hop_len,
222
+ f_min=0.0,
223
+ f_max=sr // 2,
224
+ pad=0,
225
+ n_mels=feat_dim,
226
+ )
227
+ elif feat_type == "mfcc":
228
+ melkwargs = {
229
+ "n_fft": 512,
230
+ "win_length": win_len,
231
+ "hop_length": hop_len,
232
+ "f_min": 0.0,
233
+ "f_max": sr // 2,
234
+ "pad": 0,
235
+ }
236
+ self.feature_extract = trans.MFCC(
237
+ sample_rate=sr, n_mfcc=feat_dim, log_mels=False, melkwargs=melkwargs
238
+ )
239
+ else:
240
+ if config_path is None:
241
+ torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
242
+ self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
243
+ else:
244
+ self.feature_extract = UpstreamExpert(config_path)
245
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
246
+ self.feature_extract.model.encoder.layers[23].self_attn,
247
+ "fp32_attention",
248
+ ):
249
+ self.feature_extract.model.encoder.layers[
250
+ 23
251
+ ].self_attn.fp32_attention = False
252
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
253
+ self.feature_extract.model.encoder.layers[11].self_attn,
254
+ "fp32_attention",
255
+ ):
256
+ self.feature_extract.model.encoder.layers[
257
+ 11
258
+ ].self_attn.fp32_attention = False
259
+
260
+ self.feat_num = self.get_feat_num()
261
+ self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
262
+
263
+ if feat_type != "fbank" and feat_type != "mfcc":
264
+ freeze_list = [
265
+ "final_proj",
266
+ "label_embs_concat",
267
+ "mask_emb",
268
+ "project_q",
269
+ "quantizer",
270
+ ]
271
+ for name, param in self.feature_extract.named_parameters():
272
+ for freeze_val in freeze_list:
273
+ if freeze_val in name:
274
+ param.requires_grad = False
275
+ break
276
+
277
+ if not self.update_extract:
278
+ for param in self.feature_extract.parameters():
279
+ param.requires_grad = False
280
+
281
+ self.instance_norm = nn.InstanceNorm1d(feat_dim)
282
+ # self.channels = [channels] * 4 + [channels * 3]
283
+ self.channels = [channels] * 4 + [1536]
284
+
285
+ self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
286
+ self.layer2 = SE_Res2Block(
287
+ self.channels[0],
288
+ self.channels[1],
289
+ kernel_size=3,
290
+ stride=1,
291
+ padding=2,
292
+ dilation=2,
293
+ scale=8,
294
+ se_bottleneck_dim=128,
295
+ )
296
+ self.layer3 = SE_Res2Block(
297
+ self.channels[1],
298
+ self.channels[2],
299
+ kernel_size=3,
300
+ stride=1,
301
+ padding=3,
302
+ dilation=3,
303
+ scale=8,
304
+ se_bottleneck_dim=128,
305
+ )
306
+ self.layer4 = SE_Res2Block(
307
+ self.channels[2],
308
+ self.channels[3],
309
+ kernel_size=3,
310
+ stride=1,
311
+ padding=4,
312
+ dilation=4,
313
+ scale=8,
314
+ se_bottleneck_dim=128,
315
+ )
316
+
317
+ cat_channels = channels * 3
318
+ self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
319
+ self.pooling = AttentiveStatsPool(
320
+ self.channels[-1],
321
+ attention_channels=128,
322
+ global_context_att=global_context_att,
323
+ )
324
+ self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
325
+ self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
326
+
327
+ def get_feat_num(self):
328
+ self.feature_extract.eval()
329
+ wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
330
+ with torch.no_grad():
331
+ features = self.feature_extract(wav)
332
+ select_feature = features[self.feature_selection]
333
+ if isinstance(select_feature, (list, tuple)):
334
+ return len(select_feature)
335
+ else:
336
+ return 1
337
+
338
+ def get_feat(self, x):
339
+ if self.update_extract:
340
+ x = self.feature_extract([sample for sample in x])
341
+ else:
342
+ with torch.no_grad():
343
+ if self.feat_type == "fbank" or self.feat_type == "mfcc":
344
+ x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
345
+ else:
346
+ x = self.feature_extract([sample for sample in x])
347
+
348
+ if self.feat_type == "fbank":
349
+ x = x.log()
350
+
351
+ if self.feat_type != "fbank" and self.feat_type != "mfcc":
352
+ x = x[self.feature_selection]
353
+ if isinstance(x, (list, tuple)):
354
+ x = torch.stack(x, dim=0)
355
+ else:
356
+ x = x.unsqueeze(0)
357
+ norm_weights = (
358
+ F.softmax(self.feature_weight, dim=-1)
359
+ .unsqueeze(-1)
360
+ .unsqueeze(-1)
361
+ .unsqueeze(-1)
362
+ )
363
+ x = (norm_weights * x).sum(dim=0)
364
+ x = torch.transpose(x, 1, 2) + 1e-6
365
+
366
+ x = self.instance_norm(x)
367
+ return x
368
+
369
+ def forward(self, x):
370
+ x = self.get_feat(x)
371
+
372
+ out1 = self.layer1(x)
373
+ out2 = self.layer2(out1)
374
+ out3 = self.layer3(out2)
375
+ out4 = self.layer4(out3)
376
+
377
+ out = torch.cat([out2, out3, out4], dim=1)
378
+ out = F.relu(self.conv(out))
379
+ out = self.bn(self.pooling(out))
380
+ out = self.linear(out)
381
+
382
+ return out
383
+
384
+
385
+ if __name__ == "__main__":
386
+ x = torch.zeros(2, 32000)
387
+ model = ECAPA_TDNN(
388
+ feat_dim=768,
389
+ emb_dim=256,
390
+ feat_type="hubert_base",
391
+ feature_selection="hidden_states",
392
+ update_extract=False,
393
+ )
394
+
395
+ out = model(x)
396
+ print(out.shape)