File size: 3,433 Bytes
50f0fbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# coding=utf-8
from collections import Counter
import torch
from torch import nn
# import seqeval

from .utils_ner import get_entities


class metrics_mlm_acc(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, logits, labels, masked_lm_metric):

        # if len(list(logits.shape))==3:
        mask_label_size = 0
        for i in masked_lm_metric:
            for j in i:
                if j > 0:
                    mask_label_size += 1

        y_pred = torch.argmax(logits, dim=-1)

        y_pred = y_pred.view(size=(-1,))
        y_true = labels.view(size=(-1,))
        masked_lm_metric = masked_lm_metric.view(size=(-1,))

        corr = torch.eq(y_pred, y_true)
        corr = torch.multiply(masked_lm_metric, corr)

        acc = torch.sum(corr.float())/mask_label_size
        return acc


class SeqEntityScore(object):
    def __init__(self, id2label, markup='bios', middle_prefix='I-'):
        self.id2label = id2label
        self.markup = markup
        self.middle_prefix = middle_prefix
        self.reset()

    def reset(self):
        self.origins = []
        self.founds = []
        self.rights = []

    def compute(self, origin, found, right):
        recall = 0 if origin == 0 else (right / origin)
        precision = 0 if found == 0 else (right / found)
        f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall)
        return recall, precision, f1

    def result(self):
        class_info = {}
        origin_counter = Counter([x[0] for x in self.origins])
        found_counter = Counter([x[0] for x in self.founds])
        right_counter = Counter([x[0] for x in self.rights])
        for type_, count in origin_counter.items():
            origin = count
            found = found_counter.get(type_, 0)
            right = right_counter.get(type_, 0)
            # print('origin:', origin, ' found:', found, ' right:', right)
            recall, precision, f1 = self.compute(origin, found, right)
            class_info[type_] = {"acc": round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4)}
        origin = len(self.origins)
        found = len(self.founds)
        right = len(self.rights)
        recall, precision, f1 = self.compute(origin, found, right)
        return {'acc': precision, 'recall': recall, 'f1': f1}, class_info

    def update(self, label_paths, pred_paths):
        '''
        labels_paths: [[],[],[],....]
        pred_paths: [[],[],[],.....]

        :param label_paths:
        :param pred_paths:
        :return:
        Example:
            >>> labels_paths = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
            >>> pred_paths = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
        '''
        for label_path, pre_path in zip(label_paths, pred_paths):
            label_entities = get_entities(label_path, self.id2label, self.markup, self.middle_prefix)
            pre_entities = get_entities(pre_path, self.id2label, self.markup, self.middle_prefix)
            # print('label:', label_path, ',label_entities: ', label_entities)
            # print('pred:', pre_path, ',pre_entities: ', pre_entities)
            self.origins.extend(label_entities)
            self.founds.extend(pre_entities)
            self.rights.extend([pre_entity for pre_entity in pre_entities if pre_entity in label_entities])