File size: 4,059 Bytes
05922fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from argparse import ArgumentParser
from typing import *

from flask import Flask
from flask import request

from sftp import SpanPredictor, Span

parser = ArgumentParser()
parser.add_argument('model', metavar='MODEL_PATH', type=str)
parser.add_argument('-p', metavar='PORT', type=int, default=7749)
parser.add_argument('-d', metavar='DEVICE', type=int, default=-1)
args = parser.parse_args()

template = open('tools/demo/flask_template.html').read()
predictor = SpanPredictor.from_path(args.model, cuda_device=args.d)
app = Flask(__name__)
default_sentence = '因为 आरजू です vegan , هي купил soja .'


def visualized_prediction(inputs: List[str], prediction: Span, prefix=''):
    spans = list()
    span2event = [[] for _ in inputs]
    for event_idx, event in enumerate(prediction):
        for arg_idx, arg in enumerate(event):
            for token_idx in range(arg.start_idx, arg.end_idx+1):
                span2event[token_idx].append((event_idx, arg_idx))

    for token_idx, token in enumerate(inputs):
        class_labels = ' '.join(
            ['token'] + [f'{prefix}-arg-{event_idx}-{arg_idx}' for event_idx, arg_idx in span2event[token_idx]]
        )
        spans.append(f'<span id="{prefix}-token-{token_idx}" class="{class_labels}" style="background-color">{token} </span>\n')

    for event_idx, event in enumerate(prediction):
        spans[event.start_idx] = (
            f'<span class="highlight bottom blue" '
            f' onmouseenter="highlight_args({event_idx}, \'{prefix}\')" onmouseleave="cancel_highlight(\'{prefix}\')">'
            '<span class="highlight__content" align="center">'
            f'<span class="event" id="{prefix}-event-{event_idx}">'
            + spans[event.start_idx]
        )
        spans[event.end_idx] += f'</span></span><span class="highlight__label"><center>{event.label}</center></span>'
        arg_tips = []
        for arg_idx, arg in enumerate(event):
            arg_tips.append(f'<span class="{prefix}-arg-{event_idx}-{arg_idx}">{arg.label}</span>')
        if len(arg_tips) > 0:
            arg_tips = '<br>'.join(arg_tips)
            spans[event.end_idx] += f'<span class="highlight__tooltip">{arg_tips}</span>\n'
        spans[event.end_idx] += '\n</span>'
    return(
            '<div class="passage model__content__summary highlight-container highlight-container--bottom-labels">\n' +
            '\n'.join(spans) + '\n</div>'
    )


def structured_prediction(inputs, prediction):
    ret = list()
    for event in prediction:
        event_text, event_label = ' '.join(inputs[event.start_idx: event.end_idx+1]), event.label
        ret.append(f'<li class="list-group-item list-group-item-info">'
                   f'<strong>{event_label}</strong>: {event_text}</li>')
        for arg in event:
            arg_text = ' '.join(inputs[arg.start_idx: arg.end_idx+1])
            ret.append(
                f'<li class="list-group-item">&nbsp;&nbsp;&nbsp;&nbsp;<strong>{arg.label}</strong>: {arg_text}</li>'
            )
    content = '\n'.join(ret)
    return '\n<ul class="list-group">\n' + content + '\n</ul>'


@app.route('/')
def sftp():
    ret = template
    tokens = request.args.get('sentence')
    if tokens is not None:
        ret = ret.replace('DEFAULT_SENTENCE', tokens)
        sentences = tokens.split('\n')
        model_outputs = predictor.predict_batch_sentences(sentences, max_tokens=512)
        vis_pred, str_pred = list(), list()
        for sent_idx, output in enumerate(model_outputs):
            vis_pred.append(visualized_prediction(output.sentence, output.span, f'sent{sent_idx}'))
            str_pred.append(structured_prediction(output.sentence, output.span))
        ret = ret.replace('VISUALIZED_PREDICTION', '<hr>'.join(vis_pred))
        ret = ret.replace('STRUCTURED_PREDICTION', '<hr>'.join(str_pred))
    else:
        ret = ret.replace('DEFAULT_SENTENCE', default_sentence)
        ret = ret.replace('VISUALIZED_PREDICTION', '')
        ret = ret.replace('STRUCTURED_PREDICTION', '')
    return ret


app.run(port=args.p)