# coding=utf-8 from collections import Counter import torch from torch import nn # import seqeval from .utils_ner import get_entities class metrics_mlm_acc(nn.Module): def __init__(self): super().__init__() def forward(self, logits, labels, masked_lm_metric): # if len(list(logits.shape))==3: mask_label_size = 0 for i in masked_lm_metric: for j in i: if j > 0: mask_label_size += 1 y_pred = torch.argmax(logits, dim=-1) y_pred = y_pred.view(size=(-1,)) y_true = labels.view(size=(-1,)) masked_lm_metric = masked_lm_metric.view(size=(-1,)) corr = torch.eq(y_pred, y_true) corr = torch.multiply(masked_lm_metric, corr) acc = torch.sum(corr.float())/mask_label_size return acc class EntityScore(object): def __init__(self): self.reset() def reset(self): self.origins = [] self.founds = [] self.rights = [] def compute(self, origin, found, right): recall = 0 if origin == 0 else (right / origin) precision = 0 if found == 0 else (right / found) f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall) return recall, precision, f1 def result(self): class_info = {} origin_counter = Counter([x[0] for x in self.origins]) found_counter = Counter([x[0] for x in self.founds]) right_counter = Counter([x[0] for x in self.rights]) for type_, count in origin_counter.items(): origin = count found = found_counter.get(type_, 0) right = right_counter.get(type_, 0) recall, precision, f1 = self.compute(origin, found, right) class_info[type_] = {"acc": round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4)} origin = len(self.origins) found = len(self.founds) right = len(self.rights) recall, precision, f1 = self.compute(origin, found, right) return {'acc': precision, 'recall': recall, 'f1': f1}, class_info def update(self, true_subject, pred_subject): self.origins.extend(true_subject) self.founds.extend(pred_subject) self.rights.extend([pre_entity for pre_entity in pred_subject if pre_entity in true_subject]) class SeqEntityScore(object): def __init__(self, id2label, markup='bios', middle_prefix='I-'): self.id2label = id2label self.markup = markup self.middle_prefix = middle_prefix self.reset() def reset(self): self.origins = [] self.founds = [] self.rights = [] def compute(self, origin, found, right): recall = 0 if origin == 0 else (right / origin) precision = 0 if found == 0 else (right / found) f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall) return recall, precision, f1 def result(self): class_info = {} origin_counter = Counter([x[0] for x in self.origins]) found_counter = Counter([x[0] for x in self.founds]) right_counter = Counter([x[0] for x in self.rights]) for type_, count in origin_counter.items(): origin = count found = found_counter.get(type_, 0) right = right_counter.get(type_, 0) # print('origin:', origin, ' found:', found, ' right:', right) recall, precision, f1 = self.compute(origin, found, right) class_info[type_] = {"acc": round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4)} origin = len(self.origins) found = len(self.founds) right = len(self.rights) recall, precision, f1 = self.compute(origin, found, right) return {'acc': precision, 'recall': recall, 'f1': f1}, class_info def update(self, label_paths, pred_paths): ''' labels_paths: [[],[],[],....] pred_paths: [[],[],[],.....] :param label_paths: :param pred_paths: :return: Example: >>> labels_paths = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']] >>> pred_paths = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']] ''' for label_path, pre_path in zip(label_paths, pred_paths): label_entities = get_entities(label_path, self.id2label, self.markup, self.middle_prefix) pre_entities = get_entities(pre_path, self.id2label, self.markup, self.middle_prefix) # print('label:', label_path, ',label_entities: ', label_entities) # print('pred:', pre_path, ',pre_entities: ', pre_entities) self.origins.extend(label_entities) self.founds.extend(pre_entities) self.rights.extend([pre_entity for pre_entity in pre_entities if pre_entity in label_entities])