import streamlit as st import random import numpy as np import torch from transformers import AutoTokenizer, AutoModel from aligner import Aligner from utils import plot_align_matrix_heatmap device = "cuda" if torch.cuda.is_available() else "cpu" torch.manual_seed(42) np.random.seed(42) random.seed(42) @st.cache_resource def init_model(model: str): tokenizer = AutoTokenizer.from_pretrained(model) model = AutoModel.from_pretrained(model, output_hidden_states=True).to(device).eval() return tokenizer, model @st.cache_resource(max_entries=100) def init_aligner(ot_type: str, sinkhorn: bool, distortion: float, threshhold: float, tau: float): return Aligner( ot_type=ot_type, sinkhorn=sinkhorn, chimera=False, dist_type="cos", weight_type="uniform", distortion=distortion, thresh=threshhold, tau=tau, div_type="--" ) def encode_sentence(sent, pair, tokenizer, model, layer: int): if pair == None: inputs = tokenizer(sent, padding=False, truncation=False, is_split_into_words=True, return_offsets_mapping=True, return_tensors="pt") with torch.no_grad(): outputs = model(inputs['input_ids'].to(device), inputs['attention_mask'].to(device), inputs['token_type_ids'].to(device)) else: inputs = tokenizer(text=sent, text_pair=pair, padding=False, truncation=True, is_split_into_words=True, return_offsets_mapping=True, return_tensors="pt") with torch.no_grad(): outputs = model(inputs['input_ids'].to(device), inputs['attention_mask'].to(device), inputs['token_type_ids'].to(device)) return outputs.hidden_states[layer][0], inputs['input_ids'][0], inputs['offset_mapping'][0] def centering(hidden_outputs): """ hidden_outputs : [tokens, hidden_size] """ # 全てのトークンの埋め込みについて足し上げ、その平均ベクトルを求める mean_vec = torch.sum(hidden_outputs, dim=0) / hidden_outputs.shape[0] hidden_outputs = hidden_outputs - mean_vec print(hidden_outputs.shape) return hidden_outputs def convert_to_word_embeddings(offset_mapping, token_ids, hidden_tensors, tokenizer, pair): word_idx = -1 subword_to_word_conv = np.full((hidden_tensors.shape[0]), -1) # Bug in hugging face tokenizer? Sometimes Metaspace is inserted metaspace = getattr(tokenizer.decoder, "replacement", None) metaspace = tokenizer.decoder.prefix if metaspace is None else metaspace tokenizer_bug_idxes = [i for i, x in enumerate(tokenizer.convert_ids_to_tokens(token_ids)) if x == metaspace] for subw_idx, offset in enumerate(offset_mapping): if subw_idx in tokenizer_bug_idxes: continue elif offset[0] == offset[1]: # Special token continue elif offset[0] == 0: word_idx += 1 subword_to_word_conv[subw_idx] = word_idx else: subword_to_word_conv[subw_idx] = word_idx word_embeddings = torch.vstack( ([torch.mean(hidden_tensors[subword_to_word_conv == word_idx], dim=0) for word_idx in range(word_idx + 1)])) print(word_embeddings.shape) if pair: sep_tok_indices = [i for i, x in enumerate(token_ids) if x == tokenizer.sep_token_id] s2_start_idx = subword_to_word_conv[ sep_tok_indices[0] + np.argmax(subword_to_word_conv[sep_tok_indices[0]:] > -1)] s1_word_embeddigs = word_embeddings[0:s2_start_idx, :] s2_word_embeddigs = word_embeddings[s2_start_idx:, :] return s1_word_embeddigs, s2_word_embeddigs else: return word_embeddings def main(): st.set_page_config(layout="wide") # Sidebar st.sidebar.markdown("## Settings & Parameters") model = st.sidebar.selectbox('model', ['microsoft/deberta-v3-base', 'bert-base-uncased']) layer = st.sidebar.slider( 'layer number for embeddings', 0, 11, value=9 ) is_centering = st.sidebar.checkbox('centering embeddings', value=True) ot_type = st.sidebar.selectbox('ot_type', ['OT', 'POT', 'UOT']) ot_type = ot_type.lower() sinkhorn = st.sidebar.checkbox('sinkhorn', value=True) distortion = st.sidebar.slider( 'distortion: $\kappa$', 0.0, 1.0, value=0.20 ) tau = st.sidebar.slider( 'tau: $\\tau$', 0.0, 1.0, value=0.98 ) # with 0.02 interva threshhold = st.sidebar.slider( 'threshhold: $\lambda$', 0.0, 1.0 ) # with 0.01 interval # Content st.markdown('## Playground: Unbalanced Optimal Transport for Unbalanced Word Alignment') col1, col2 = st.columns(2) with col1: sent1 = st.text_area( 'sentence 1', 'By one estimate , fewer than 20,000 lions exist in the wild , a drop of about 40 percent in the past two decades .' ) with col2: sent2 = st.text_area( 'sentence 2', 'Today there are only around 20,000 wild lions left in the world .' ) tokenizer, model = init_model(model) aligner = init_aligner(ot_type, sinkhorn, distortion, threshhold, tau) with st.container(): st.write("word alignment matrix") if sent1 != '' and sent2 != '': sent1 = sent1.lower().split() sent2 = sent2.lower().split() hidden_output, input_id, offset_map = encode_sentence(sent1, sent2, tokenizer, model, layer=layer) if is_centering: hidden_output = centering(hidden_output) s1_vec, s2_vec = convert_to_word_embeddings(offset_map, input_id, hidden_output, tokenizer, pair=True) aligner.compute_alignment_matrixes([s1_vec], [s2_vec]) align_matrix = aligner.align_matrixes[0] print(align_matrix.shape) #fig = align_matrix_heatmap(align_matrix.T, sent1, sent2, threshhold) #st.plotly_chart(fig, use_container_width=True) fig = plot_align_matrix_heatmap(align_matrix.T, sent1, sent2, threshhold) st.pyplot(fig, dpi=300) st.divider() st.markdown("Note that the centering in this demo is applied only to the input sentences, so the variance may be large.") st.subheader('Refs') st.write("Yuki Arase, Han Bao, Sho Yokoi, [Unbalanced Optimal Transport for Unbalanced Word Alignment](https://arxiv.org/abs/2306.04116), ACL2023 [[github](https://github.com/yukiar/OTAlign/tree/main)]") if __name__ == '__main__': main()