HoneyTian commited on
Commit
65a472d
·
1 Parent(s): 8f4f9ae
examples/nx_clean_unet/yaml/config.yaml CHANGED
@@ -6,7 +6,7 @@ n_fft: 512
6
  win_size: 200
7
  hop_size: 80
8
 
9
- down_sampling_num_layers: 5
10
  down_sampling_in_channels: 1
11
  down_sampling_hidden_channels: 64
12
  down_sampling_kernel_size: 4
@@ -18,16 +18,16 @@ causal_kernel_size: 3
18
  causal_bias: false
19
  causal_separable: true
20
  causal_f_stride: 1
21
- causal_num_layers: 3
22
 
23
  tsfm_hidden_size: 256
24
  tsfm_attention_heads: 8
25
  tsfm_num_blocks: 6
26
  tsfm_dropout_rate: 0.1
27
  tsfm_max_length: 512
28
- tsfm_chunk_size: 4
29
- tsfm_num_left_chunks: 64
30
- tsfm_num_right_chunks: 2
31
 
32
  discriminator_dim: 32
33
  discriminator_in_channel: 2
 
6
  win_size: 200
7
  hop_size: 80
8
 
9
+ down_sampling_num_layers: 6
10
  down_sampling_in_channels: 1
11
  down_sampling_hidden_channels: 64
12
  down_sampling_kernel_size: 4
 
18
  causal_bias: false
19
  causal_separable: true
20
  causal_f_stride: 1
21
+ causal_num_layers: 5
22
 
23
  tsfm_hidden_size: 256
24
  tsfm_attention_heads: 8
25
  tsfm_num_blocks: 6
26
  tsfm_dropout_rate: 0.1
27
  tsfm_max_length: 512
28
+ tsfm_chunk_size: 1
29
+ tsfm_num_left_chunks: 128
30
+ tsfm_num_right_chunks: 4
31
 
32
  discriminator_dim: 32
33
  discriminator_in_channel: 2
toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py CHANGED
@@ -1,7 +1,7 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import os
4
- from typing import Optional, Union
5
 
6
  import numpy as np
7
  import torch
@@ -10,7 +10,7 @@ from torch.nn import functional as F
10
 
11
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
12
  from toolbox.torchaudio.models.nx_clean_unet.configuration_nx_clean_unet import NXCleanUNetConfig
13
- from toolbox.torchaudio.models.nx_clean_unet.transformer.transformer import TransformerEncoder
14
  from toolbox.torchaudio.models.nx_clean_unet.causal_convolution.causal_conv2d import CausalConv2dEncoder
15
 
16
 
@@ -66,10 +66,12 @@ class DownSampling(nn.Module):
66
 
67
  def forward(self, x: torch.Tensor):
68
  # x shape: [batch_size, channels, num_samples]
 
69
  for down_sampling_block in self.down_sampling_block_list:
70
  x = down_sampling_block.forward(x)
 
71
  # x shape: [batch_size, hidden_channels, num_samples**]
72
- return x
73
 
74
 
75
  class UpSamplingBlock(nn.Module):
@@ -134,9 +136,14 @@ class UpSampling(nn.Module):
134
  up_sampling_block_list.append(up_sampling_block)
135
  self.up_sampling_block_list = nn.ModuleList(modules=up_sampling_block_list)
136
 
137
- def forward(self, x: torch.Tensor):
 
 
138
  # x shape: [batch_size, channels, num_samples]
139
- for up_sampling_block in self.up_sampling_block_list:
 
 
 
140
  x = up_sampling_block.forward(x)
141
  return x
142
 
@@ -209,7 +216,7 @@ class NXCleanUNet(nn.Module):
209
  )
210
  noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0)
211
 
212
- bottle_neck = self.down_sampling.forward(noisy_audios_padded)
213
  # bottle_neck shape: [batch_size, channels, time_steps]
214
 
215
  bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
@@ -226,7 +233,7 @@ class NXCleanUNet(nn.Module):
226
  bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
227
  # bottle_neck shape: [batch_size, channels, time_steps]
228
 
229
- enhanced_audios = self.up_sampling.forward(bottle_neck)
230
 
231
  enhanced_audios = enhanced_audios[:, :, :n_samples]
232
  # enhanced_audios shape: [batch_size, 1, n_samples]
@@ -250,7 +257,7 @@ class NXCleanUNet(nn.Module):
250
  )
251
  noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0)
252
 
253
- bottle_neck = self.down_sampling.forward(noisy_audios_padded)
254
  # bottle_neck shape: [batch_size, channels, time_steps]
255
 
256
  bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
@@ -267,7 +274,7 @@ class NXCleanUNet(nn.Module):
267
  bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
268
  # bottle_neck shape: [batch_size, channels, time_steps]
269
 
270
- enhanced_audios = self.up_sampling.forward(bottle_neck)
271
 
272
  enhanced_audios = enhanced_audios[:, :, :n_samples]
273
  # enhanced_audios shape: [batch_size, 1, n_samples]
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import os
4
+ from typing import List, Optional, Union
5
 
6
  import numpy as np
7
  import torch
 
10
 
11
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
12
  from toolbox.torchaudio.models.nx_clean_unet.configuration_nx_clean_unet import NXCleanUNetConfig
