File size: 1,326 Bytes
6a4fac9 6be1f2a 29d0f05 95d4295 29d0f05 95d4295 29d0f05 6a4fac9 95d4295 29d0f05 95d4295 29d0f05 6be1f2a 29d0f05 6be1f2a 29d0f05 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
from Levenshtein import ratio
def compute_score(predictions, ground_truths):
theta = 0.5
anls_score = 0
total = 0
for qid, prediction in predictions.items():
max_value = 0
if qid in ground_truths:
for x in ground_truths[qid]:
total += 1
nl = ratio(prediction.lower(), x.lower())
if nl < theta:
score = 1 - nl
if score > max_value:
max_value = score
anls_score += max_value
return anls_score/total
if __name__ == "__main__":
predictions = [{'question_id': '10285', 'prediction_text': 'Denver R.'},
{'question_id': '18601', 'prediction_text': '12'},
{'question_id': '16734', 'prediction_text': 'dear'}]
references = [{"answers": ["Denver Broncos", "Denver R. Broncos"], 'question_id': '10285'},
{'answers': ['12/15/88'], 'question_id': '18601'},
{'answers': ['Dear Dr. Lobo', 'Dr. Lobo'], 'question_id': '16734'}]
ground_truths = {x['question_id']: x['answers'] for x in references}
predictions = {x['question_id']: x['prediction_text'] for x in predictions}
anls_score = compute_score(predictions=predictions, ground_truths=ground_truths)
print(anls_score)
|