HoneyTian commited on
Commit
7f9c54f
·
1 Parent(s): f74ae8e
examples/mpnet_aishell/run.sh CHANGED
@@ -12,7 +12,7 @@ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name fi
12
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
 
15
- sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
 
 
12
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
 
15
+ sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
 
examples/mpnet_aishell/step_2_train_model.py CHANGED
@@ -49,7 +49,6 @@ def get_args():
49
  parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
50
  parser.add_argument("--patience", default=5, type=int)
51
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
52
- parser.add_argument("--seed", default=0, type=int)
53
 
54
  parser.add_argument("--config_file", default="config.yaml", type=str)
55
 
@@ -79,110 +78,31 @@ def logging_config(file_dir: str):
79
 
80
 
81
  class CollateFunction(object):
82
- def __init__(self,
83
- n_fft: int = 512,
84
- win_length: int = 200,
85
- hop_length: int = 80,
86
- window_fn: str = "hamming",
87
- irm_beta: float = 1.0,
88
- epsilon: float = 1e-8,
89
- ):
90
- self.n_fft = n_fft
91
- self.win_length = win_length
92
- self.hop_length = hop_length
93
- self.window_fn = window_fn
94
- self.irm_beta = irm_beta
95
- self.epsilon = epsilon
96
-
97
- self.transform = torchaudio.transforms.Spectrogram(
98
- n_fft=self.n_fft,
99
- win_length=self.win_length,
100
- hop_length=self.hop_length,
101
- power=2.0,
102
- window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
103
- )
104
-
105
- @staticmethod
106
- def make_unfold_snr_db(x: torch.Tensor, n_time_steps: int = 3):
107
- batch_size, channels, freq_dim, time_steps = x.shape
108
-
109
- # kernel: [freq_dim, n_time_step]
110
- kernel_size = (freq_dim, n_time_steps)
111
-
112
- # pad
113
- pad = n_time_steps // 2
114
- x = torch.concat(tensors=[
115
- x[:, :, :, :pad],
116
- x,
117
- x[:, :, :, -pad:],
118
- ], dim=-1)
119
-
120
- x = F.unfold(
121
- input=x,
122
- kernel_size=kernel_size,
123
- )
124
- # x shape: [batch_size, fold, time_steps]
125
- return x
126
 
127
  def __call__(self, batch: List[dict]):
128
- mix_spec_list = list()
129
- speech_irm_list = list()
130
- snr_db_list = list()
131
  for sample in batch:
132
- noise_wave: torch.Tensor = sample["noise_wave"]
133
- speech_wave: torch.Tensor = sample["speech_wave"]
134
- mix_wave: torch.Tensor = sample["mix_wave"]
135
  # snr_db: float = sample["snr_db"]
136
 
137
- noise_spec = self.transform.forward(noise_wave)
138
- speech_spec = self.transform.forward(speech_wave)
139
- mix_spec = self.transform.forward(mix_wave)
140
-
141
- # noise_irm = noise_spec / (noise_spec + speech_spec)
142
- speech_irm = speech_spec / (noise_spec + speech_spec + self.epsilon)
143
- speech_irm = torch.pow(speech_irm, self.irm_beta)
144
-
145
- # noise_spec, speech_spec, mix_spec, speech_irm
146
- # shape: [freq_dim, time_steps]
147
-
148
- snr_db: torch.Tensor = 10 * torch.log10(
149
- speech_spec / (noise_spec + self.epsilon)
150
- )
151
- snr_db = torch.clamp(snr_db, min=self.epsilon)
152
-
153
- snr_db_ = torch.unsqueeze(snr_db, dim=0)
154
- snr_db_ = torch.unsqueeze(snr_db_, dim=0)
155
- snr_db_ = self.make_unfold_snr_db(snr_db_, n_time_steps=3)
156
- snr_db_ = torch.squeeze(snr_db_, dim=0)
157
- # snr_db_ shape: [fold, time_steps]
158
-
159
- snr_db = torch.mean(snr_db_, dim=0, keepdim=True)
160
- # snr_db shape: [1, time_steps]
161
-
162
- mix_spec_list.append(mix_spec)
163
- speech_irm_list.append(speech_irm)
164
- snr_db_list.append(snr_db)
165
-
166
- mix_spec_list = torch.stack(mix_spec_list)
167
- speech_irm_list = torch.stack(speech_irm_list)
168
- snr_db_list = torch.stack(snr_db_list) # shape: (batch_size, time_steps, 1)
169
-
170
- mix_spec_list = mix_spec_list[:, :-1, :]
171
- speech_irm_list = speech_irm_list[:, :-1, :]
172
 
173
- # mix_spec_list shape: [batch_size, freq_dim, time_steps]
174
- # speech_irm_list shape: [batch_size, freq_dim, time_steps]
175
- # snr_db shape: [batch_size, 1, time_steps]
176
 
177
  # assert
178
- if torch.any(torch.isnan(mix_spec_list)) or torch.any(torch.isinf(mix_spec_list)):
179
- raise AssertionError("nan or inf in mix_spec_list")
180
- if torch.any(torch.isnan(speech_irm_list)) or torch.any(torch.isinf(speech_irm_list)):
181
- raise AssertionError("nan or inf in speech_irm_list")
182
- if torch.any(torch.isnan(snr_db_list)) or torch.any(torch.isinf(snr_db_list)):
183
- raise AssertionError("nan or inf in snr_db_list")
184
-
185
- return mix_spec_list, speech_irm_list, snr_db_list
186
 
187
 
188
  collate_fn = CollateFunction()
 
49
  parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
50
  parser.add_argument("--patience", default=5, type=int)
51
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
 
52
 
53
  parser.add_argument("--config_file", default="config.yaml", type=str)
54
 
 
78
 
79
 
80
  class CollateFunction(object):
81
+ def __init__(self):
82
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  def __call__(self, batch: List[dict]):
85
+ clean_audios = list()
86
+ noisy_audios = list()
87
+
88
  for sample in batch:
89
+ # noise_wave: torch.Tensor = sample["noise_wave"]
90
+ clean_audio: torch.Tensor = sample["speech_wave"]
91
+ noisy_audio: torch.Tensor = sample["mix_wave"]
92
  # snr_db: float = sample["snr_db"]
93
 
94
+ clean_audios.append(clean_audio)
95
+ noisy_audios.append(noisy_audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ clean_audios = torch.stack(clean_audios)
98
+ noisy_audios = torch.stack(noisy_audios)
 
99
 
100
  # assert
101
+ if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
102
+ raise AssertionError("nan or inf in clean_audios")
103
+ if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
104
+ raise AssertionError("nan or inf in noisy_audios")
105
+ return clean_audios, noisy_audios
 
 
 
106
 
107
 
108
  collate_fn = CollateFunction()