13
+ from toolbox.torchaudio.models.nx_clean_unet.transformers.transformers import TransformerEncoder
14
  from toolbox.torchaudio.models.nx_clean_unet.causal_convolution.causal_conv2d import CausalConv2dEncoder
15
 
16
 
 
66
 
67
  def forward(self, x: torch.Tensor):
68
  # x shape: [batch_size, channels, num_samples]
69
+ skip_connection_list = list()
70
  for down_sampling_block in self.down_sampling_block_list:
71
  x = down_sampling_block.forward(x)
72
+ skip_connection_list.append(x)
73
  # x shape: [batch_size, hidden_channels, num_samples**]
74
+ return x, skip_connection_list
75
 
76
 
77
  class UpSamplingBlock(nn.Module):
 
136
  up_sampling_block_list.append(up_sampling_block)
137
  self.up_sampling_block_list = nn.ModuleList(modules=up_sampling_block_list)
138
 
139
+ def forward(self, x: torch.Tensor, skip_connection_list: List[torch.Tensor]):
140
+ skip_connection_list = skip_connection_list[::-1]
141
+
142
  # x shape: [batch_size, channels, num_samples]
143
+ for idx, up_sampling_block in enumerate(self.up_sampling_block_list):
144
+ skip_x = skip_connection_list[idx]
145
+ x = x + skip_x
146
+ # x = x + skip_x[:, :, :x.shape[-1]]
147
  x = up_sampling_block.forward(x)
148
  return x
149
 
 
216
  )
217
  noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0)
218
 
219
+ bottle_neck, skip_connection_list = self.down_sampling.forward(noisy_audios_padded)
220
  # bottle_neck shape: [batch_size, channels, time_steps]
221
 
222
  bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
 
233
  bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
234
  # bottle_neck shape: [batch_size, channels, time_steps]
235
 
236
+ enhanced_audios = self.up_sampling.forward(bottle_neck, skip_connection_list)
237
 
238
  enhanced_audios = enhanced_audios[:, :, :n_samples]
239
  # enhanced_audios shape: [batch_size, 1, n_samples]
 
257
  )
258
  noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0)
259
 
260
+ bottle_neck, skip_connection_list = self.down_sampling.forward(noisy_audios_padded)
261
  # bottle_neck shape: [batch_size, channels, time_steps]
262
 
263
  bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
 
274
  bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
275
  # bottle_neck shape: [batch_size, channels, time_steps]
276
 
277
+ enhanced_audios = self.up_sampling.forward(bottle_neck, skip_connection_list)
278
 
279
  enhanced_audios = enhanced_audios[:, :, :n_samples]
280
  # enhanced_audios shape: [batch_size, 1, n_samples]
toolbox/torchaudio/models/nx_clean_unet/{transformer → transformers}/__init__.py RENAMED
File without changes
toolbox/torchaudio/models/nx_clean_unet/{transformer → transformers}/attention.py RENAMED
File without changes
toolbox/torchaudio/models/nx_clean_unet/{transformer → transformers}/mask.py RENAMED
File without changes
toolbox/torchaudio/models/nx_clean_unet/{transformer/transformer.py → transformers/transformers.py} RENAMED
@@ -1,14 +1,12 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
- import math
4
  from typing import Dict, Optional, Tuple, List, Union
5
 
6
  import torch
7
  import torch.nn as nn
8
- from fontTools.subset import prune_post_subset
9
 
10
- from toolbox.torchaudio.models.nx_clean_unet.transformer.mask import subsequent_chunk_mask
11
- from toolbox.torchaudio.models.nx_clean_unet.transformer.attention import MultiHeadSelfAttention, RelativeMultiHeadSelfAttention
12
 
13
 
14
  class PositionwiseFeedForward(nn.Module):
@@ -41,7 +39,7 @@ class PositionwiseFeedForward(nn.Module):
41
  return self.w_2(self.dropout(self.activation(self.w_1(xs))))
42
 
43
 
44
- class TransformerEncoderLayer(nn.Module):
45
  def __init__(self,
46
  input_dim: int,
47
  dropout_rate: float = 0.1,
@@ -129,7 +127,7 @@ class TransformerEncoder(nn.Module):
129
  )
130
 
131
  self.encoder_layer_list = torch.nn.ModuleList([
132
- TransformerEncoderLayer(
133
  input_dim=hidden_size,
134
  n_heads=attention_heads,
135
  dropout_rate=dropout_rate,
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
 
3
  from typing import Dict, Optional, Tuple, List, Union
4
 
5
  import torch
6
  import torch.nn as nn
 
7
 
8
+ from toolbox.torchaudio.models.nx_clean_unet.transformers.mask import subsequent_chunk_mask
9
+ from toolbox.torchaudio.models.nx_clean_unet.transformers.attention import MultiHeadSelfAttention, RelativeMultiHeadSelfAttention
10
 
11
 
12
  class PositionwiseFeedForward(nn.Module):
 
39
  return self.w_2(self.dropout(self.activation(self.w_1(xs))))
40
 
41
 
42
+ class TransformerBlock(nn.Module):
43
  def __init__(self,
44
  input_dim: int,
45
  dropout_rate: float = 0.1,
 
127
  )
128
 
129
  self.encoder_layer_list = torch.nn.ModuleList([
130
+ TransformerBlock(
131
  input_dim=hidden_size,
132
  n_heads=attention_heads,
133
  dropout_rate=dropout_rate,