import torch def th_accuracy(pad_outputs, pad_targets, ignore_label): """Calculate accuracy. Args: pad_outputs (Tensor): Prediction tensors (B * Lmax, D). pad_targets (LongTensor): Target label tensors (B, Lmax, D). ignore_label (int): Ignore label id. Returns: float: Accuracy value (0.0 - 1.0). """ pad_pred = pad_outputs.view( pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1) ).argmax(2) mask = pad_targets != ignore_label numerator = torch.sum( pad_pred.masked_select(mask) == pad_targets.masked_select(mask) ) denominator = torch.sum(mask) return float(numerator) / float(denominator)