|
import streamlit as st |
|
import random |
|
import numpy as np |
|
import torch |
|
from transformers import AutoTokenizer, AutoModel |
|
from aligner import Aligner |
|
from utils import plot_align_matrix_heatmap |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
torch.manual_seed(42) |
|
np.random.seed(42) |
|
random.seed(42) |
|
|
|
|
|
@st.cache_resource |
|
def init_model(model: str): |
|
tokenizer = AutoTokenizer.from_pretrained(model) |
|
model = AutoModel.from_pretrained(model, output_hidden_states=True).to(device).eval() |
|
return tokenizer, model |
|
|
|
|
|
@st.cache_resource(max_entries=100) |
|
def init_aligner(ot_type: str, sinkhorn: bool, distortion: float, threshhold: float, tau: float): |
|
return Aligner( |
|
ot_type=ot_type, |
|
sinkhorn=sinkhorn, |
|
chimera=False, |
|
dist_type="cos", |
|
weight_type="uniform", |
|
distortion=distortion, |
|
thresh=threshhold, |
|
tau=tau, |
|
div_type="--" |
|
) |
|
|
|
|
|
def encode_sentence(sent, pair, tokenizer, model, layer: int): |
|
if pair == None: |
|
inputs = tokenizer(sent, padding=False, truncation=False, is_split_into_words=True, return_offsets_mapping=True, |
|
return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = model(inputs['input_ids'].to(device), inputs['attention_mask'].to(device), |
|
inputs['token_type_ids'].to(device)) |
|
else: |
|
inputs = tokenizer(text=sent, text_pair=pair, padding=False, truncation=True, |
|
is_split_into_words=True, |
|
return_offsets_mapping=True, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = model(inputs['input_ids'].to(device), inputs['attention_mask'].to(device), |
|
inputs['token_type_ids'].to(device)) |
|
|
|
return outputs.hidden_states[layer][0], inputs['input_ids'][0], inputs['offset_mapping'][0] |
|
|
|
|
|
def centering(hidden_outputs): |
|
""" |
|
hidden_outputs : [tokens, hidden_size] |
|
""" |
|
|
|
mean_vec = torch.sum(hidden_outputs, dim=0) / hidden_outputs.shape[0] |
|
hidden_outputs = hidden_outputs - mean_vec |
|
print(hidden_outputs.shape) |
|
return hidden_outputs |
|
|
|
|
|
def convert_to_word_embeddings(offset_mapping, token_ids, hidden_tensors, tokenizer, pair): |
|
word_idx = -1 |
|
subword_to_word_conv = np.full((hidden_tensors.shape[0]), -1) |
|
|
|
metaspace = getattr(tokenizer.decoder, "replacement", None) |
|
metaspace = tokenizer.decoder.prefix if metaspace is None else metaspace |
|
tokenizer_bug_idxes = [i for i, x in enumerate(tokenizer.convert_ids_to_tokens(token_ids)) if |
|
x == metaspace] |
|
|
|
for subw_idx, offset in enumerate(offset_mapping): |
|
if subw_idx in tokenizer_bug_idxes: |
|
continue |
|
elif offset[0] == offset[1]: |
|
continue |
|
elif offset[0] == 0: |
|
word_idx += 1 |
|
subword_to_word_conv[subw_idx] = word_idx |
|
else: |
|
subword_to_word_conv[subw_idx] = word_idx |
|
|
|
word_embeddings = torch.vstack( |
|
([torch.mean(hidden_tensors[subword_to_word_conv == word_idx], dim=0) for word_idx in range(word_idx + 1)])) |
|
print(word_embeddings.shape) |
|
|
|
if pair: |
|
sep_tok_indices = [i for i, x in enumerate(token_ids) if x == tokenizer.sep_token_id] |
|
s2_start_idx = subword_to_word_conv[ |
|
sep_tok_indices[0] + np.argmax(subword_to_word_conv[sep_tok_indices[0]:] > -1)] |
|
|
|
s1_word_embeddigs = word_embeddings[0:s2_start_idx, :] |
|
s2_word_embeddigs = word_embeddings[s2_start_idx:, :] |
|
|
|
return s1_word_embeddigs, s2_word_embeddigs |
|
else: |
|
return word_embeddings |
|
|
|
|
|
def main(): |
|
st.set_page_config(layout="wide") |
|
|
|
|
|
st.sidebar.markdown("## Settings & Parameters") |
|
model = st.sidebar.selectbox('model', ['microsoft/deberta-v3-base', 'bert-base-uncased']) |
|
layer = st.sidebar.slider( |
|
'layer number for embeddings', 0, 11, value=9 |
|
) |
|
is_centering = st.sidebar.checkbox('centering embeddings', value=True) |
|
ot_type = st.sidebar.selectbox('ot_type', ['OT', 'POT', 'UOT']) |
|
ot_type = ot_type.lower() |
|
sinkhorn = st.sidebar.checkbox('sinkhorn', value=True) |
|
distortion = st.sidebar.slider( |
|
'distortion: $\kappa$', 0.0, 1.0, value=0.20 |
|
) |
|
tau = st.sidebar.slider( |
|
'tau: $\\tau$', 0.0, 1.0, value=0.98 |
|
) |
|
threshhold = st.sidebar.slider( |
|
'threshhold: $\lambda$', 0.0, 1.0 |
|
) |
|
|
|
|
|
st.markdown('## Playground: Unbalanced Optimal Transport for Unbalanced Word Alignment') |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
sent1 = st.text_area( |
|
'sentence 1', |
|
'By one estimate , fewer than 20,000 lions exist in the wild , a drop of about 40 percent in the past two decades .' |
|
) |
|
with col2: |
|
sent2 = st.text_area( |
|
'sentence 2', |
|
'Today there are only around 20,000 wild lions left in the world .' |
|
) |
|
|
|
tokenizer, model = init_model(model) |
|
aligner = init_aligner(ot_type, sinkhorn, distortion, threshhold, tau) |
|
|
|
with st.container(): |
|
st.write("word alignment matrix") |
|
|
|
if sent1 != '' and sent2 != '': |
|
sent1 = sent1.lower().split() |
|
sent2 = sent2.lower().split() |
|
hidden_output, input_id, offset_map = encode_sentence(sent1, sent2, tokenizer, model, layer=layer) |
|
if is_centering: |
|
hidden_output = centering(hidden_output) |
|
s1_vec, s2_vec = convert_to_word_embeddings(offset_map, input_id, hidden_output, tokenizer, pair=True) |
|
aligner.compute_alignment_matrixes([s1_vec], [s2_vec]) |
|
align_matrix = aligner.align_matrixes[0] |
|
print(align_matrix.shape) |
|
|
|
|
|
|
|
fig = plot_align_matrix_heatmap(align_matrix.T, sent1, sent2, threshhold) |
|
st.pyplot(fig, dpi=300) |
|
|
|
st.divider() |
|
st.markdown("Note that the centering in this demo is applied only to the input sentences, so the variance may be large.") |
|
st.subheader('Refs') |
|
st.write("Yuki Arase, Han Bao, Sho Yokoi, [Unbalanced Optimal Transport for Unbalanced Word Alignment](https://arxiv.org/abs/2306.04116), ACL2023 [[github](https://github.com/yukiar/OTAlign/tree/main)]") |
|
|
|
if __name__ == '__main__': |
|
main() |