Compact_Facts / utils /eval_ent_rel.py
khulnasoft's picture
Upload 108 files
4fb0bd1 verified
from collections import defaultdict
import logging
import sys
logger = logging.getLogger(__name__)
class EvalCounts():
"""This class is evaluating counters
"""
def __init__(self):
self.pred_correct_cnt = 0
self.correct_cnt = 0
self.pred_cnt = 0
self.pred_correct_types_cnt = defaultdict(int)
self.correct_types_cnt = defaultdict(int)
self.pred_types_cnt = defaultdict(int)
def eval_file(file_path, eval_metrics):
"""eval_file evaluates results file
Args:
file_path (str): file path
eval_metrics (list): eval metrics
Returns:
tuple: results
"""
with open(file_path, 'r') as fin:
sents = []
metric2labels = {
'token': ['Sequence-Label-True', 'Sequence-Label-Pred'],
'ent-label': ['Ent-Label-True', 'Ent-Label-Pred'],
'rel-label': ['Rel-Label-True', 'Rel-Label-Pred'],
'separate-position': ['Separate-Position-True', 'Separate-Position-Pred'],
'span': ['Ent-Span-Pred'],
'ent': ['Ent-True', 'Ent-Pred'],
'rel': ['Rel-True', 'Rel-Pred'],
'exact-rel': ['Rel-True', 'Rel-Pred']
}
labels = set()
for metric in eval_metrics:
labels.update(metric2labels[metric])
label2idx = {label: idx for idx, label in enumerate(labels)}
sent = [[] for _ in range(len(labels))]
for line in fin:
line = line.strip('\r\n')
if line == "":
sents.append(sent)
sent = [[] for _ in range(len(labels))]
else:
words = line.split('\t')
if words[0] in ['Ent-Label-True', 'Ent-Label-Pred', 'Rel-Label-True', 'Rel-Label-Pred']:
sent[label2idx[words[0]]].extend(words[1].split(' '))
elif words[0] in ['Separate-Position-True', 'Separate-Position-Pred']:
sent[label2idx[words[0]]].append(words[1].split(' '))
elif words[0] in ['Ent-Span-Pred']:
sent[label2idx[words[0]]].append(eval(words[1]))
elif words[0] in ['Ent-True', 'Ent-Pred']:
sent[label2idx[words[0]]].append([words[1], eval(words[2])])
elif words[0] in ['Rel-True', 'Rel-Pred']:
sent[label2idx[words[0]]].append([words[1], eval(words[2]), eval(words[3])])
sents.append(sent)
counts = {metric: EvalCounts() for metric in eval_metrics}
for sent in sents:
evaluate(sent, counts, label2idx)
results = []
logger.info("-" * 22 + "START" + "-" * 23)
for metric, count in counts.items():
left_offset = (50 - len(metric)) // 2
logger.info("-" * left_offset + metric + "-" * (50 - left_offset - len(metric)))
score = report(count)
results += [score]
logger.info("-" * 23 + "END" + "-" * 24)
return results
def evaluate(sent, counts, label2idx):
"""evaluate calculates counters
Arguments:
sent {list} -- line
Args:
sent (list): line
counts (dict): counts
label2idx (dict): label -> idx dict
"""
# evaluate token
if 'token' in counts:
for token1, token2 in zip(sent[label2idx['Sequence-Label-True']], sent[label2idx['Sequence-Label-Pred']]):
if token1 != 'O':
counts['token'].correct_cnt += 1
counts['token'].correct_types_cnt[token1] += 1
counts['token'].pred_correct_types_cnt[token1] += 0
if token2 != 'O':
counts['token'].pred_cnt += 1
counts['token'].pred_types_cnt[token2] += 1
counts['token'].pred_correct_types_cnt[token2] += 0
if token1 == token2 and token1 != 'O':
counts['token'].pred_correct_cnt += 1
counts['token'].pred_correct_types_cnt[token1] += 1
# evaluate ent label
if 'ent-label' in counts:
for label1, label2 in zip(sent[label2idx['Ent-Label-True']], sent[label2idx['Ent-Label-Pred']]):
if label1 != 'None':
counts['ent-label'].correct_cnt += 1
counts['ent-label'].correct_types_cnt['Arc'] += 1
counts['ent-label'].correct_types_cnt[label1] += 1
counts['ent-label'].pred_correct_types_cnt[label1] += 0
if label2 != 'None':
counts['ent-label'].pred_cnt += 1
counts['ent-label'].pred_types_cnt['Arc'] += 1
counts['ent-label'].pred_types_cnt[label2] += 1
counts['ent-label'].pred_correct_types_cnt[label2] += 0
if label1 != 'None' and label2 != 'None':
counts['ent-label'].pred_correct_types_cnt['Arc'] += 1
if label1 == label2 and label1 != 'None':
counts['ent-label'].pred_correct_cnt += 1
counts['ent-label'].pred_correct_types_cnt[label1] += 1
# evaluate separate position
if 'separate-position' in counts:
for positions1, positions2 in zip(sent[label2idx['Separate-Position-True']],
sent[label2idx['Separate-Position-Pred']]):
counts['separate-position'].correct_cnt += len(positions1)
counts['separate-position'].pred_cnt += len(positions2)
counts['separate-position'].pred_correct_cnt += len(set(positions1) & set(positions2))
# evaluate span & entity
correct_ent2idx = defaultdict(set)
correct_span2ent = dict()
correct_span = set()
for ent, span in sent[label2idx['Ent-True']]:
correct_span.add(span)
correct_span2ent[span] = ent
correct_ent2idx[ent].add(span)
pred_ent2idx = defaultdict(set)
pred_span2ent = dict()
for ent, span in sent[label2idx['Ent-Pred']]:
pred_span2ent[span] = ent
pred_ent2idx[ent].add(span)
if 'span' in counts:
pred_span = set(sent[label2idx['Ent-Span-Pred']])
counts['span'].correct_cnt += len(correct_span)
counts['span'].pred_cnt += len(pred_span)
counts['span'].pred_correct_cnt += len(correct_span & pred_span)
if 'ent' in counts:
# TODO this part should change! if a noncontinuous entity is subset of another nc entity SCORE!
all_ents = set(correct_ent2idx) | set(pred_ent2idx)
for ent in all_ents:
counts['ent'].correct_cnt += len(correct_ent2idx[ent])
counts['ent'].correct_types_cnt[ent] += len(correct_ent2idx[ent])
counts['ent'].pred_cnt += len(pred_ent2idx[ent])
counts['ent'].pred_types_cnt[ent] += len(pred_ent2idx[ent])
pred_correct_cnt = len(correct_ent2idx[ent] & pred_ent2idx[ent])
counts['ent'].pred_correct_cnt += pred_correct_cnt
counts['ent'].pred_correct_types_cnt[ent] += pred_correct_cnt
# evaluate rel label
if 'rel-label' in counts:
for label1, label2 in zip(sent[label2idx['Rel-Label-True']], sent[label2idx['Rel-Label-Pred']]):
if label1 != 'None':
counts['rel-label'].correct_cnt += 1
counts['rel-label'].correct_types_cnt['Arc'] += 1
counts['rel-label'].correct_types_cnt[label1] += 1
counts['rel-label'].pred_correct_types_cnt[label1] += 0
if label2 != 'None':
counts['rel-label'].pred_cnt += 1
counts['rel-label'].pred_types_cnt['Arc'] += 1
counts['rel-label'].pred_types_cnt[label2] += 1
counts['rel-label'].pred_correct_types_cnt[label2] += 0
if label1 != 'None' and label2 != 'None':
counts['rel-label'].pred_correct_types_cnt['Arc'] += 1
if label1 == label2 and label1 != 'None':
counts['rel-label'].pred_correct_cnt += 1
counts['rel-label'].pred_correct_types_cnt[label1] += 1
# exact relation evaluation
if 'exact-rel' in counts:
exact_correct_rel2idx = defaultdict(set)
for rel, span1, span2 in sent[label2idx['Rel-True']]:
if span1 not in correct_span2ent or span2 not in correct_span2ent:
continue
exact_correct_rel2idx[rel].add((span1, correct_span2ent[span1], span2, correct_span2ent[span2]))
exact_pred_rel2idx = defaultdict(set)
for rel, span1, span2 in sent[label2idx['Rel-Pred']]:
if span1 not in pred_span2ent or span2 not in pred_span2ent:
continue
exact_pred_rel2idx[rel].add((span1, pred_span2ent[span1], span2, pred_span2ent[span2]))
all_exact_rels = set(exact_correct_rel2idx) | set(exact_pred_rel2idx)
for rel in all_exact_rels:
counts['exact-rel'].correct_cnt += len(exact_correct_rel2idx[rel])
counts['exact-rel'].correct_types_cnt[rel] += len(exact_correct_rel2idx[rel])
counts['exact-rel'].pred_cnt += len(exact_pred_rel2idx[rel])
counts['exact-rel'].pred_types_cnt[rel] += len(exact_pred_rel2idx[rel])
exact_pred_correct_rel_cnt = len(exact_correct_rel2idx[rel] & exact_pred_rel2idx[rel])
counts['exact-rel'].pred_correct_cnt += exact_pred_correct_rel_cnt
counts['exact-rel'].pred_correct_types_cnt[rel] += exact_pred_correct_rel_cnt
def report(counts):
"""This function print evaluation results
Arguments:
counts {dict} -- counters
Returns:
float -- f1 score
"""
p, r, f = calculate_metrics(counts.pred_correct_cnt, counts.pred_cnt, counts.correct_cnt)
logger.info("truth cnt: {} pred cnt: {} correct cnt: {}".format(counts.correct_cnt, counts.pred_cnt,
counts.pred_correct_cnt))
logger.info("precision: {:6.2f}%".format(100 * p))
logger.info("recall: {:6.2f}%".format(100 * r))
logger.info("f1: {:6.2f}%".format(100 * f))
score = f
for type in counts.pred_correct_types_cnt:
p, r, f = calculate_metrics(counts.pred_correct_types_cnt[type], counts.pred_types_cnt[type],
counts.correct_types_cnt[type])
logger.info("-" * 50)
logger.info("type: {}".format(type))
logger.info("truth cnt: {} pred cnt: {} correct cnt: {}".format(counts.correct_types_cnt[type],
counts.pred_types_cnt[type],
counts.pred_correct_types_cnt[type]))
logger.info("precision: {:6.2f}%".format(100 * p))
logger.info("recall: {:6.2f}%".format(100 * r))
logger.info("f1: {:6.2f}%".format(100 * f))
return score
def calculate_metrics(pred_correct_cnt, pred_cnt, correct_cnt):
"""This function calculation metrics: precision, recall, f1-score
Arguments:
pred_correct_cnt {int} -- the number of corrected prediction
pred_cnt {int} -- the number of prediction
correct_cnt {int} -- the numbert of truth
Returns:
tuple -- precision, recall, f1-score
"""
tp, fp, fn = pred_correct_cnt, pred_cnt - pred_correct_cnt, correct_cnt - pred_correct_cnt
p = 0 if tp + fp == 0 else (tp / (tp + fp))
r = 0 if tp + fn == 0 else (tp / (tp + fn))
f = 0 if p + r == 0 else (2 * p * r / (p + r))
return p, r, f
if __name__ == '__main__':
eval_file(sys.argv[1])