Spaces:
Sleeping
Sleeping
File size: 5,135 Bytes
05922fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
from argparse import ArgumentParser
from collections import defaultdict
from torch import nn
from copy import deepcopy
import torch
import os
import json
from sftp import SpanPredictor
import nltk
def shift_grid_cos_sim(mat: torch.Tensor):
mat1 = mat.unsqueeze(0).expand(mat.shape[0], -1, -1)
mat2 = mat.unsqueeze(1).expand(-1, mat.shape[0], -1)
cos = nn.CosineSimilarity(2)
sim = (cos(mat1, mat2) + 1) / 2
return sim
def all_frames():
nltk.download('framenet_v17')
fn = nltk.corpus.framenet
return fn.frames()
def extract_relations(fr):
ret = list()
added = {fr.name}
for rel in fr.frameRelations:
for key in ['subFrameName', 'superFrameName']:
rel_fr_name = rel[key]
if rel_fr_name in added:
continue
ret.append((rel_fr_name, key[:-4]))
return ret
def run():
parser = ArgumentParser()
parser.add_argument('archive', metavar='ARCHIVE_PATH', type=str)
parser.add_argument('dst', metavar='DESTINATION', type=str)
parser.add_argument('kairos', metavar='KAIROS', type=str)
parser.add_argument('--topk', metavar='TOPK', type=int, default=10)
args = parser.parse_args()
predictor = SpanPredictor.from_path(args.archive, cuda_device=-1)
kairos_gold_mapping = json.load(open(args.kairos))
label_emb = predictor._model._span_typing.label_emb.weight.clone().detach()
idx2label = predictor._model.vocab.get_index_to_token_vocabulary('span_label')
emb_sim = shift_grid_cos_sim(label_emb)
fr2definition = {fr.name: (fr.URL, fr.definition) for fr in all_frames()}
last_mlp = predictor._model._span_typing.MLPs[-1].weight.detach().clone()
mlp_sim = shift_grid_cos_sim(last_mlp)
def rank_frame(sim):
rank = sim.argsort(1, True)
scores = sim.gather(1, rank)
mapping = {
fr.name: {
'similarity': list(),
'ontology': extract_relations(fr),
'URL': fr.URL,
'definition': fr.definition
} for fr in all_frames()
}
for left_idx, (right_indices, match_scores) in enumerate(zip(rank, scores)):
left_label = idx2label[left_idx]
if left_label not in mapping:
continue
for right_idx, s in zip(right_indices, match_scores):
right_label = idx2label[int(right_idx)]
if right_label not in mapping or right_idx == left_idx:
continue
mapping[left_label]['similarity'].append((right_label, float(s)))
return mapping
emb_map = rank_frame(emb_sim)
mlp_map = rank_frame(mlp_sim)
def dump(mapping, folder_path):
os.makedirs(folder_path, exist_ok=True)
json.dump(mapping, open(os.path.join(folder_path, 'raw.json'), 'w'))
sim_lines, onto_lines = list(), list()
for fr, values in mapping.items():
sim_line = [
fr,
values['definition'],
values['URL'],
]
onto_line = deepcopy(sim_line)
for rel_fr_name, rel_type in values['ontology']:
onto_line.append(f'{rel_fr_name} ({rel_type})')
onto_lines.append('\t'.join(onto_line))
if len(values['similarity']) > 0:
for sim_fr_name, score in values['similarity'][:args.topk]:
sim_line.append(f'{sim_fr_name} ({score:.3f})')
sim_lines.append('\t'.join(sim_line))
with open(os.path.join(folder_path, 'similarity.tsv'), 'w') as fp:
fp.write('\n'.join(sim_lines))
with open(os.path.join(folder_path, 'ontology.tsv'), 'w') as fp:
fp.write('\n'.join(onto_lines))
kairos_dump = list()
for kairos_event, kairos_content in kairos_gold_mapping.items():
for gold_fr in kairos_content['framenet']:
gold_fr = gold_fr['label']
if gold_fr not in fr2definition:
continue
kairos_dump.append([
'GOLD',
gold_fr,
kairos_event,
fr2definition[gold_fr][0],
fr2definition[gold_fr][1],
str(kairos_content['description']),
'1.00'
])
for ass_fr, sim_score in mapping[gold_fr]['similarity'][:args.topk]:
kairos_dump.append([
'',
ass_fr,
kairos_event,
fr2definition[ass_fr][0],
fr2definition[ass_fr][1],
str(kairos_content['description']),
f'{sim_score:.2f}'
])
kairos_dump = list(map(lambda line: '\t'.join(line), kairos_dump))
open(os.path.join(folder_path, 'kairos_sheet.tsv'), 'w').write('\n'.join(kairos_dump))
dump(mlp_map, os.path.join(args.dst, 'mlp'))
dump(emb_map, os.path.join(args.dst, 'emb'))
if __name__ == '__main__':
run()
|