File size: 7,172 Bytes
2b49fe2
efeee8a
2b49fe2
efeee8a
 
 
2b49fe2
d82123b
dd4548e
2b49fe2
e87e116
 
efeee8a
e87e116
 
 
 
efeee8a
c6dd7aa
efeee8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3bd75e
c5489ad
 
 
c6dd7aa
 
 
 
 
 
 
 
 
 
 
50ce4f4
 
 
 
 
 
2b66ae3
 
50ce4f4
 
b6390e8
2b66ae3
 
50ce4f4
2b66ae3
b6390e8
 
50ce4f4
 
 
 
 
 
 
b6390e8
 
50ce4f4
 
b6390e8
 
 
50ce4f4
b6390e8
 
 
50ce4f4
 
8f32fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6dd7aa
 
c5489ad
 
 
 
 
c6dd7aa
 
 
 
 
c5489ad
 
 
 
 
50ce4f4
c5489ad
 
 
c6dd7aa
50ce4f4
c5489ad
 
 
 
 
2b66ae3
 
57c67f2
 
50ce4f4
 
 
 
 
 
 
 
10ced5b
50ce4f4
 
 
 
 
 
10ced5b
 
50ce4f4
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import numpy as np
import pandas as pd
import time
import streamlit as st
import matplotlib.pyplot as plt
import seaborn as sns

import jax
import jax.numpy as jnp

import torch
import torch.nn.functional as F

from transformers import AlbertTokenizer, AlbertForMaskedLM

#from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
from skeleton_modeling_albert import SkeletonAlbertForMaskedLM

def wide_setup():
    max_width = 1500
    padding_top = 0
    padding_right = 2
    padding_bottom = 0
    padding_left = 2

    define_margins = f"""
    <style>
        .appview-container .main .block-container{{
            max-width: {max_width}px;
            padding-top: {padding_top}rem;
            padding-right: {padding_right}rem;
            padding-left: {padding_left}rem;
            padding-bottom: {padding_bottom}rem;
        }}
    </style>
    """
    hide_table_row_index = """
                <style>
                tbody th {display:none}
                .blank {display:none}
                </style>
                """
    st.markdown(define_margins, unsafe_allow_html=True)
    st.markdown(hide_table_row_index, unsafe_allow_html=True)

def load_css(file_name):
    with open(file_name) as f:
        st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)

@st.cache(show_spinner=True,allow_output_mutation=True)
def load_model():
    tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
    #model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
    model = AlbertForMaskedLM.from_pretrained('albert-xxlarge-v2')
    return tokenizer,model

def clear_data():
    for key in st.session_state:
        del st.session_state[key]

def annotate_mask(sent_id,sent):
    st.write(f'Sentence {sent_id}')
    input_sent = tokenizer(sent).input_ids
    decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
    char_nums = [len(word)+2 for word in decoded_sent]
    cols = st.columns(char_nums)
    if f'mask_locs_{sent_id}' not in st.session_state:
        st.session_state[f'mask_locs_{sent_id}'] = []
    for word_id,(col,word) in enumerate(zip(cols,decoded_sent)):
        with col:
            if st.button(word,key=f'word_mask_{sent_id}_{word_id}'):
                if word_id not in st.session_state[f'mask_locs_{sent_id}']:
                    st.session_state[f'mask_locs_{sent_id}'].append(word_id)
                else:
                    st.session_state[f'mask_locs_{sent_id}'].remove(word_id)
    st.markdown(show_annotated_sentence(decoded_sent,
                                        mask_locs=st.session_state[f'mask_locs_{sent_id}']), unsafe_allow_html = True)

def annotate_options(sent_id,sent):
    st.write(f'Sentence {sent_id}')
    input_sent = tokenizer(sent).input_ids
    decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
    char_nums = [len(word)+2 for word in decoded_sent]
    cols = st.columns(char_nums)
    if f'option_locs_{sent_id}' not in st.session_state:
        st.session_state[f'option_locs_{sent_id}'] = []
    for word_id,(col,word) in enumerate(zip(cols,decoded_sent)):
        with col:
            if st.button(word,key=f'word_option_{sent_id}_{word_id}'):
                if word_id not in st.session_state[f'option_locs_{sent_id}']:
                    st.session_state[f'option_locs_{sent_id}'].append(word_id)
                else:
                    st.session_state[f'option_locs_{sent_id}'].remove(word_id)
    st.markdown(show_annotated_sentence(decoded_sent,
                                        option_locs=st.session_state[f'option_locs_{sent_id}'],
                                        mask_locs=st.session_state[f'mask_locs_{sent_id}']), unsafe_allow_html = True)

def show_annotated_sentence(sent,option_locs=[],mask_locs=[]):
    disp_style = '"font-family:san serif; color:Black; font-size: 20px"'
    prefix = f'<p style={disp_style}><span style="font-weight:bold">'
    style_list = []
    for i, word in enumerate(sent):
        if i in mask_locs:
            style_list.append(f'<span style="color:Red">{word}</span>')
        elif i in option_locs:
            style_list.append(f'<span style="color:Blue">{word}</span>')
        else:
            style_list.append(f'{word}')
    disp = ' '.join(style_list)
    suffix = '</span></p>'
    return prefix + disp + suffix

if __name__=='__main__':
    wide_setup()
    load_css('style.css')
    tokenizer,model = load_model()
    mask_id = tokenizer('[MASK]').input_ids[1:-1][0]

    main_area = st.empty()

    if 'page_status' not in st.session_state:
        st.session_state['page_status'] = 'type_in'

    if st.session_state['page_status']=='type_in':
        with main_area.container():
            st.write('1. Type in the sentences and click "Tokenize"')
            sent_1 = st.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.')
            sent_2 = st.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.')
            if st.button('Tokenize'):
                st.session_state['page_status'] = 'annotate_mask'
                st.session_state['sent_1'] = sent_1
                st.session_state['sent_2'] = sent_2
                main_area.empty()

    if st.session_state['page_status']=='annotate_mask':
        with main_area.container():
            sent_1 = st.session_state['sent_1']
            sent_2 = st.session_state['sent_2']

            st.write('2. Select sites to mask out and click "Confirm"')
            annotate_mask(1,sent_1)
            annotate_mask(2,sent_2)
            st.write(st.session_state['mask_locs_1'])
            st.write(st.session_state['mask_locs_2'])
            if st.button('Confirm'):
                st.session_state['page_status'] = 'annotate_options'
                main_area.empty()

    if st.session_state['page_status'] == 'annotate_options':
        with main_area.container():
            sent_1 = st.session_state['sent_1']
            sent_2 = st.session_state['sent_2']

            st.write('2. Select options click "Confirm"')
            st.session_state[f'option_locs_1'] = annotate_options(1,sent_1)
            st.session_state[f'option_locs_2'] = annotate_options(2,sent_2)
            if st.button('Confirm'):
                st.session_state['page_status'] = 'analysis'
                main_area.empty()

    if st.session_state['page_status']=='analysis':
        with main_area.container():
            sent_1 = st.session_state['sent_1']
            sent_2 = st.session_state['sent_2']

            input_ids_1 = tokenizer(sent_1).input_ids
            input_ids_2 = tokenizer(sent_2).input_ids
            input_ids = torch.tensor([input_ids_1,input_ids_2])

            outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions = {0:{'lay':[(8,1,[0,1])]}})
            logprobs = F.log_softmax(outputs['logits'], dim = -1)
            preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0]]
            st.write([tokenizer.decode([token]) for token in preds])