Spaces:
Sleeping
Sleeping
from argparse import ArgumentParser | |
from collections import defaultdict | |
from torch import nn | |
from copy import deepcopy | |
import torch | |
import os | |
import json | |
from sftp import SpanPredictor | |
import nltk | |
def shift_grid_cos_sim(mat: torch.Tensor): | |
mat1 = mat.unsqueeze(0).expand(mat.shape[0], -1, -1) | |
mat2 = mat.unsqueeze(1).expand(-1, mat.shape[0], -1) | |
cos = nn.CosineSimilarity(2) | |
sim = (cos(mat1, mat2) + 1) / 2 | |
return sim | |
def all_frames(): | |
nltk.download('framenet_v17') | |
fn = nltk.corpus.framenet | |
return fn.frames() | |
def extract_relations(fr): | |
ret = list() | |
added = {fr.name} | |
for rel in fr.frameRelations: | |
for key in ['subFrameName', 'superFrameName']: | |
rel_fr_name = rel[key] | |
if rel_fr_name in added: | |
continue | |
ret.append((rel_fr_name, key[:-4])) | |
return ret | |
def run(): | |
parser = ArgumentParser() | |
parser.add_argument('archive', metavar='ARCHIVE_PATH', type=str) | |
parser.add_argument('dst', metavar='DESTINATION', type=str) | |
parser.add_argument('kairos', metavar='KAIROS', type=str) | |
parser.add_argument('--topk', metavar='TOPK', type=int, default=10) | |
args = parser.parse_args() | |
predictor = SpanPredictor.from_path(args.archive, cuda_device=-1) | |
kairos_gold_mapping = json.load(open(args.kairos)) | |
label_emb = predictor._model._span_typing.label_emb.weight.clone().detach() | |
idx2label = predictor._model.vocab.get_index_to_token_vocabulary('span_label') | |
emb_sim = shift_grid_cos_sim(label_emb) | |
fr2definition = {fr.name: (fr.URL, fr.definition) for fr in all_frames()} | |
last_mlp = predictor._model._span_typing.MLPs[-1].weight.detach().clone() | |
mlp_sim = shift_grid_cos_sim(last_mlp) | |
def rank_frame(sim): | |
rank = sim.argsort(1, True) | |
scores = sim.gather(1, rank) | |
mapping = { | |
fr.name: { | |
'similarity': list(), | |
'ontology': extract_relations(fr), | |
'URL': fr.URL, | |
'definition': fr.definition | |
} for fr in all_frames() | |
} | |
for left_idx, (right_indices, match_scores) in enumerate(zip(rank, scores)): | |
left_label = idx2label[left_idx] | |
if left_label not in mapping: | |
continue | |
for right_idx, s in zip(right_indices, match_scores): | |
right_label = idx2label[int(right_idx)] | |
if right_label not in mapping or right_idx == left_idx: | |
continue | |
mapping[left_label]['similarity'].append((right_label, float(s))) | |
return mapping | |
emb_map = rank_frame(emb_sim) | |
mlp_map = rank_frame(mlp_sim) | |
def dump(mapping, folder_path): | |
os.makedirs(folder_path, exist_ok=True) | |
json.dump(mapping, open(os.path.join(folder_path, 'raw.json'), 'w')) | |
sim_lines, onto_lines = list(), list() | |
for fr, values in mapping.items(): | |
sim_line = [ | |
fr, | |
values['definition'], | |
values['URL'], | |
] | |
onto_line = deepcopy(sim_line) | |
for rel_fr_name, rel_type in values['ontology']: | |
onto_line.append(f'{rel_fr_name} ({rel_type})') | |
onto_lines.append('\t'.join(onto_line)) | |
if len(values['similarity']) > 0: | |
for sim_fr_name, score in values['similarity'][:args.topk]: | |
sim_line.append(f'{sim_fr_name} ({score:.3f})') | |
sim_lines.append('\t'.join(sim_line)) | |
with open(os.path.join(folder_path, 'similarity.tsv'), 'w') as fp: | |
fp.write('\n'.join(sim_lines)) | |
with open(os.path.join(folder_path, 'ontology.tsv'), 'w') as fp: | |
fp.write('\n'.join(onto_lines)) | |
kairos_dump = list() | |
for kairos_event, kairos_content in kairos_gold_mapping.items(): | |
for gold_fr in kairos_content['framenet']: | |
gold_fr = gold_fr['label'] | |
if gold_fr not in fr2definition: | |
continue | |
kairos_dump.append([ | |
'GOLD', | |
gold_fr, | |
kairos_event, | |
fr2definition[gold_fr][0], | |
fr2definition[gold_fr][1], | |
str(kairos_content['description']), | |
'1.00' | |
]) | |
for ass_fr, sim_score in mapping[gold_fr]['similarity'][:args.topk]: | |
kairos_dump.append([ | |
'', | |
ass_fr, | |
kairos_event, | |
fr2definition[ass_fr][0], | |
fr2definition[ass_fr][1], | |
str(kairos_content['description']), | |
f'{sim_score:.2f}' | |
]) | |
kairos_dump = list(map(lambda line: '\t'.join(line), kairos_dump)) | |
open(os.path.join(folder_path, 'kairos_sheet.tsv'), 'w').write('\n'.join(kairos_dump)) | |
dump(mlp_map, os.path.join(args.dst, 'mlp')) | |
dump(emb_map, os.path.join(args.dst, 'emb')) | |
if __name__ == '__main__': | |
run() | |