Spaces:
Running
Running
update
Browse files- examples/nx_clean_unet/yaml/config.yaml +5 -5
- toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py +16 -9
- toolbox/torchaudio/models/nx_clean_unet/{transformer → transformers}/__init__.py +0 -0
- toolbox/torchaudio/models/nx_clean_unet/{transformer → transformers}/attention.py +0 -0
- toolbox/torchaudio/models/nx_clean_unet/{transformer → transformers}/mask.py +0 -0
- toolbox/torchaudio/models/nx_clean_unet/{transformer/transformer.py → transformers/transformers.py} +4 -6
examples/nx_clean_unet/yaml/config.yaml
CHANGED
@@ -6,7 +6,7 @@ n_fft: 512
|
|
6 |
win_size: 200
|
7 |
hop_size: 80
|
8 |
|
9 |
-
down_sampling_num_layers:
|
10 |
down_sampling_in_channels: 1
|
11 |
down_sampling_hidden_channels: 64
|
12 |
down_sampling_kernel_size: 4
|
@@ -18,16 +18,16 @@ causal_kernel_size: 3
|
|
18 |
causal_bias: false
|
19 |
causal_separable: true
|
20 |
causal_f_stride: 1
|
21 |
-
causal_num_layers:
|
22 |
|
23 |
tsfm_hidden_size: 256
|
24 |
tsfm_attention_heads: 8
|
25 |
tsfm_num_blocks: 6
|
26 |
tsfm_dropout_rate: 0.1
|
27 |
tsfm_max_length: 512
|
28 |
-
tsfm_chunk_size:
|
29 |
-
tsfm_num_left_chunks:
|
30 |
-
tsfm_num_right_chunks:
|
31 |
|
32 |
discriminator_dim: 32
|
33 |
discriminator_in_channel: 2
|
|
|
6 |
win_size: 200
|
7 |
hop_size: 80
|
8 |
|
9 |
+
down_sampling_num_layers: 6
|
10 |
down_sampling_in_channels: 1
|
11 |
down_sampling_hidden_channels: 64
|
12 |
down_sampling_kernel_size: 4
|
|
|
18 |
causal_bias: false
|
19 |
causal_separable: true
|
20 |
causal_f_stride: 1
|
21 |
+
causal_num_layers: 5
|
22 |
|
23 |
tsfm_hidden_size: 256
|
24 |
tsfm_attention_heads: 8
|
25 |
tsfm_num_blocks: 6
|
26 |
tsfm_dropout_rate: 0.1
|
27 |
tsfm_max_length: 512
|
28 |
+
tsfm_chunk_size: 1
|
29 |
+
tsfm_num_left_chunks: 128
|
30 |
+
tsfm_num_right_chunks: 4
|
31 |
|
32 |
discriminator_dim: 32
|
33 |
discriminator_in_channel: 2
|
toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
import os
|
4 |
-
from typing import Optional, Union
|
5 |
|
6 |
import numpy as np
|
7 |
import torch
|
@@ -10,7 +10,7 @@ from torch.nn import functional as F
|
|
10 |
|
11 |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
12 |
from toolbox.torchaudio.models.nx_clean_unet.configuration_nx_clean_unet import NXCleanUNetConfig
|
13 |
-
from toolbox.torchaudio.models.nx_clean_unet.
|
14 |
from toolbox.torchaudio.models.nx_clean_unet.causal_convolution.causal_conv2d import CausalConv2dEncoder
|
15 |
|
16 |
|
@@ -66,10 +66,12 @@ class DownSampling(nn.Module):
|
|
66 |
|
67 |
def forward(self, x: torch.Tensor):
|
68 |
# x shape: [batch_size, channels, num_samples]
|
|
|
69 |
for down_sampling_block in self.down_sampling_block_list:
|
70 |
x = down_sampling_block.forward(x)
|
|
|
71 |
# x shape: [batch_size, hidden_channels, num_samples**]
|
72 |
-
return x
|
73 |
|
74 |
|
75 |
class UpSamplingBlock(nn.Module):
|
@@ -134,9 +136,14 @@ class UpSampling(nn.Module):
|
|
134 |
up_sampling_block_list.append(up_sampling_block)
|
135 |
self.up_sampling_block_list = nn.ModuleList(modules=up_sampling_block_list)
|
136 |
|
137 |
-
def forward(self, x: torch.Tensor):
|
|
|
|
|
138 |
# x shape: [batch_size, channels, num_samples]
|
139 |
-
for up_sampling_block in self.up_sampling_block_list:
|
|
|
|
|
|
|
140 |
x = up_sampling_block.forward(x)
|
141 |
return x
|
142 |
|
@@ -209,7 +216,7 @@ class NXCleanUNet(nn.Module):
|
|
209 |
)
|
210 |
noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0)
|
211 |
|
212 |
-
bottle_neck = self.down_sampling.forward(noisy_audios_padded)
|
213 |
# bottle_neck shape: [batch_size, channels, time_steps]
|
214 |
|
215 |
bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
|
@@ -226,7 +233,7 @@ class NXCleanUNet(nn.Module):
|
|
226 |
bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
|
227 |
# bottle_neck shape: [batch_size, channels, time_steps]
|
228 |
|
229 |
-
enhanced_audios = self.up_sampling.forward(bottle_neck)
|
230 |
|
231 |
enhanced_audios = enhanced_audios[:, :, :n_samples]
|
232 |
# enhanced_audios shape: [batch_size, 1, n_samples]
|
@@ -250,7 +257,7 @@ class NXCleanUNet(nn.Module):
|
|
250 |
)
|
251 |
noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0)
|
252 |
|
253 |
-
bottle_neck = self.down_sampling.forward(noisy_audios_padded)
|
254 |
# bottle_neck shape: [batch_size, channels, time_steps]
|
255 |
|
256 |
bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
|
@@ -267,7 +274,7 @@ class NXCleanUNet(nn.Module):
|
|
267 |
bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
|
268 |
# bottle_neck shape: [batch_size, channels, time_steps]
|
269 |
|
270 |
-
enhanced_audios = self.up_sampling.forward(bottle_neck)
|
271 |
|
272 |
enhanced_audios = enhanced_audios[:, :, :n_samples]
|
273 |
# enhanced_audios shape: [batch_size, 1, n_samples]
|
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
import os
|
4 |
+
from typing import List, Optional, Union
|
5 |
|
6 |
import numpy as np
|
7 |
import torch
|
|
|
10 |
|
11 |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
12 |
from toolbox.torchaudio.models.nx_clean_unet.configuration_nx_clean_unet import NXCleanUNetConfig
|
13 |
+
from toolbox.torchaudio.models.nx_clean_unet.transformers.transformers import TransformerEncoder
|
14 |
from toolbox.torchaudio.models.nx_clean_unet.causal_convolution.causal_conv2d import CausalConv2dEncoder
|
15 |
|
16 |
|
|
|
66 |
|
67 |
def forward(self, x: torch.Tensor):
|
68 |
# x shape: [batch_size, channels, num_samples]
|
69 |
+
skip_connection_list = list()
|
70 |
for down_sampling_block in self.down_sampling_block_list:
|
71 |
x = down_sampling_block.forward(x)
|
72 |
+
skip_connection_list.append(x)
|
73 |
# x shape: [batch_size, hidden_channels, num_samples**]
|
74 |
+
return x, skip_connection_list
|
75 |
|
76 |
|
77 |
class UpSamplingBlock(nn.Module):
|
|
|
136 |
up_sampling_block_list.append(up_sampling_block)
|
137 |
self.up_sampling_block_list = nn.ModuleList(modules=up_sampling_block_list)
|
138 |
|
139 |
+
def forward(self, x: torch.Tensor, skip_connection_list: List[torch.Tensor]):
|
140 |
+
skip_connection_list = skip_connection_list[::-1]
|
141 |
+
|
142 |
# x shape: [batch_size, channels, num_samples]
|
143 |
+
for idx, up_sampling_block in enumerate(self.up_sampling_block_list):
|
144 |
+
skip_x = skip_connection_list[idx]
|
145 |
+
x = x + skip_x
|
146 |
+
# x = x + skip_x[:, :, :x.shape[-1]]
|
147 |
x = up_sampling_block.forward(x)
|
148 |
return x
|
149 |
|
|
|
216 |
)
|
217 |
noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0)
|
218 |
|
219 |
+
bottle_neck, skip_connection_list = self.down_sampling.forward(noisy_audios_padded)
|
220 |
# bottle_neck shape: [batch_size, channels, time_steps]
|
221 |
|
222 |
bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
|
|
|
233 |
bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
|
234 |
# bottle_neck shape: [batch_size, channels, time_steps]
|
235 |
|
236 |
+
enhanced_audios = self.up_sampling.forward(bottle_neck, skip_connection_list)
|
237 |
|
238 |
enhanced_audios = enhanced_audios[:, :, :n_samples]
|
239 |
# enhanced_audios shape: [batch_size, 1, n_samples]
|
|
|
257 |
)
|
258 |
noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0)
|
259 |
|
260 |
+
bottle_neck, skip_connection_list = self.down_sampling.forward(noisy_audios_padded)
|
261 |
# bottle_neck shape: [batch_size, channels, time_steps]
|
262 |
|
263 |
bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
|
|
|
274 |
bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
|
275 |
# bottle_neck shape: [batch_size, channels, time_steps]
|
276 |
|
277 |
+
enhanced_audios = self.up_sampling.forward(bottle_neck, skip_connection_list)
|
278 |
|
279 |
enhanced_audios = enhanced_audios[:, :, :n_samples]
|
280 |
# enhanced_audios shape: [batch_size, 1, n_samples]
|
toolbox/torchaudio/models/nx_clean_unet/{transformer → transformers}/__init__.py
RENAMED
File without changes
|
toolbox/torchaudio/models/nx_clean_unet/{transformer → transformers}/attention.py
RENAMED
File without changes
|
toolbox/torchaudio/models/nx_clean_unet/{transformer → transformers}/mask.py
RENAMED
File without changes
|
toolbox/torchaudio/models/nx_clean_unet/{transformer/transformer.py → transformers/transformers.py}
RENAMED
@@ -1,14 +1,12 @@
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
-
import math
|
4 |
from typing import Dict, Optional, Tuple, List, Union
|
5 |
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
-
from fontTools.subset import prune_post_subset
|
9 |
|
10 |
-
from toolbox.torchaudio.models.nx_clean_unet.
|
11 |
-
from toolbox.torchaudio.models.nx_clean_unet.
|
12 |
|
13 |
|
14 |
class PositionwiseFeedForward(nn.Module):
|
@@ -41,7 +39,7 @@ class PositionwiseFeedForward(nn.Module):
|
|
41 |
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
42 |
|
43 |
|
44 |
-
class
|
45 |
def __init__(self,
|
46 |
input_dim: int,
|
47 |
dropout_rate: float = 0.1,
|
@@ -129,7 +127,7 @@ class TransformerEncoder(nn.Module):
|
|
129 |
)
|
130 |
|
131 |
self.encoder_layer_list = torch.nn.ModuleList([
|
132 |
-
|
133 |
input_dim=hidden_size,
|
134 |
n_heads=attention_heads,
|
135 |
dropout_rate=dropout_rate,
|
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
|
|
3 |
from typing import Dict, Optional, Tuple, List, Union
|
4 |
|
5 |
import torch
|
6 |
import torch.nn as nn
|
|
|
7 |
|
8 |
+
from toolbox.torchaudio.models.nx_clean_unet.transformers.mask import subsequent_chunk_mask
|
9 |
+
from toolbox.torchaudio.models.nx_clean_unet.transformers.attention import MultiHeadSelfAttention, RelativeMultiHeadSelfAttention
|
10 |
|
11 |
|
12 |
class PositionwiseFeedForward(nn.Module):
|
|
|
39 |
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
40 |
|
41 |
|
42 |
+
class TransformerBlock(nn.Module):
|
43 |
def __init__(self,
|
44 |
input_dim: int,
|
45 |
dropout_rate: float = 0.1,
|
|
|
127 |
)
|
128 |
|
129 |
self.encoder_layer_list = torch.nn.ModuleList([
|
130 |
+
TransformerBlock(
|
131 |
input_dim=hidden_size,
|
132 |
n_heads=attention_heads,
|
133 |
dropout_rate=dropout_rate,
|