Spaces:
Runtime error
Runtime error
File size: 331 Bytes
ae29df4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import torch
def l1(output, target):
return torch.mean(torch.abs(output - target))
def l1_wav(output_dict, target_dict):
return l1(output_dict['segment'], target_dict['segment'])
def get_loss_function(loss_type):
if loss_type == "l1_wav":
return l1_wav
else:
raise NotImplementedError("Error!")
|