File size: 5,135 Bytes
05922fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from argparse import ArgumentParser
from collections import defaultdict

from torch import nn
from copy import deepcopy
import torch
import os
import json

from sftp import SpanPredictor
import nltk


def shift_grid_cos_sim(mat: torch.Tensor):
    mat1 = mat.unsqueeze(0).expand(mat.shape[0], -1, -1)
    mat2 = mat.unsqueeze(1).expand(-1, mat.shape[0], -1)
    cos = nn.CosineSimilarity(2)
    sim = (cos(mat1, mat2) + 1) / 2
    return sim


def all_frames():
    nltk.download('framenet_v17')
    fn = nltk.corpus.framenet
    return fn.frames()


def extract_relations(fr):
    ret = list()
    added = {fr.name}
    for rel in fr.frameRelations:
        for key in ['subFrameName', 'superFrameName']:
            rel_fr_name = rel[key]
            if rel_fr_name in added:
                continue
            ret.append((rel_fr_name, key[:-4]))
    return ret


def run():
    parser = ArgumentParser()
    parser.add_argument('archive', metavar='ARCHIVE_PATH', type=str)
    parser.add_argument('dst', metavar='DESTINATION', type=str)
    parser.add_argument('kairos', metavar='KAIROS', type=str)
    parser.add_argument('--topk', metavar='TOPK', type=int, default=10)
    args = parser.parse_args()

    predictor = SpanPredictor.from_path(args.archive, cuda_device=-1)
    kairos_gold_mapping = json.load(open(args.kairos))

    label_emb = predictor._model._span_typing.label_emb.weight.clone().detach()
    idx2label = predictor._model.vocab.get_index_to_token_vocabulary('span_label')

    emb_sim = shift_grid_cos_sim(label_emb)
    fr2definition = {fr.name: (fr.URL, fr.definition) for fr in all_frames()}

    last_mlp = predictor._model._span_typing.MLPs[-1].weight.detach().clone()
    mlp_sim = shift_grid_cos_sim(last_mlp)

    def rank_frame(sim):
        rank = sim.argsort(1, True)
        scores = sim.gather(1, rank)
        mapping = {
            fr.name: {
                'similarity': list(),
                'ontology': extract_relations(fr),
                'URL': fr.URL,
                'definition': fr.definition
            } for fr in all_frames()
        }
        for left_idx, (right_indices, match_scores) in enumerate(zip(rank, scores)):
            left_label = idx2label[left_idx]
            if left_label not in mapping:
                continue
            for right_idx, s in zip(right_indices, match_scores):
                right_label = idx2label[int(right_idx)]
                if right_label not in mapping or right_idx == left_idx:
                    continue
                mapping[left_label]['similarity'].append((right_label, float(s)))
        return mapping

    emb_map = rank_frame(emb_sim)
    mlp_map = rank_frame(mlp_sim)

    def dump(mapping, folder_path):
        os.makedirs(folder_path, exist_ok=True)
        json.dump(mapping, open(os.path.join(folder_path, 'raw.json'), 'w'))
        sim_lines, onto_lines = list(), list()

        for fr, values in mapping.items():
            sim_line = [
                fr,
                values['definition'],
                values['URL'],
            ]
            onto_line = deepcopy(sim_line)
            for rel_fr_name, rel_type in values['ontology']:
                onto_line.append(f'{rel_fr_name} ({rel_type})')
                onto_lines.append('\t'.join(onto_line))
            if len(values['similarity']) > 0:
                for sim_fr_name, score in values['similarity'][:args.topk]:
                    sim_line.append(f'{sim_fr_name} ({score:.3f})')
                sim_lines.append('\t'.join(sim_line))

        with open(os.path.join(folder_path, 'similarity.tsv'), 'w') as fp:
            fp.write('\n'.join(sim_lines))
        with open(os.path.join(folder_path, 'ontology.tsv'), 'w') as fp:
            fp.write('\n'.join(onto_lines))

        kairos_dump = list()
        for kairos_event, kairos_content in kairos_gold_mapping.items():
            for gold_fr in kairos_content['framenet']:
                gold_fr = gold_fr['label']
                if gold_fr not in fr2definition:
                    continue
                kairos_dump.append([
                    'GOLD',
                    gold_fr,
                    kairos_event,
                    fr2definition[gold_fr][0],
                    fr2definition[gold_fr][1],
                    str(kairos_content['description']),
                    '1.00'
                ])
                for ass_fr, sim_score in mapping[gold_fr]['similarity'][:args.topk]:
                    kairos_dump.append([
                        '',
                        ass_fr,
                        kairos_event,
                        fr2definition[ass_fr][0],
                        fr2definition[ass_fr][1],
                        str(kairos_content['description']),
                        f'{sim_score:.2f}'
                    ])
        kairos_dump = list(map(lambda line: '\t'.join(line), kairos_dump))
        open(os.path.join(folder_path, 'kairos_sheet.tsv'), 'w').write('\n'.join(kairos_dump))

    dump(mlp_map, os.path.join(args.dst, 'mlp'))
    dump(emb_map, os.path.join(args.dst, 'emb'))


if __name__ == '__main__':
    run()