Spaces:
Runtime error
Runtime error
import numpy as np | |
import pandas as pd | |
import time | |
import streamlit as st | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import jax | |
import jax.numpy as jnp | |
import torch | |
import torch.nn.functional as F | |
from transformers import AlbertTokenizer, AlbertForMaskedLM | |
#from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM | |
from skeleton_modeling_albert import SkeletonAlbertForMaskedLM | |
def wide_setup(): | |
max_width = 1500 | |
padding_top = 0 | |
padding_right = 2 | |
padding_bottom = 0 | |
padding_left = 2 | |
define_margins = f""" | |
<style> | |
.appview-container .main .block-container{{ | |
max-width: {max_width}px; | |
padding-top: {padding_top}rem; | |
padding-right: {padding_right}rem; | |
padding-left: {padding_left}rem; | |
padding-bottom: {padding_bottom}rem; | |
}} | |
</style> | |
""" | |
hide_table_row_index = """ | |
<style> | |
tbody th {display:none} | |
.blank {display:none} | |
</style> | |
""" | |
st.markdown(define_margins, unsafe_allow_html=True) | |
st.markdown(hide_table_row_index, unsafe_allow_html=True) | |
def load_css(file_name): | |
with open(file_name) as f: | |
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True) | |
def load_model(): | |
tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2') | |
#model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True) | |
model = AlbertForMaskedLM.from_pretrained('albert-xxlarge-v2') | |
return tokenizer,model | |
def clear_data(): | |
for key in st.session_state: | |
del st.session_state[key] | |
def annotate_mask(sent_id,sent): | |
st.write(f'Sentence {sent_id}') | |
input_sent = tokenizer(sent).input_ids | |
decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]] | |
char_nums = [len(word)+2 for word in decoded_sent] | |
cols = st.columns(char_nums) | |
if f'mask_locs_{sent_id}' not in st.session_state: | |
st.session_state[f'mask_locs_{sent_id}'] = [] | |
for word_id,(col,word) in enumerate(zip(cols,decoded_sent)): | |
with col: | |
if st.button(word,key=f'word_mask_{sent_id}_{word_id}'): | |
if word_id not in st.session_state[f'mask_locs_{sent_id}']: | |
st.session_state[f'mask_locs_{sent_id}'].append(word_id) | |
else: | |
st.session_state[f'mask_locs_{sent_id}'].remove(word_id) | |
st.markdown(show_annotated_sentence(decoded_sent, | |
mask_locs=st.session_state[f'mask_locs_{sent_id}']), unsafe_allow_html = True) | |
def annotate_options(sent_id,sent): | |
st.write(f'Sentence {sent_id}') | |
input_sent = tokenizer(sent).input_ids | |
decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]] | |
char_nums = [len(word)+2 for word in decoded_sent] | |
cols = st.columns(char_nums) | |
if f'option_locs_{sent_id}' not in st.session_state: | |
st.session_state[f'option_locs_{sent_id}'] = [] | |
for word_id,(col,word) in enumerate(zip(cols,decoded_sent)): | |
with col: | |
if st.button(word,key=f'word_option_{sent_id}_{word_id}'): | |
if word_id not in st.session_state[f'option_locs_{sent_id}']: | |
st.session_state[f'option_locs_{sent_id}'].append(word_id) | |
else: | |
st.session_state[f'option_locs_{sent_id}'].remove(word_id) | |
st.markdown(show_annotated_sentence(decoded_sent, | |
option_locs=st.session_state[f'option_locs_{sent_id}'], | |
mask_locs=st.session_state[f'mask_locs_{sent_id}']), unsafe_allow_html = True) | |
def show_annotated_sentence(sent,option_locs=[],mask_locs=[]): | |
disp_style = '"font-family:san serif; color:Black; font-size: 20px"' | |
prefix = f'<p style={disp_style}><span style="font-weight:bold">' | |
style_list = [] | |
for i, word in enumerate(sent): | |
if i in mask_locs: | |
style_list.append(f'<span style="color:Red">{word}</span>') | |
elif i in option_locs: | |
style_list.append(f'<span style="color:Blue">{word}</span>') | |
else: | |
style_list.append(f'{word}') | |
disp = ' '.join(style_list) | |
suffix = '</span></p>' | |
return prefix + disp + suffix | |
if __name__=='__main__': | |
wide_setup() | |
load_css('style.css') | |
tokenizer,model = load_model() | |
mask_id = tokenizer('[MASK]').input_ids[1:-1][0] | |
main_area = st.empty() | |
if 'page_status' not in st.session_state: | |
st.session_state['page_status'] = 'type_in' | |
if st.session_state['page_status']=='type_in': | |
with main_area.container(): | |
st.write('1. Type in the sentences and click "Tokenize"') | |
sent_1 = st.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.') | |
sent_2 = st.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.') | |
if st.button('Tokenize'): | |
st.session_state['page_status'] = 'annotate_mask' | |
st.session_state['sent_1'] = sent_1 | |
st.session_state['sent_2'] = sent_2 | |
main_area.empty() | |
if st.session_state['page_status']=='annotate_mask': | |
with main_area.container(): | |
sent_1 = st.session_state['sent_1'] | |
sent_2 = st.session_state['sent_2'] | |
st.write('2. Select sites to mask out and click "Confirm"') | |
annotate_mask(1,sent_1) | |
annotate_mask(2,sent_2) | |
st.write(st.session_state['mask_locs_1']) | |
st.write(st.session_state['mask_locs_2']) | |
if st.button('Confirm'): | |
st.session_state['page_status'] = 'annotate_options' | |
main_area.empty() | |
if st.session_state['page_status'] == 'annotate_options': | |
with main_area.container(): | |
sent_1 = st.session_state['sent_1'] | |
sent_2 = st.session_state['sent_2'] | |
st.write('2. Select options click "Confirm"') | |
st.session_state[f'option_locs_1'] = annotate_options(1,sent_1) | |
st.session_state[f'option_locs_2'] = annotate_options(2,sent_2) | |
if st.button('Confirm'): | |
st.session_state['page_status'] = 'analysis' | |
main_area.empty() | |
if st.session_state['page_status']=='analysis': | |
with main_area.container(): | |
sent_1 = st.session_state['sent_1'] | |
sent_2 = st.session_state['sent_2'] | |
input_ids_1 = tokenizer(sent_1).input_ids | |
input_ids_2 = tokenizer(sent_2).input_ids | |
input_ids = torch.tensor([input_ids_1,input_ids_2]) | |
outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions = {0:{'lay':[(8,1,[0,1])]}}) | |
logprobs = F.log_softmax(outputs['logits'], dim = -1) | |
preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0]] | |
st.write([tokenizer.decode([token]) for token in preds]) | |