HoneyTian commited on
Commit
3834772
·
1 Parent(s): b78e61f
examples/nx_denoise/step_2_train_model.py CHANGED
@@ -288,8 +288,9 @@ def main():
288
  metric_g = discriminator.forward(clean_audios, audio_g.detach())
289
  loss_metric = F.mse_loss(metric_g.flatten(), one_labels)
290
 
291
- # loss_gen_all = loss_mag * 0.9 + loss_pha * 0.3 + loss_com * 0.1 + loss_metric * 0.05 + loss_time * 0.2
292
- loss_gen_all = loss_mag * 0.1 + loss_pha * 0.1 + loss_com * 0.1 + loss_metric * 0.9 + loss_time * 0.9
 
293
 
294
  loss_gen_all.backward()
295
  optim_g.step()
 
288
  metric_g = discriminator.forward(clean_audios, audio_g.detach())
289
  loss_metric = F.mse_loss(metric_g.flatten(), one_labels)
290
 
291
+ loss_gen_all = loss_mag * 0.9 + loss_pha * 0.3 + loss_com * 0.1 + loss_metric * 0.05 + loss_time * 0.2
292
+ # loss_gen_all = loss_mag * 0.1 + loss_pha * 0.1 + loss_com * 0.1 + loss_metric * 0.9 + loss_time * 0.9
293
+ # 2.02
294
 
295
  loss_gen_all.backward()
296
  optim_g.step()
toolbox/torchaudio/models/nx_denoise/transformers/transformers.py CHANGED
@@ -456,19 +456,19 @@ def main2():
456
  def main():
457
 
458
  encoder = TSTransformerEncoder(
459
- input_size=16,
460
- hidden_size=64,
461
- attention_heads=4,
462
- num_blocks=4,
463
  dropout_rate=0.1,
464
  )
465
  # print(encoder)
466
 
467
- x = torch.ones([4, 16, 200, 32])
468
  y = encoder.forward(xs=x)
469
  print(y.shape)
470
 
471
- x = torch.ones([4, 16, 200, 32])
472
  y = encoder.forward_chunk_by_chunk(xs=x)
473
  print(y.shape)
474
 
 
456
  def main():
457
 
458
  encoder = TSTransformerEncoder(
459
+ input_size=8,
460
+ hidden_size=16,
461
+ attention_heads=2,
462
+ num_blocks=2,
463
  dropout_rate=0.1,
464
  )
465
  # print(encoder)
466
 
467
+ x = torch.ones([4, 8, 200, 8])
468
  y = encoder.forward(xs=x)
469
  print(y.shape)
470
 
471
+ x = torch.ones([4, 8, 200, 8])
472
  y = encoder.forward_chunk_by_chunk(xs=x)
473
  print(y.shape)
474