sociolome / scripts /archive /frame_similarity.py
Gosse Minnema
Initial commit
05922fb
raw
history blame
5.14 kB
from argparse import ArgumentParser
from collections import defaultdict
from torch import nn
from copy import deepcopy
import torch
import os
import json
from sftp import SpanPredictor
import nltk
def shift_grid_cos_sim(mat: torch.Tensor):
mat1 = mat.unsqueeze(0).expand(mat.shape[0], -1, -1)
mat2 = mat.unsqueeze(1).expand(-1, mat.shape[0], -1)
cos = nn.CosineSimilarity(2)
sim = (cos(mat1, mat2) + 1) / 2
return sim
def all_frames():
nltk.download('framenet_v17')
fn = nltk.corpus.framenet
return fn.frames()
def extract_relations(fr):
ret = list()
added = {fr.name}
for rel in fr.frameRelations:
for key in ['subFrameName', 'superFrameName']:
rel_fr_name = rel[key]
if rel_fr_name in added:
continue
ret.append((rel_fr_name, key[:-4]))
return ret
def run():
parser = ArgumentParser()
parser.add_argument('archive', metavar='ARCHIVE_PATH', type=str)
parser.add_argument('dst', metavar='DESTINATION', type=str)
parser.add_argument('kairos', metavar='KAIROS', type=str)
parser.add_argument('--topk', metavar='TOPK', type=int, default=10)
args = parser.parse_args()
predictor = SpanPredictor.from_path(args.archive, cuda_device=-1)
kairos_gold_mapping = json.load(open(args.kairos))
label_emb = predictor._model._span_typing.label_emb.weight.clone().detach()
idx2label = predictor._model.vocab.get_index_to_token_vocabulary('span_label')
emb_sim = shift_grid_cos_sim(label_emb)
fr2definition = {fr.name: (fr.URL, fr.definition) for fr in all_frames()}
last_mlp = predictor._model._span_typing.MLPs[-1].weight.detach().clone()
mlp_sim = shift_grid_cos_sim(last_mlp)
def rank_frame(sim):
rank = sim.argsort(1, True)
scores = sim.gather(1, rank)
mapping = {
fr.name: {
'similarity': list(),
'ontology': extract_relations(fr),
'URL': fr.URL,
'definition': fr.definition
} for fr in all_frames()
}
for left_idx, (right_indices, match_scores) in enumerate(zip(rank, scores)):
left_label = idx2label[left_idx]
if left_label not in mapping:
continue
for right_idx, s in zip(right_indices, match_scores):
right_label = idx2label[int(right_idx)]
if right_label not in mapping or right_idx == left_idx:
continue
mapping[left_label]['similarity'].append((right_label, float(s)))
return mapping
emb_map = rank_frame(emb_sim)
mlp_map = rank_frame(mlp_sim)
def dump(mapping, folder_path):
os.makedirs(folder_path, exist_ok=True)
json.dump(mapping, open(os.path.join(folder_path, 'raw.json'), 'w'))
sim_lines, onto_lines = list(), list()
for fr, values in mapping.items():
sim_line = [
fr,
values['definition'],
values['URL'],
]
onto_line = deepcopy(sim_line)
for rel_fr_name, rel_type in values['ontology']:
onto_line.append(f'{rel_fr_name} ({rel_type})')
onto_lines.append('\t'.join(onto_line))
if len(values['similarity']) > 0:
for sim_fr_name, score in values['similarity'][:args.topk]:
sim_line.append(f'{sim_fr_name} ({score:.3f})')
sim_lines.append('\t'.join(sim_line))
with open(os.path.join(folder_path, 'similarity.tsv'), 'w') as fp:
fp.write('\n'.join(sim_lines))
with open(os.path.join(folder_path, 'ontology.tsv'), 'w') as fp:
fp.write('\n'.join(onto_lines))
kairos_dump = list()
for kairos_event, kairos_content in kairos_gold_mapping.items():
for gold_fr in kairos_content['framenet']:
gold_fr = gold_fr['label']
if gold_fr not in fr2definition:
continue
kairos_dump.append([
'GOLD',
gold_fr,
kairos_event,
fr2definition[gold_fr][0],
fr2definition[gold_fr][1],
str(kairos_content['description']),
'1.00'
])
for ass_fr, sim_score in mapping[gold_fr]['similarity'][:args.topk]:
kairos_dump.append([
'',
ass_fr,
kairos_event,
fr2definition[ass_fr][0],
fr2definition[ass_fr][1],
str(kairos_content['description']),
f'{sim_score:.2f}'
])
kairos_dump = list(map(lambda line: '\t'.join(line), kairos_dump))
open(os.path.join(folder_path, 'kairos_sheet.tsv'), 'w').write('\n'.join(kairos_dump))
dump(mlp_map, os.path.join(args.dst, 'mlp'))
dump(emb_map, os.path.join(args.dst, 'emb'))
if __name__ == '__main__':
run()