|
|
|
from collections import Counter |
|
import torch |
|
from torch import nn |
|
|
|
|
|
from .utils_ner import get_entities |
|
|
|
|
|
class metrics_mlm_acc(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, logits, labels, masked_lm_metric): |
|
|
|
|
|
mask_label_size = 0 |
|
for i in masked_lm_metric: |
|
for j in i: |
|
if j > 0: |
|
mask_label_size += 1 |
|
|
|
y_pred = torch.argmax(logits, dim=-1) |
|
|
|
y_pred = y_pred.view(size=(-1,)) |
|
y_true = labels.view(size=(-1,)) |
|
masked_lm_metric = masked_lm_metric.view(size=(-1,)) |
|
|
|
corr = torch.eq(y_pred, y_true) |
|
corr = torch.multiply(masked_lm_metric, corr) |
|
|
|
acc = torch.sum(corr.float())/mask_label_size |
|
return acc |
|
|
|
|
|
class EntityScore(object): |
|
def __init__(self): |
|
self.reset() |
|
|
|
def reset(self): |
|
self.origins = [] |
|
self.founds = [] |
|
self.rights = [] |
|
|
|
def compute(self, origin, found, right): |
|
recall = 0 if origin == 0 else (right / origin) |
|
precision = 0 if found == 0 else (right / found) |
|
f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall) |
|
return recall, precision, f1 |
|
|
|
def result(self): |
|
class_info = {} |
|
|
|
origin_counter = Counter([x[0] for x in self.origins]) |
|
found_counter = Counter([x[0] for x in self.founds]) |
|
right_counter = Counter([x[0] for x in self.rights]) |
|
for type_, count in origin_counter.items(): |
|
origin = count |
|
found = found_counter.get(type_, 0) |
|
right = right_counter.get(type_, 0) |
|
recall, precision, f1 = self.compute(origin, found, right) |
|
class_info[type_] = {"acc": round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4)} |
|
origin = len(self.origins) |
|
found = len(self.founds) |
|
right = len(self.rights) |
|
recall, precision, f1 = self.compute(origin, found, right) |
|
return {'acc': precision, 'recall': recall, 'f1': f1}, class_info |
|
|
|
def update(self, true_subject, pred_subject): |
|
self.origins.extend(true_subject) |
|
self.founds.extend(pred_subject) |
|
self.rights.extend([pre_entity for pre_entity in pred_subject if pre_entity in true_subject]) |
|
|
|
class SeqEntityScore(object): |
|
def __init__(self, id2label, markup='bios', middle_prefix='I-'): |
|
self.id2label = id2label |
|
self.markup = markup |
|
self.middle_prefix = middle_prefix |
|
self.reset() |
|
|
|
def reset(self): |
|
self.origins = [] |
|
self.founds = [] |
|
self.rights = [] |
|
|
|
def compute(self, origin, found, right): |
|
recall = 0 if origin == 0 else (right / origin) |
|
precision = 0 if found == 0 else (right / found) |
|
f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall) |
|
return recall, precision, f1 |
|
|
|
def result(self): |
|
class_info = {} |
|
origin_counter = Counter([x[0] for x in self.origins]) |
|
found_counter = Counter([x[0] for x in self.founds]) |
|
right_counter = Counter([x[0] for x in self.rights]) |
|
for type_, count in origin_counter.items(): |
|
origin = count |
|
found = found_counter.get(type_, 0) |
|
right = right_counter.get(type_, 0) |
|
|
|
recall, precision, f1 = self.compute(origin, found, right) |
|
class_info[type_] = {"acc": round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4)} |
|
origin = len(self.origins) |
|
found = len(self.founds) |
|
right = len(self.rights) |
|
recall, precision, f1 = self.compute(origin, found, right) |
|
return {'acc': precision, 'recall': recall, 'f1': f1}, class_info |
|
|
|
def update(self, label_paths, pred_paths): |
|
''' |
|
labels_paths: [[],[],[],....] |
|
pred_paths: [[],[],[],.....] |
|
|
|
:param label_paths: |
|
:param pred_paths: |
|
:return: |
|
Example: |
|
>>> labels_paths = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']] |
|
>>> pred_paths = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']] |
|
''' |
|
for label_path, pre_path in zip(label_paths, pred_paths): |
|
label_entities = get_entities(label_path, self.id2label, self.markup, self.middle_prefix) |
|
pre_entities = get_entities(pre_path, self.id2label, self.markup, self.middle_prefix) |
|
|
|
|
|
self.origins.extend(label_entities) |
|
self.founds.extend(pre_entities) |
|
self.rights.extend([pre_entity for pre_entity in pre_entities if pre_entity in label_entities]) |
|
|