File size: 4,970 Bytes
8ebda9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# coding=utf-8
from collections import Counter
import torch
from torch import nn
# import seqeval
from .utils_ner import get_entities
class metrics_mlm_acc(nn.Module):
def __init__(self):
super().__init__()
def forward(self, logits, labels, masked_lm_metric):
# if len(list(logits.shape))==3:
mask_label_size = 0
for i in masked_lm_metric:
for j in i:
if j > 0:
mask_label_size += 1
y_pred = torch.argmax(logits, dim=-1)
y_pred = y_pred.view(size=(-1,))
y_true = labels.view(size=(-1,))
masked_lm_metric = masked_lm_metric.view(size=(-1,))
corr = torch.eq(y_pred, y_true)
corr = torch.multiply(masked_lm_metric, corr)
acc = torch.sum(corr.float())/mask_label_size
return acc
class EntityScore(object):
def __init__(self):
self.reset()
def reset(self):
self.origins = []
self.founds = []
self.rights = []
def compute(self, origin, found, right):
recall = 0 if origin == 0 else (right / origin)
precision = 0 if found == 0 else (right / found)
f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall)
return recall, precision, f1
def result(self):
class_info = {}
origin_counter = Counter([x[0] for x in self.origins])
found_counter = Counter([x[0] for x in self.founds])
right_counter = Counter([x[0] for x in self.rights])
for type_, count in origin_counter.items():
origin = count
found = found_counter.get(type_, 0)
right = right_counter.get(type_, 0)
recall, precision, f1 = self.compute(origin, found, right)
class_info[type_] = {"acc": round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4)}
origin = len(self.origins)
found = len(self.founds)
right = len(self.rights)
recall, precision, f1 = self.compute(origin, found, right)
return {'acc': precision, 'recall': recall, 'f1': f1}, class_info
def update(self, true_subject, pred_subject):
self.origins.extend(true_subject)
self.founds.extend(pred_subject)
self.rights.extend([pre_entity for pre_entity in pred_subject if pre_entity in true_subject])
class SeqEntityScore(object):
def __init__(self, id2label, markup='bios', middle_prefix='I-'):
self.id2label = id2label
self.markup = markup
self.middle_prefix = middle_prefix
self.reset()
def reset(self):
self.origins = []
self.founds = []
self.rights = []
def compute(self, origin, found, right):
recall = 0 if origin == 0 else (right / origin)
precision = 0 if found == 0 else (right / found)
f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall)
return recall, precision, f1
def result(self):
class_info = {}
origin_counter = Counter([x[0] for x in self.origins])
found_counter = Counter([x[0] for x in self.founds])
right_counter = Counter([x[0] for x in self.rights])
for type_, count in origin_counter.items():
origin = count
found = found_counter.get(type_, 0)
right = right_counter.get(type_, 0)
# print('origin:', origin, ' found:', found, ' right:', right)
recall, precision, f1 = self.compute(origin, found, right)
class_info[type_] = {"acc": round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4)}
origin = len(self.origins)
found = len(self.founds)
right = len(self.rights)
recall, precision, f1 = self.compute(origin, found, right)
return {'acc': precision, 'recall': recall, 'f1': f1}, class_info
def update(self, label_paths, pred_paths):
'''
labels_paths: [[],[],[],....]
pred_paths: [[],[],[],.....]
:param label_paths:
:param pred_paths:
:return:
Example:
>>> labels_paths = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> pred_paths = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
'''
for label_path, pre_path in zip(label_paths, pred_paths):
label_entities = get_entities(label_path, self.id2label, self.markup, self.middle_prefix)
pre_entities = get_entities(pre_path, self.id2label, self.markup, self.middle_prefix)
# print('label:', label_path, ',label_entities: ', label_entities)
# print('pred:', pre_path, ',pre_entities: ', pre_entities)
self.origins.extend(label_entities)
self.founds.extend(pre_entities)
self.rights.extend([pre_entity for pre_entity in pre_entities if pre_entity in label_entities])
|