import random import numpy as np import streamlit as st import torch import umap from nltk.tokenize import word_tokenize from transformers import AutoModel, AutoTokenizer from aligner import Aligner from plotools import ( plot_align_matrix_heatmap_plotly, plot_similarity_matrix_heatmap_plotly, show_assignments_plotly, ) from utils import centering, convert_to_word_embeddings, encode_sentence device = "cuda" if torch.cuda.is_available() else "cpu" torch.manual_seed(42) np.random.seed(42) random.seed(42) import nltk nltk.download("punkt") @st.cache_resource def init_model(model: str): tokenizer = AutoTokenizer.from_pretrained(model) model = ( AutoModel.from_pretrained(model, output_hidden_states=True).to(device).eval() ) return tokenizer, model @st.cache_resource(max_entries=100) def init_aligner( ot_type: str, sinkhorn: bool, distortion: float, threshhold: float, tau: float ): return Aligner( ot_type=ot_type, sinkhorn=sinkhorn, dist_type="cos", weight_type="uniform", distortion=distortion, thresh=threshhold, tau=tau, div_type="--", ) def main(): st.set_page_config(layout="wide") # Sidebar st.sidebar.markdown("## Settings & Parameters") model = st.sidebar.selectbox( "model", ["microsoft/deberta-v3-base", "bert-base-uncased"] ) layer = st.sidebar.slider( "layer number for embeddings", 0, 11, value=9, ) is_centering = st.sidebar.checkbox("centering embeddings", value=True) ot_type = st.sidebar.selectbox( "ot_type", ["POT", "UOT", "OT"], help="optimal transport algorithm to be used" ) ot_type = ot_type.lower() sinkhorn = st.sidebar.checkbox( "sinkhorn", value=True, help="use sinkhorn algorithm" ) distortion = st.sidebar.slider( "distortion: $\kappa$", 0.0, 1.0, value=0.20, help="suppression of off-diagonal alignments", ) tau = st.sidebar.slider( "m / $\\tau$", 0.0, 1.0, value=0.98, help="fraction of fertility to be aligned (fraction of mass to be transported) / penalties", ) threshhold = st.sidebar.slider( "threshhold: $\lambda$", 0.0, 1.0, value=0.22, help="sparsity of alignment matrix", ) show_assignments = st.sidebar.checkbox("show assignments", value=True) if show_assignments: n_neighbors = st.sidebar.slider( "n_neighbors", 2, 10, value=8, help="number of neighbors for umap" ) # Content st.markdown( "## Playground: Unbalanced Optimal Transport for Unbalanced Word Alignment" ) col1, col2 = st.columns(2) with col1: sent1 = st.text_area( "sentence 1", "By one estimate, fewer than 20,000 lions exist in the wild, a drop of about 40 percent in the past two decades.", help="Initial text", ) with col2: sent2 = st.text_area( "sentence 2", "Today there are only around 20,000 wild lions left in the world.", help="Text to compare", ) tokenizer, model = init_model(model) aligner = init_aligner(ot_type, sinkhorn, distortion, threshhold, tau) with st.container(): if sent1 != '' and sent2 != '': sent1 = word_tokenize(sent1.lower()) sent2 = word_tokenize(sent2.lower()) print(sent1) print(sent2) hidden_output, input_id, offset_map = encode_sentence(sent1, sent2, tokenizer, model, layer=layer) if is_centering: hidden_output = centering(hidden_output) s1_vec, s2_vec = convert_to_word_embeddings(offset_map, input_id, hidden_output, tokenizer, pair=True) align_matrix, cost_matrix, loss, similarity_matrix = aligner.compute_alignment_matrixes(s1_vec, s2_vec) print(align_matrix.shape, cost_matrix.shape) st.write(f"**word alignment matrix** (loss: :blue[{loss}])") fig = plot_align_matrix_heatmap_plotly(align_matrix.T, sent1, sent2, threshhold, cost_matrix.T) st.plotly_chart(fig, use_container_width=True) st.write(f"**word similarity matrix**") fig2 = plot_similarity_matrix_heatmap_plotly(similarity_matrix.T, sent1, sent2, cost_matrix.T) st.plotly_chart(fig2, use_container_width=True) if show_assignments: st.write(f"**Alignments after UMAP**") word_embeddings = torch.vstack([s1_vec, s2_vec]) umap_embeddings = umap.UMAP( n_neighbors=n_neighbors, n_components=2, random_state=42, metric="cosine", ).fit_transform(word_embeddings.detach().numpy()) print(umap_embeddings.shape) fig3 = show_assignments_plotly( align_matrix, umap_embeddings, sent1, sent2, thr=threshhold ) st.plotly_chart(fig3, use_container_width=True) st.divider() st.subheader('Refs') st.write("Yuki Arase, Han Bao, Sho Yokoi, [Unbalanced Optimal Transport for Unbalanced Word Alignment](https://arxiv.org/abs/2306.04116), ACL2023 [[github](https://github.com/yukiar/OTAlign/tree/main)]") if __name__ == '__main__': main()