HoneyTian commited on
Commit
32aa651
·
1 Parent(s): 4f045d5
examples/clean_unet_aishell/run.sh CHANGED
@@ -12,7 +12,7 @@ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name fi
12
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
 
15
- sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
 
 
12
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
 
15
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
 
examples/clean_unet_aishell/step_2_train_model.py CHANGED
@@ -243,7 +243,7 @@ def main():
243
 
244
  enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
245
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
246
- pesq_score = run_pesq_score(enhanced_audios_list_r, clean_audios_list_r, sample_rate=8000, mode="nb")
247
 
248
  optimizer.zero_grad()
249
  loss.backward()
@@ -304,7 +304,7 @@ def main():
304
 
305
  enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
306
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
307
- pesq_score = run_pesq_score(enhanced_audios_list_r, clean_audios_list_r, sample_rate=8000, mode="nb")
308
 
309
  total_pesq_score += pesq_score
310
  total_loss += loss.item()
 
243
 
244
  enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
245
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
246
+ pesq_score = run_pesq_score(clean_audios_list_r, enhanced_audios_list_r, sample_rate=8000, mode="nb")
247
 
248
  optimizer.zero_grad()
249
  loss.backward()
 
304
 
305
  enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
306
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
307
+ pesq_score = run_pesq_score(clean_audios_list_r, enhanced_audios_list_r, sample_rate=8000, mode="nb")
308
 
309
  total_pesq_score += pesq_score
310
  total_loss += loss.item()
examples/mpnet/step_2_train_model.py CHANGED
@@ -26,9 +26,10 @@ from tqdm import tqdm
26
 
27
  from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
28
  from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig
29
- from toolbox.torchaudio.models.mpnet.discriminator import MetricDiscriminatorPretrainedModel, batch_pesq
30
- from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNet, MPNetPretrainedModel, phase_losses, pesq_score
31
  from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft
 
32
 
33
 
34
  def get_args():
@@ -251,7 +252,7 @@ def main():
251
  mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
252
 
253
  audio_list_r, audio_list_g = list(clean_audio.cpu().numpy()), list(audio_g.detach().cpu().numpy())
254
- batch_pesq_score = batch_pesq(audio_list_r, audio_list_g)
255
 
256
  # Discriminator
257
  optim_d.zero_grad()
@@ -259,11 +260,12 @@ def main():
259
  metric_g = discriminator.forward(clean_mag, mag_g_hat.detach())
260
  loss_disc_r = F.mse_loss(one_labels, metric_r.flatten())
261
 
262
- if batch_pesq_score is not None:
263
- loss_disc_g = F.mse_loss(batch_pesq_score.to(device), metric_g.flatten())
264
- else:
265
- # print("pesq is None!")
266
  loss_disc_g = 0
 
 
 
267
 
268
  loss_disc_all = loss_disc_r + loss_disc_g
269
  loss_disc_all.backward()
@@ -334,11 +336,17 @@ def main():
334
  audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
335
  mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
336
 
337
- total_pesq_score += pesq_score(
338
- torch.split(clean_audio, 1, dim=0),
339
- torch.split(audio_g, 1, dim=0),
340
- config
341
- ).item()
 
 
 
 
 
 
342
  total_mag_err += F.mse_loss(clean_mag, mag_g).item()
343
  val_ip_err, val_gd_err, val_iaf_err = phase_losses(clean_pha, pha_g)
344
  total_pha_err += (val_ip_err + val_gd_err + val_iaf_err).item()
 
26
 
27
  from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
28
  from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig
29
+ from toolbox.torchaudio.models.mpnet.discriminator import MetricDiscriminatorPretrainedModel
30
+ from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNet, MPNetPretrainedModel, phase_losses
31
  from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft
32
+ from toolbox.torchaudio.models.mpnet.metrics import run_batch_pesq, run_pesq_score
33
 
34
 
35
  def get_args():
 
252
  mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
253
 
254
  audio_list_r, audio_list_g = list(clean_audio.cpu().numpy()), list(audio_g.detach().cpu().numpy())
255
+ pesq_score_list: List[float] = run_batch_pesq(audio_list_r, audio_list_g, sample_rate=config.sample_rate, mode="nb")
256
 
257
  # Discriminator
258
  optim_d.zero_grad()
 
260
  metric_g = discriminator.forward(clean_mag, mag_g_hat.detach())
261
  loss_disc_r = F.mse_loss(one_labels, metric_r.flatten())
262
 
263
+ if -1 in pesq_score_list:
264
+ # print("-1 in batch_pesq_score!")
 
 
265
  loss_disc_g = 0
266
+ else:
267
+ pesq_score_list: torch.FloatTensor = torch.tensor([(score - 1) / 3.5 for score in pesq_score_list], dtype=torch.float32)
268
+ loss_disc_g = F.mse_loss(pesq_score_list.to(device), metric_g.flatten())
269
 
270
  loss_disc_all = loss_disc_r + loss_disc_g
271
  loss_disc_all.backward()
 
336
  audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
337
  mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
338
 
339
+ clean_audio_list = torch.split(clean_audio, 1, dim=0)
340
+ enhanced_audio_list = torch.split(audio_g, 1, dim=0)
341
+ clean_audio_list = [t.squeeze().cpu().numpy() for t in clean_audio_list]
342
+ enhanced_audio_list = [t.squeeze().cpu().numpy() for t in enhanced_audio_list]
343
+ pesq_score = run_pesq_score(
344
+ clean_audio_list,
345
+ enhanced_audio_list,
346
+ sample_rate = config.sample_rate,
347
+ mode = "nb",
348
+ )
349
+ total_pesq_score += pesq_score
350
  total_mag_err += F.mse_loss(clean_mag, mag_g).item()
351
  val_ip_err, val_gd_err, val_iaf_err = phase_losses(clean_pha, pha_g)
352
  total_pha_err += (val_ip_err + val_gd_err + val_iaf_err).item()