Spaces:
Runtime error
Runtime error
File size: 7,172 Bytes
2b49fe2 efeee8a 2b49fe2 efeee8a 2b49fe2 d82123b dd4548e 2b49fe2 e87e116 efeee8a e87e116 efeee8a c6dd7aa efeee8a d3bd75e c5489ad c6dd7aa 50ce4f4 2b66ae3 50ce4f4 b6390e8 2b66ae3 50ce4f4 2b66ae3 b6390e8 50ce4f4 b6390e8 50ce4f4 b6390e8 50ce4f4 b6390e8 50ce4f4 8f32fbf c6dd7aa c5489ad c6dd7aa c5489ad 50ce4f4 c5489ad c6dd7aa 50ce4f4 c5489ad 2b66ae3 57c67f2 50ce4f4 10ced5b 50ce4f4 10ced5b 50ce4f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import numpy as np
import pandas as pd
import time
import streamlit as st
import matplotlib.pyplot as plt
import seaborn as sns
import jax
import jax.numpy as jnp
import torch
import torch.nn.functional as F
from transformers import AlbertTokenizer, AlbertForMaskedLM
#from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
from skeleton_modeling_albert import SkeletonAlbertForMaskedLM
def wide_setup():
max_width = 1500
padding_top = 0
padding_right = 2
padding_bottom = 0
padding_left = 2
define_margins = f"""
<style>
.appview-container .main .block-container{{
max-width: {max_width}px;
padding-top: {padding_top}rem;
padding-right: {padding_right}rem;
padding-left: {padding_left}rem;
padding-bottom: {padding_bottom}rem;
}}
</style>
"""
hide_table_row_index = """
<style>
tbody th {display:none}
.blank {display:none}
</style>
"""
st.markdown(define_margins, unsafe_allow_html=True)
st.markdown(hide_table_row_index, unsafe_allow_html=True)
def load_css(file_name):
with open(file_name) as f:
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
@st.cache(show_spinner=True,allow_output_mutation=True)
def load_model():
tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
#model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
model = AlbertForMaskedLM.from_pretrained('albert-xxlarge-v2')
return tokenizer,model
def clear_data():
for key in st.session_state:
del st.session_state[key]
def annotate_mask(sent_id,sent):
st.write(f'Sentence {sent_id}')
input_sent = tokenizer(sent).input_ids
decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
char_nums = [len(word)+2 for word in decoded_sent]
cols = st.columns(char_nums)
if f'mask_locs_{sent_id}' not in st.session_state:
st.session_state[f'mask_locs_{sent_id}'] = []
for word_id,(col,word) in enumerate(zip(cols,decoded_sent)):
with col:
if st.button(word,key=f'word_mask_{sent_id}_{word_id}'):
if word_id not in st.session_state[f'mask_locs_{sent_id}']:
st.session_state[f'mask_locs_{sent_id}'].append(word_id)
else:
st.session_state[f'mask_locs_{sent_id}'].remove(word_id)
st.markdown(show_annotated_sentence(decoded_sent,
mask_locs=st.session_state[f'mask_locs_{sent_id}']), unsafe_allow_html = True)
def annotate_options(sent_id,sent):
st.write(f'Sentence {sent_id}')
input_sent = tokenizer(sent).input_ids
decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
char_nums = [len(word)+2 for word in decoded_sent]
cols = st.columns(char_nums)
if f'option_locs_{sent_id}' not in st.session_state:
st.session_state[f'option_locs_{sent_id}'] = []
for word_id,(col,word) in enumerate(zip(cols,decoded_sent)):
with col:
if st.button(word,key=f'word_option_{sent_id}_{word_id}'):
if word_id not in st.session_state[f'option_locs_{sent_id}']:
st.session_state[f'option_locs_{sent_id}'].append(word_id)
else:
st.session_state[f'option_locs_{sent_id}'].remove(word_id)
st.markdown(show_annotated_sentence(decoded_sent,
option_locs=st.session_state[f'option_locs_{sent_id}'],
mask_locs=st.session_state[f'mask_locs_{sent_id}']), unsafe_allow_html = True)
def show_annotated_sentence(sent,option_locs=[],mask_locs=[]):
disp_style = '"font-family:san serif; color:Black; font-size: 20px"'
prefix = f'<p style={disp_style}><span style="font-weight:bold">'
style_list = []
for i, word in enumerate(sent):
if i in mask_locs:
style_list.append(f'<span style="color:Red">{word}</span>')
elif i in option_locs:
style_list.append(f'<span style="color:Blue">{word}</span>')
else:
style_list.append(f'{word}')
disp = ' '.join(style_list)
suffix = '</span></p>'
return prefix + disp + suffix
if __name__=='__main__':
wide_setup()
load_css('style.css')
tokenizer,model = load_model()
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
main_area = st.empty()
if 'page_status' not in st.session_state:
st.session_state['page_status'] = 'type_in'
if st.session_state['page_status']=='type_in':
with main_area.container():
st.write('1. Type in the sentences and click "Tokenize"')
sent_1 = st.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.')
sent_2 = st.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.')
if st.button('Tokenize'):
st.session_state['page_status'] = 'annotate_mask'
st.session_state['sent_1'] = sent_1
st.session_state['sent_2'] = sent_2
main_area.empty()
if st.session_state['page_status']=='annotate_mask':
with main_area.container():
sent_1 = st.session_state['sent_1']
sent_2 = st.session_state['sent_2']
st.write('2. Select sites to mask out and click "Confirm"')
annotate_mask(1,sent_1)
annotate_mask(2,sent_2)
st.write(st.session_state['mask_locs_1'])
st.write(st.session_state['mask_locs_2'])
if st.button('Confirm'):
st.session_state['page_status'] = 'annotate_options'
main_area.empty()
if st.session_state['page_status'] == 'annotate_options':
with main_area.container():
sent_1 = st.session_state['sent_1']
sent_2 = st.session_state['sent_2']
st.write('2. Select options click "Confirm"')
st.session_state[f'option_locs_1'] = annotate_options(1,sent_1)
st.session_state[f'option_locs_2'] = annotate_options(2,sent_2)
if st.button('Confirm'):
st.session_state['page_status'] = 'analysis'
main_area.empty()
if st.session_state['page_status']=='analysis':
with main_area.container():
sent_1 = st.session_state['sent_1']
sent_2 = st.session_state['sent_2']
input_ids_1 = tokenizer(sent_1).input_ids
input_ids_2 = tokenizer(sent_2).input_ids
input_ids = torch.tensor([input_ids_1,input_ids_2])
outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions = {0:{'lay':[(8,1,[0,1])]}})
logprobs = F.log_softmax(outputs['logits'], dim = -1)
preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0]]
st.write([tokenizer.decode([token]) for token in preds])
|