Spaces:
Running
Running
update
Browse files
examples/clean_unet_aishell/run.sh
CHANGED
@@ -12,7 +12,7 @@ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name fi
|
|
12 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
13 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
14 |
|
15 |
-
sh run.sh --stage
|
16 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
17 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
18 |
|
|
|
12 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
13 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
14 |
|
15 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir \
|
16 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
17 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
18 |
|
examples/clean_unet_aishell/step_2_train_model.py
CHANGED
@@ -243,7 +243,7 @@ def main():
|
|
243 |
|
244 |
enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
|
245 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
246 |
-
pesq_score = run_pesq_score(
|
247 |
|
248 |
optimizer.zero_grad()
|
249 |
loss.backward()
|
@@ -304,7 +304,7 @@ def main():
|
|
304 |
|
305 |
enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
|
306 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
307 |
-
pesq_score = run_pesq_score(
|
308 |
|
309 |
total_pesq_score += pesq_score
|
310 |
total_loss += loss.item()
|
|
|
243 |
|
244 |
enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
|
245 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
246 |
+
pesq_score = run_pesq_score(clean_audios_list_r, enhanced_audios_list_r, sample_rate=8000, mode="nb")
|
247 |
|
248 |
optimizer.zero_grad()
|
249 |
loss.backward()
|
|
|
304 |
|
305 |
enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
|
306 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
307 |
+
pesq_score = run_pesq_score(clean_audios_list_r, enhanced_audios_list_r, sample_rate=8000, mode="nb")
|
308 |
|
309 |
total_pesq_score += pesq_score
|
310 |
total_loss += loss.item()
|
examples/mpnet/step_2_train_model.py
CHANGED
@@ -26,9 +26,10 @@ from tqdm import tqdm
|
|
26 |
|
27 |
from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
|
28 |
from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig
|
29 |
-
from toolbox.torchaudio.models.mpnet.discriminator import MetricDiscriminatorPretrainedModel
|
30 |
-
from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNet, MPNetPretrainedModel, phase_losses
|
31 |
from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft
|
|
|
32 |
|
33 |
|
34 |
def get_args():
|
@@ -251,7 +252,7 @@ def main():
|
|
251 |
mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
|
252 |
|
253 |
audio_list_r, audio_list_g = list(clean_audio.cpu().numpy()), list(audio_g.detach().cpu().numpy())
|
254 |
-
|
255 |
|
256 |
# Discriminator
|
257 |
optim_d.zero_grad()
|
@@ -259,11 +260,12 @@ def main():
|
|
259 |
metric_g = discriminator.forward(clean_mag, mag_g_hat.detach())
|
260 |
loss_disc_r = F.mse_loss(one_labels, metric_r.flatten())
|
261 |
|
262 |
-
if
|
263 |
-
|
264 |
-
else:
|
265 |
-
# print("pesq is None!")
|
266 |
loss_disc_g = 0
|
|
|
|
|
|
|
267 |
|
268 |
loss_disc_all = loss_disc_r + loss_disc_g
|
269 |
loss_disc_all.backward()
|
@@ -334,11 +336,17 @@ def main():
|
|
334 |
audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
|
335 |
mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
|
336 |
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
total_mag_err += F.mse_loss(clean_mag, mag_g).item()
|
343 |
val_ip_err, val_gd_err, val_iaf_err = phase_losses(clean_pha, pha_g)
|
344 |
total_pha_err += (val_ip_err + val_gd_err + val_iaf_err).item()
|
|
|
26 |
|
27 |
from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
|
28 |
from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig
|
29 |
+
from toolbox.torchaudio.models.mpnet.discriminator import MetricDiscriminatorPretrainedModel
|
30 |
+
from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNet, MPNetPretrainedModel, phase_losses
|
31 |
from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft
|
32 |
+
from toolbox.torchaudio.models.mpnet.metrics import run_batch_pesq, run_pesq_score
|
33 |
|
34 |
|
35 |
def get_args():
|
|
|
252 |
mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
|
253 |
|
254 |
audio_list_r, audio_list_g = list(clean_audio.cpu().numpy()), list(audio_g.detach().cpu().numpy())
|
255 |
+
pesq_score_list: List[float] = run_batch_pesq(audio_list_r, audio_list_g, sample_rate=config.sample_rate, mode="nb")
|
256 |
|
257 |
# Discriminator
|
258 |
optim_d.zero_grad()
|
|
|
260 |
metric_g = discriminator.forward(clean_mag, mag_g_hat.detach())
|
261 |
loss_disc_r = F.mse_loss(one_labels, metric_r.flatten())
|
262 |
|
263 |
+
if -1 in pesq_score_list:
|
264 |
+
# print("-1 in batch_pesq_score!")
|
|
|
|
|
265 |
loss_disc_g = 0
|
266 |
+
else:
|
267 |
+
pesq_score_list: torch.FloatTensor = torch.tensor([(score - 1) / 3.5 for score in pesq_score_list], dtype=torch.float32)
|
268 |
+
loss_disc_g = F.mse_loss(pesq_score_list.to(device), metric_g.flatten())
|
269 |
|
270 |
loss_disc_all = loss_disc_r + loss_disc_g
|
271 |
loss_disc_all.backward()
|
|
|
336 |
audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
|
337 |
mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
|
338 |
|
339 |
+
clean_audio_list = torch.split(clean_audio, 1, dim=0)
|
340 |
+
enhanced_audio_list = torch.split(audio_g, 1, dim=0)
|
341 |
+
clean_audio_list = [t.squeeze().cpu().numpy() for t in clean_audio_list]
|
342 |
+
enhanced_audio_list = [t.squeeze().cpu().numpy() for t in enhanced_audio_list]
|
343 |
+
pesq_score = run_pesq_score(
|
344 |
+
clean_audio_list,
|
345 |
+
enhanced_audio_list,
|
346 |
+
sample_rate = config.sample_rate,
|
347 |
+
mode = "nb",
|
348 |
+
)
|
349 |
+
total_pesq_score += pesq_score
|
350 |
total_mag_err += F.mse_loss(clean_mag, mag_g).item()
|
351 |
val_ip_err, val_gd_err, val_iaf_err = phase_losses(clean_pha, pha_g)
|
352 |
total_pha_err += (val_ip_err + val_gd_err + val_iaf_err).item()
|