fclong's picture
Upload 396 files
8ebda9e
# coding=utf-8
from collections import Counter
import torch
from torch import nn
# import seqeval
from .utils_ner import get_entities
class metrics_mlm_acc(nn.Module):
def __init__(self):
super().__init__()
def forward(self, logits, labels, masked_lm_metric):
# if len(list(logits.shape))==3:
mask_label_size = 0
for i in masked_lm_metric:
for j in i:
if j > 0:
mask_label_size += 1
y_pred = torch.argmax(logits, dim=-1)
y_pred = y_pred.view(size=(-1,))
y_true = labels.view(size=(-1,))
masked_lm_metric = masked_lm_metric.view(size=(-1,))
corr = torch.eq(y_pred, y_true)
corr = torch.multiply(masked_lm_metric, corr)
acc = torch.sum(corr.float())/mask_label_size
return acc
class EntityScore(object):
def __init__(self):
self.reset()
def reset(self):
self.origins = []
self.founds = []
self.rights = []
def compute(self, origin, found, right):
recall = 0 if origin == 0 else (right / origin)
precision = 0 if found == 0 else (right / found)
f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall)
return recall, precision, f1
def result(self):
class_info = {}
origin_counter = Counter([x[0] for x in self.origins])
found_counter = Counter([x[0] for x in self.founds])
right_counter = Counter([x[0] for x in self.rights])
for type_, count in origin_counter.items():
origin = count
found = found_counter.get(type_, 0)
right = right_counter.get(type_, 0)
recall, precision, f1 = self.compute(origin, found, right)
class_info[type_] = {"acc": round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4)}
origin = len(self.origins)
found = len(self.founds)
right = len(self.rights)
recall, precision, f1 = self.compute(origin, found, right)
return {'acc': precision, 'recall': recall, 'f1': f1}, class_info
def update(self, true_subject, pred_subject):
self.origins.extend(true_subject)
self.founds.extend(pred_subject)
self.rights.extend([pre_entity for pre_entity in pred_subject if pre_entity in true_subject])
class SeqEntityScore(object):
def __init__(self, id2label, markup='bios', middle_prefix='I-'):
self.id2label = id2label
self.markup = markup
self.middle_prefix = middle_prefix
self.reset()
def reset(self):
self.origins = []
self.founds = []
self.rights = []
def compute(self, origin, found, right):
recall = 0 if origin == 0 else (right / origin)
precision = 0 if found == 0 else (right / found)
f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall)
return recall, precision, f1
def result(self):
class_info = {}
origin_counter = Counter([x[0] for x in self.origins])
found_counter = Counter([x[0] for x in self.founds])
right_counter = Counter([x[0] for x in self.rights])
for type_, count in origin_counter.items():
origin = count
found = found_counter.get(type_, 0)
right = right_counter.get(type_, 0)
# print('origin:', origin, ' found:', found, ' right:', right)
recall, precision, f1 = self.compute(origin, found, right)
class_info[type_] = {"acc": round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4)}
origin = len(self.origins)
found = len(self.founds)
right = len(self.rights)
recall, precision, f1 = self.compute(origin, found, right)
return {'acc': precision, 'recall': recall, 'f1': f1}, class_info
def update(self, label_paths, pred_paths):
'''
labels_paths: [[],[],[],....]
pred_paths: [[],[],[],.....]
:param label_paths:
:param pred_paths:
:return:
Example:
>>> labels_paths = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> pred_paths = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
'''
for label_path, pre_path in zip(label_paths, pred_paths):
label_entities = get_entities(label_path, self.id2label, self.markup, self.middle_prefix)
pre_entities = get_entities(pre_path, self.id2label, self.markup, self.middle_prefix)
# print('label:', label_path, ',label_entities: ', label_entities)
# print('pred:', pre_path, ',pre_entities: ', pre_entities)
self.origins.extend(label_entities)
self.founds.extend(pre_entities)
self.rights.extend([pre_entity for pre_entity in pre_entities if pre_entity in label_entities])