UOT / app.py
4kasha
add link
35ebeb7
import random
import numpy as np
import streamlit as st
import nltk
import torch
import umap
from nltk.tokenize import word_tokenize
from transformers import AutoModel, AutoTokenizer
from aligner import Aligner
from plotools import (
plot_align_matrix_heatmap_plotly,
plot_similarity_matrix_heatmap_plotly,
show_assignments_plotly,
)
from utils import centering, convert_to_word_embeddings, encode_sentence
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
nltk.download("punkt")
@st.cache_resource
def init_model(model: str):
tokenizer = AutoTokenizer.from_pretrained(model)
model = (
AutoModel.from_pretrained(model, output_hidden_states=True).to(device).eval()
)
return tokenizer, model
@st.cache_resource(max_entries=100)
def init_aligner(
ot_type: str, sinkhorn: bool, distortion: float, threshhold: float, tau: float
):
return Aligner(
ot_type=ot_type,
sinkhorn=sinkhorn,
dist_type="cos",
weight_type="uniform",
distortion=distortion,
thresh=threshhold,
tau=tau,
div_type="--",
)
def main():
st.set_page_config(layout="wide")
# Sidebar
st.sidebar.markdown("## Settings & Parameters")
model = st.sidebar.selectbox(
"model", ["microsoft/deberta-v3-base", "bert-base-uncased"]
)
layer = st.sidebar.slider(
"layer number for embeddings",
0,
11,
value=9,
)
is_centering = st.sidebar.checkbox("centering embeddings", value=True)
ot_type = st.sidebar.radio(
"type",
["OT", "POT", "UOT"],
index=1,
horizontal=True,
help="optimal transport algorithm to be used",
)
ot_type = ot_type.lower()
sinkhorn = st.sidebar.checkbox(
"sinkhorn", value=True, help="use sinkhorn algorithm"
)
distortion = st.sidebar.slider(
"distortion: $\kappa$",
0.0,
1.0,
value=0.20,
help="suppression of off-diagonal alignments",
)
tau = st.sidebar.slider(
"m / $\\tau$",
0.0,
1.0,
value=0.98,
help="fraction of fertility to be aligned (fraction of mass to be transported) / penalties",
)
threshhold = st.sidebar.slider(
"threshhold: $\lambda$",
0.0,
1.0,
value=0.22,
help="sparsity of alignment matrix",
)
show_similarity = st.sidebar.checkbox("show similarity matrix", value=True)
show_assignments = st.sidebar.checkbox("show assignments", value=True)
if show_assignments:
n_neighbors = st.sidebar.slider(
"n_neighbors (see [details](https://umap-learn.readthedocs.io/en/latest/parameters.html#n-neighbors).)",
2, 15, value=8,
help="number of nearest neighbors for umap balancing between the preservation of local and global structures"
)
# Content
st.markdown(
"## Playground: Unbalanced Optimal Transport for Unbalanced Word Alignment"
)
col1, col2 = st.columns(2)
with col1:
sent1 = st.text_area(
"sentence 1",
"By one estimate, fewer than 20,000 lions exist in the wild, a drop of about 40 percent in the past two decades.",
help="Initial text",
)
with col2:
sent2 = st.text_area(
"sentence 2",
"Today there are only around 20,000 wild lions left in the world.",
help="Text to compare",
)
tokenizer, model = init_model(model)
aligner = init_aligner(ot_type, sinkhorn, distortion, threshhold, tau)
with st.container():
if sent1 != '' and sent2 != '':
sent1 = word_tokenize(sent1.lower())
sent2 = word_tokenize(sent2.lower())
print(sent1)
print(sent2)
hidden_output, input_id, offset_map = encode_sentence(sent1, sent2, tokenizer, model, layer=layer)
if is_centering:
hidden_output = centering(hidden_output)
s1_vec, s2_vec = convert_to_word_embeddings(offset_map, input_id, hidden_output, tokenizer, pair=True)
align_matrix, cost_matrix, loss, similarity_matrix = aligner.compute_alignment_matrixes(s1_vec, s2_vec)
print(align_matrix.shape, cost_matrix.shape)
st.write(f"**Word alignment matrix** (loss: :blue[{loss}])")
fig = plot_align_matrix_heatmap_plotly(align_matrix.T, sent1, sent2, threshhold, cost_matrix.T)
st.plotly_chart(fig, use_container_width=True)
if show_similarity:
st.write(f"**Word similarity matrix**")
fig2 = plot_similarity_matrix_heatmap_plotly(similarity_matrix.T, sent1, sent2, cost_matrix.T)
st.plotly_chart(fig2, use_container_width=True)
if show_assignments:
st.write(f"**Alignments after UMAP**")
word_embeddings = torch.vstack([s1_vec, s2_vec])
umap_embeddings = umap.UMAP(
n_neighbors=n_neighbors,
n_components=2,
random_state=42,
metric="cosine",
).fit_transform(word_embeddings.detach().numpy())
print(umap_embeddings.shape)
fig3 = show_assignments_plotly(
align_matrix, umap_embeddings, sent1, sent2, thr=threshhold
)
st.plotly_chart(fig3, use_container_width=True)
st.divider()
st.subheader('Refs')
st.write("Yuki Arase, Han Bao, Sho Yokoi, [Unbalanced Optimal Transport for Unbalanced Word Alignment](https://arxiv.org/abs/2306.04116), ACL2023 [[github](https://github.com/yukiar/OTAlign/tree/main)]")
if __name__ == '__main__':
main()