Spaces:
Running
Running
update
Browse files
examples/spectrum_unet_irm_aishell/step_2_train_model.py
CHANGED
@@ -304,7 +304,7 @@ def main():
|
|
304 |
if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
|
305 |
raise AssertionError("nan or inf in snr_loss")
|
306 |
# loss = irm_loss + 0.1 * snr_loss
|
307 |
-
loss = irm_loss + 0.05 * snr_loss
|
308 |
# loss = irm_loss
|
309 |
|
310 |
total_loss += loss.item()
|
@@ -347,7 +347,7 @@ def main():
|
|
347 |
raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
|
348 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
349 |
# loss = irm_loss + 0.1 * snr_loss
|
350 |
-
loss = irm_loss + 0.05 * snr_loss
|
351 |
# loss = irm_loss
|
352 |
|
353 |
total_loss += loss.item()
|
|
|
304 |
if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
|
305 |
raise AssertionError("nan or inf in snr_loss")
|
306 |
# loss = irm_loss + 0.1 * snr_loss
|
307 |
+
loss = 10.0 * irm_loss + 0.05 * snr_loss
|
308 |
# loss = irm_loss
|
309 |
|
310 |
total_loss += loss.item()
|
|
|
347 |
raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
|
348 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
349 |
# loss = irm_loss + 0.1 * snr_loss
|
350 |
+
loss = 10.0 * irm_loss + 0.05 * snr_loss
|
351 |
# loss = irm_loss
|
352 |
|
353 |
total_loss += loss.item()
|