HoneyTian commited on
Commit
7ec71f8
·
1 Parent(s): ce34f8c
examples/spectrum_unet_irm_aishell/step_2_train_model.py CHANGED
@@ -304,7 +304,7 @@ def main():
304
  if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
305
  raise AssertionError("nan or inf in snr_loss")
306
  # loss = irm_loss + 0.1 * snr_loss
307
- loss = irm_loss + 0.05 * snr_loss
308
  # loss = irm_loss
309
 
310
  total_loss += loss.item()
@@ -347,7 +347,7 @@ def main():
347
  raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
348
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
349
  # loss = irm_loss + 0.1 * snr_loss
350
- loss = irm_loss + 0.05 * snr_loss
351
  # loss = irm_loss
352
 
353
  total_loss += loss.item()
 
304
  if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
305
  raise AssertionError("nan or inf in snr_loss")
306
  # loss = irm_loss + 0.1 * snr_loss
307
+ loss = 10.0 * irm_loss + 0.05 * snr_loss
308
  # loss = irm_loss
309
 
310
  total_loss += loss.item()
 
347
  raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
348
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
349
  # loss = irm_loss + 0.1 * snr_loss
350
+ loss = 10.0 * irm_loss + 0.05 * snr_loss
351
  # loss = irm_loss
352
 
353
  total_loss += loss.item()