4kasha
commited on
Commit
•
94f5fd3
1
Parent(s):
e7088f8
update demo
Browse files- aligner.py +34 -47
- app.py +45 -91
- otfuncs.py +68 -0
- plotools.py +129 -0
- requirements.txt +4 -3
- utils.py +64 -100
aligner.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import numpy as np
|
2 |
import torch
|
3 |
import ot
|
4 |
-
from
|
5 |
compute_distance_matrix_cosine,
|
6 |
compute_distance_matrix_l2,
|
7 |
compute_weights_norm,
|
@@ -30,55 +30,36 @@ class Aligner:
|
|
30 |
else:
|
31 |
self.weight_func = compute_weights_norm
|
32 |
|
33 |
-
def compute_alignment_matrixes(self,
|
34 |
-
self.
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
def get_alignments(self, thresh, assign_cost=False):
|
43 |
-
assert len(self.align_matrixes) > 0
|
44 |
-
|
45 |
-
self.thresh = thresh
|
46 |
-
all_alignments = []
|
47 |
-
for P in self.align_matrixes:
|
48 |
-
alignments = self.matrix_to_alignments(P, assign_cost)
|
49 |
-
all_alignments.append(alignments)
|
50 |
-
|
51 |
-
return all_alignments
|
52 |
-
|
53 |
-
def matrix_to_alignments(self, P, assign_cost):
|
54 |
-
alignments = set()
|
55 |
-
align_pairs = np.transpose(np.nonzero(P > self.thresh))
|
56 |
-
if assign_cost:
|
57 |
-
for i_j in align_pairs:
|
58 |
-
alignments.add('{0}-{1}-{2:.4f}'.format(i_j[0], i_j[1], P[i_j[0], i_j[1]]))
|
59 |
-
else:
|
60 |
-
for i_j in align_pairs:
|
61 |
-
alignments.add('{0}-{1}'.format(i_j[0], i_j[1]))
|
62 |
-
|
63 |
-
return alignments
|
64 |
|
65 |
def compute_optimal_transport(self, s1_word_embeddigs, s2_word_embeddigs):
|
66 |
s1_word_embeddigs = s1_word_embeddigs.to(torch.float64)
|
67 |
s2_word_embeddigs = s2_word_embeddigs.to(torch.float64)
|
68 |
|
69 |
-
C = self.dist_func(s1_word_embeddigs, s2_word_embeddigs, self.distotion)
|
70 |
s1_weights, s2_weights = self.weight_func(s1_word_embeddigs, s2_word_embeddigs)
|
71 |
|
72 |
if self.ot_type == 'ot':
|
73 |
s1_weights = s1_weights / s1_weights.sum()
|
74 |
s2_weights = s2_weights / s2_weights.sum()
|
75 |
-
s1_weights, s2_weights, C = self.
|
76 |
|
77 |
if self.sinkhorn:
|
78 |
-
P = ot.bregman.sinkhorn_log(
|
79 |
-
|
|
|
|
|
|
|
80 |
else:
|
81 |
-
P = ot.emd(s1_weights, s2_weights, C)
|
82 |
# Min-max normalization
|
83 |
P = min_max_scaling(P)
|
84 |
|
@@ -89,16 +70,18 @@ class Aligner:
|
|
89 |
else:
|
90 |
m = self.tau
|
91 |
|
92 |
-
s1_weights, s2_weights, C = self.
|
93 |
m = np.min((np.sum(s1_weights), np.sum(s2_weights))) * m
|
94 |
|
95 |
if self.sinkhorn:
|
96 |
-
P = ot.partial.entropic_partial_wasserstein(
|
97 |
-
|
98 |
-
|
|
|
|
|
99 |
else:
|
100 |
# To cope with round error
|
101 |
-
P = ot.partial.partial_wasserstein(s1_weights, s2_weights, C, m=m)
|
102 |
# Min-max normalization
|
103 |
P = min_max_scaling(P)
|
104 |
|
@@ -109,20 +92,24 @@ class Aligner:
|
|
109 |
tau = self.tau
|
110 |
|
111 |
if self.ot_type == 'uot':
|
112 |
-
P = ot.unbalanced.sinkhorn_stabilized_unbalanced(
|
113 |
-
|
|
|
|
|
114 |
elif self.ot_type == 'uot-mm':
|
115 |
-
P = ot.unbalanced.mm_unbalanced(
|
116 |
-
|
|
|
|
|
117 |
# Min-max normalization
|
118 |
P = min_max_scaling(P)
|
119 |
|
120 |
elif self.ot_type == 'none':
|
121 |
P = 1 - C
|
122 |
|
123 |
-
return P
|
124 |
|
125 |
-
def
|
126 |
if torch.is_tensor(s1_weights):
|
127 |
s1_weights = s1_weights.to('cpu').numpy()
|
128 |
s2_weights = s2_weights.to('cpu').numpy()
|
|
|
1 |
import numpy as np
|
2 |
import torch
|
3 |
import ot
|
4 |
+
from otfuncs import (
|
5 |
compute_distance_matrix_cosine,
|
6 |
compute_distance_matrix_l2,
|
7 |
compute_weights_norm,
|
|
|
30 |
else:
|
31 |
self.weight_func = compute_weights_norm
|
32 |
|
33 |
+
def compute_alignment_matrixes(self, s1_word_embeddigs, s2_word_embeddigs):
|
34 |
+
P, Cost, log, similarity_matrix = self.compute_optimal_transport(s1_word_embeddigs, s2_word_embeddigs)
|
35 |
+
print(log.keys())
|
36 |
+
if torch.is_tensor(P):
|
37 |
+
P = P.to('cpu').numpy()
|
38 |
+
loss = log.get('cost', 'NotImplemented')
|
39 |
|
40 |
+
return P, Cost, loss, similarity_matrix
|
41 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
def compute_optimal_transport(self, s1_word_embeddigs, s2_word_embeddigs):
|
44 |
s1_word_embeddigs = s1_word_embeddigs.to(torch.float64)
|
45 |
s2_word_embeddigs = s2_word_embeddigs.to(torch.float64)
|
46 |
|
47 |
+
C, similarity_matrix = self.dist_func(s1_word_embeddigs, s2_word_embeddigs, self.distotion)
|
48 |
s1_weights, s2_weights = self.weight_func(s1_word_embeddigs, s2_word_embeddigs)
|
49 |
|
50 |
if self.ot_type == 'ot':
|
51 |
s1_weights = s1_weights / s1_weights.sum()
|
52 |
s2_weights = s2_weights / s2_weights.sum()
|
53 |
+
s1_weights, s2_weights, C = self.convert_to_numpy(s1_weights, s2_weights, C)
|
54 |
|
55 |
if self.sinkhorn:
|
56 |
+
P, log = ot.bregman.sinkhorn_log(
|
57 |
+
s1_weights, s2_weights, C,
|
58 |
+
reg=self.epsilon, stopThr=self.stopThr,
|
59 |
+
numItermax=self.numItermax, log=True
|
60 |
+
)
|
61 |
else:
|
62 |
+
P, log = ot.emd(s1_weights, s2_weights, C, log=True)
|
63 |
# Min-max normalization
|
64 |
P = min_max_scaling(P)
|
65 |
|
|
|
70 |
else:
|
71 |
m = self.tau
|
72 |
|
73 |
+
s1_weights, s2_weights, C = self.convert_to_numpy(s1_weights, s2_weights, C)
|
74 |
m = np.min((np.sum(s1_weights), np.sum(s2_weights))) * m
|
75 |
|
76 |
if self.sinkhorn:
|
77 |
+
P, log = ot.partial.entropic_partial_wasserstein(
|
78 |
+
s1_weights, s2_weights, C,
|
79 |
+
reg=self.epsilon,
|
80 |
+
m=m, stopThr=self.stopThr, numItermax=self.numItermax, log=True
|
81 |
+
)
|
82 |
else:
|
83 |
# To cope with round error
|
84 |
+
P, log = ot.partial.partial_wasserstein(s1_weights, s2_weights, C, m=m, log=True)
|
85 |
# Min-max normalization
|
86 |
P = min_max_scaling(P)
|
87 |
|
|
|
92 |
tau = self.tau
|
93 |
|
94 |
if self.ot_type == 'uot':
|
95 |
+
P, log = ot.unbalanced.sinkhorn_stabilized_unbalanced(
|
96 |
+
s1_weights, s2_weights, C, reg=self.epsilon, reg_m=tau,
|
97 |
+
stopThr=self.stopThr, numItermax=self.numItermax, log=True
|
98 |
+
)
|
99 |
elif self.ot_type == 'uot-mm':
|
100 |
+
P, log = ot.unbalanced.mm_unbalanced(
|
101 |
+
s1_weights, s2_weights, C, reg_m=tau, div=self.div_type,
|
102 |
+
stopThr=self.stopThr, numItermax=self.numItermax, log=True
|
103 |
+
)
|
104 |
# Min-max normalization
|
105 |
P = min_max_scaling(P)
|
106 |
|
107 |
elif self.ot_type == 'none':
|
108 |
P = 1 - C
|
109 |
|
110 |
+
return P, C, log, similarity_matrix
|
111 |
|
112 |
+
def convert_to_numpy(self, s1_weights, s2_weights, C):
|
113 |
if torch.is_tensor(s1_weights):
|
114 |
s1_weights = s1_weights.to('cpu').numpy()
|
115 |
s2_weights = s2_weights.to('cpu').numpy()
|
app.py
CHANGED
@@ -2,14 +2,22 @@ import streamlit as st
|
|
2 |
import random
|
3 |
import numpy as np
|
4 |
import torch
|
|
|
5 |
from transformers import AutoTokenizer, AutoModel
|
6 |
from aligner import Aligner
|
7 |
-
from utils import
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
torch.manual_seed(42)
|
11 |
np.random.seed(42)
|
12 |
random.seed(42)
|
|
|
|
|
13 |
|
14 |
|
15 |
@st.cache_resource
|
@@ -34,72 +42,6 @@ def init_aligner(ot_type: str, sinkhorn: bool, distortion: float, threshhold: fl
|
|
34 |
)
|
35 |
|
36 |
|
37 |
-
def encode_sentence(sent, pair, tokenizer, model, layer: int):
|
38 |
-
if pair == None:
|
39 |
-
inputs = tokenizer(sent, padding=False, truncation=False, is_split_into_words=True, return_offsets_mapping=True,
|
40 |
-
return_tensors="pt")
|
41 |
-
with torch.no_grad():
|
42 |
-
outputs = model(inputs['input_ids'].to(device), inputs['attention_mask'].to(device),
|
43 |
-
inputs['token_type_ids'].to(device))
|
44 |
-
else:
|
45 |
-
inputs = tokenizer(text=sent, text_pair=pair, padding=False, truncation=True,
|
46 |
-
is_split_into_words=True,
|
47 |
-
return_offsets_mapping=True, return_tensors="pt")
|
48 |
-
with torch.no_grad():
|
49 |
-
outputs = model(inputs['input_ids'].to(device), inputs['attention_mask'].to(device),
|
50 |
-
inputs['token_type_ids'].to(device))
|
51 |
-
|
52 |
-
return outputs.hidden_states[layer][0], inputs['input_ids'][0], inputs['offset_mapping'][0]
|
53 |
-
|
54 |
-
|
55 |
-
def centering(hidden_outputs):
|
56 |
-
"""
|
57 |
-
hidden_outputs : [tokens, hidden_size]
|
58 |
-
"""
|
59 |
-
# 全てのトークンの埋め込みについて足し上げ、その平均ベクトルを求める
|
60 |
-
mean_vec = torch.sum(hidden_outputs, dim=0) / hidden_outputs.shape[0]
|
61 |
-
hidden_outputs = hidden_outputs - mean_vec
|
62 |
-
print(hidden_outputs.shape)
|
63 |
-
return hidden_outputs
|
64 |
-
|
65 |
-
|
66 |
-
def convert_to_word_embeddings(offset_mapping, token_ids, hidden_tensors, tokenizer, pair):
|
67 |
-
word_idx = -1
|
68 |
-
subword_to_word_conv = np.full((hidden_tensors.shape[0]), -1)
|
69 |
-
# Bug in hugging face tokenizer? Sometimes Metaspace is inserted
|
70 |
-
metaspace = getattr(tokenizer.decoder, "replacement", None)
|
71 |
-
metaspace = tokenizer.decoder.prefix if metaspace is None else metaspace
|
72 |
-
tokenizer_bug_idxes = [i for i, x in enumerate(tokenizer.convert_ids_to_tokens(token_ids)) if
|
73 |
-
x == metaspace]
|
74 |
-
|
75 |
-
for subw_idx, offset in enumerate(offset_mapping):
|
76 |
-
if subw_idx in tokenizer_bug_idxes:
|
77 |
-
continue
|
78 |
-
elif offset[0] == offset[1]: # Special token
|
79 |
-
continue
|
80 |
-
elif offset[0] == 0:
|
81 |
-
word_idx += 1
|
82 |
-
subword_to_word_conv[subw_idx] = word_idx
|
83 |
-
else:
|
84 |
-
subword_to_word_conv[subw_idx] = word_idx
|
85 |
-
|
86 |
-
word_embeddings = torch.vstack(
|
87 |
-
([torch.mean(hidden_tensors[subword_to_word_conv == word_idx], dim=0) for word_idx in range(word_idx + 1)]))
|
88 |
-
print(word_embeddings.shape)
|
89 |
-
|
90 |
-
if pair:
|
91 |
-
sep_tok_indices = [i for i, x in enumerate(token_ids) if x == tokenizer.sep_token_id]
|
92 |
-
s2_start_idx = subword_to_word_conv[
|
93 |
-
sep_tok_indices[0] + np.argmax(subword_to_word_conv[sep_tok_indices[0]:] > -1)]
|
94 |
-
|
95 |
-
s1_word_embeddigs = word_embeddings[0:s2_start_idx, :]
|
96 |
-
s2_word_embeddigs = word_embeddings[s2_start_idx:, :]
|
97 |
-
|
98 |
-
return s1_word_embeddigs, s2_word_embeddigs
|
99 |
-
else:
|
100 |
-
return word_embeddings
|
101 |
-
|
102 |
-
|
103 |
def main():
|
104 |
st.set_page_config(layout="wide")
|
105 |
|
@@ -107,21 +49,30 @@ def main():
|
|
107 |
st.sidebar.markdown("## Settings & Parameters")
|
108 |
model = st.sidebar.selectbox('model', ['microsoft/deberta-v3-base', 'bert-base-uncased'])
|
109 |
layer = st.sidebar.slider(
|
110 |
-
|
111 |
)
|
112 |
is_centering = st.sidebar.checkbox('centering embeddings', value=True)
|
113 |
-
ot_type = st.sidebar.selectbox(
|
|
|
|
|
|
|
114 |
ot_type = ot_type.lower()
|
115 |
-
sinkhorn = st.sidebar.checkbox(
|
|
|
|
|
|
|
116 |
distortion = st.sidebar.slider(
|
117 |
-
|
|
|
118 |
)
|
119 |
tau = st.sidebar.slider(
|
120 |
-
|
121 |
-
|
|
|
122 |
threshhold = st.sidebar.slider(
|
123 |
-
|
124 |
-
|
|
|
125 |
|
126 |
# Content
|
127 |
st.markdown('## Playground: Unbalanced Optimal Transport for Unbalanced Word Alignment')
|
@@ -130,39 +81,42 @@ def main():
|
|
130 |
|
131 |
with col1:
|
132 |
sent1 = st.text_area(
|
133 |
-
|
134 |
-
|
|
|
135 |
)
|
136 |
with col2:
|
137 |
sent2 = st.text_area(
|
138 |
-
|
139 |
-
|
|
|
140 |
)
|
141 |
|
142 |
tokenizer, model = init_model(model)
|
143 |
aligner = init_aligner(ot_type, sinkhorn, distortion, threshhold, tau)
|
144 |
|
145 |
with st.container():
|
146 |
-
st.write("word alignment matrix")
|
147 |
-
|
148 |
if sent1 != '' and sent2 != '':
|
149 |
-
sent1 = sent1.lower()
|
150 |
-
sent2 = sent2.lower()
|
|
|
|
|
151 |
hidden_output, input_id, offset_map = encode_sentence(sent1, sent2, tokenizer, model, layer=layer)
|
152 |
if is_centering:
|
153 |
hidden_output = centering(hidden_output)
|
154 |
s1_vec, s2_vec = convert_to_word_embeddings(offset_map, input_id, hidden_output, tokenizer, pair=True)
|
155 |
-
aligner.compute_alignment_matrixes(
|
156 |
-
align_matrix
|
157 |
-
print(align_matrix.shape)
|
158 |
|
159 |
-
|
160 |
-
|
161 |
-
fig =
|
162 |
-
|
|
|
|
|
|
|
163 |
|
164 |
st.divider()
|
165 |
-
st.markdown("Note that the centering in this demo is applied only to the input sentences, so the variance may be large.")
|
166 |
st.subheader('Refs')
|
167 |
st.write("Yuki Arase, Han Bao, Sho Yokoi, [Unbalanced Optimal Transport for Unbalanced Word Alignment](https://arxiv.org/abs/2306.04116), ACL2023 [[github](https://github.com/yukiar/OTAlign/tree/main)]")
|
168 |
|
|
|
2 |
import random
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
+
from nltk.tokenize import word_tokenize
|
6 |
from transformers import AutoTokenizer, AutoModel
|
7 |
from aligner import Aligner
|
8 |
+
from utils import (
|
9 |
+
encode_sentence,
|
10 |
+
centering,
|
11 |
+
convert_to_word_embeddings
|
12 |
+
)
|
13 |
+
from plotools import plot_align_matrix_heatmap_plotly, plot_similarity_matrix_heatmap_plotly
|
14 |
|
15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
torch.manual_seed(42)
|
17 |
np.random.seed(42)
|
18 |
random.seed(42)
|
19 |
+
import nltk
|
20 |
+
nltk.download('punkt')
|
21 |
|
22 |
|
23 |
@st.cache_resource
|
|
|
42 |
)
|
43 |
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
def main():
|
46 |
st.set_page_config(layout="wide")
|
47 |
|
|
|
49 |
st.sidebar.markdown("## Settings & Parameters")
|
50 |
model = st.sidebar.selectbox('model', ['microsoft/deberta-v3-base', 'bert-base-uncased'])
|
51 |
layer = st.sidebar.slider(
|
52 |
+
'layer number for embeddings', 0, 11, value=9,
|
53 |
)
|
54 |
is_centering = st.sidebar.checkbox('centering embeddings', value=True)
|
55 |
+
ot_type = st.sidebar.selectbox(
|
56 |
+
'ot_type', ['OT', 'POT', 'UOT'],
|
57 |
+
help="optimal transport algorithm to be used"
|
58 |
+
)
|
59 |
ot_type = ot_type.lower()
|
60 |
+
sinkhorn = st.sidebar.checkbox(
|
61 |
+
'sinkhorn', value=True,
|
62 |
+
help="use sinkhorn algorithm"
|
63 |
+
)
|
64 |
distortion = st.sidebar.slider(
|
65 |
+
'distortion: $\kappa$', 0.0, 1.0, value=0.20,
|
66 |
+
help="suppression of off-diagonal alignments"
|
67 |
)
|
68 |
tau = st.sidebar.slider(
|
69 |
+
'm / $\\tau$', 0.0, 1.0, value=0.98,
|
70 |
+
help="fraction of fertility to be aligned (fraction of mass to be transported) / penalties"
|
71 |
+
)
|
72 |
threshhold = st.sidebar.slider(
|
73 |
+
'threshhold: $\lambda$', 0.0, 1.0, value=0.22,
|
74 |
+
help="sparsity of alignment matrix"
|
75 |
+
)
|
76 |
|
77 |
# Content
|
78 |
st.markdown('## Playground: Unbalanced Optimal Transport for Unbalanced Word Alignment')
|
|
|
81 |
|
82 |
with col1:
|
83 |
sent1 = st.text_area(
|
84 |
+
'sentence 1',
|
85 |
+
'By one estimate, fewer than 20,000 lions exist in the wild, a drop of about 40 percent in the past two decades.',
|
86 |
+
help="Initial text"
|
87 |
)
|
88 |
with col2:
|
89 |
sent2 = st.text_area(
|
90 |
+
'sentence 2',
|
91 |
+
'Today there are only around 20,000 wild lions left in the world.',
|
92 |
+
help="Text to compare"
|
93 |
)
|
94 |
|
95 |
tokenizer, model = init_model(model)
|
96 |
aligner = init_aligner(ot_type, sinkhorn, distortion, threshhold, tau)
|
97 |
|
98 |
with st.container():
|
|
|
|
|
99 |
if sent1 != '' and sent2 != '':
|
100 |
+
sent1 = word_tokenize(sent1.lower())
|
101 |
+
sent2 = word_tokenize(sent2.lower())
|
102 |
+
print(sent1)
|
103 |
+
print(sent2)
|
104 |
hidden_output, input_id, offset_map = encode_sentence(sent1, sent2, tokenizer, model, layer=layer)
|
105 |
if is_centering:
|
106 |
hidden_output = centering(hidden_output)
|
107 |
s1_vec, s2_vec = convert_to_word_embeddings(offset_map, input_id, hidden_output, tokenizer, pair=True)
|
108 |
+
align_matrix, cost_matrix, loss, similarity_matrix = aligner.compute_alignment_matrixes(s1_vec, s2_vec)
|
109 |
+
print(align_matrix.shape, cost_matrix.shape)
|
|
|
110 |
|
111 |
+
st.write(f"**word alignment matrix** (loss: :blue[{loss}])")
|
112 |
+
fig = plot_align_matrix_heatmap_plotly(align_matrix.T, sent1, sent2, threshhold, cost_matrix.T)
|
113 |
+
st.plotly_chart(fig, use_container_width=True)
|
114 |
+
|
115 |
+
st.write(f"**word similarity matrix**")
|
116 |
+
fig2 = plot_similarity_matrix_heatmap_plotly(similarity_matrix.T, sent1, sent2, cost_matrix.T)
|
117 |
+
st.plotly_chart(fig2, use_container_width=True)
|
118 |
|
119 |
st.divider()
|
|
|
120 |
st.subheader('Refs')
|
121 |
st.write("Yuki Arase, Han Bao, Sho Yokoi, [Unbalanced Optimal Transport for Unbalanced Word Alignment](https://arxiv.org/abs/2306.04116), ACL2023 [[github](https://github.com/yukiar/OTAlign/tree/main)]")
|
122 |
|
otfuncs.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from ot.backend import get_backend
|
5 |
+
|
6 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
7 |
+
|
8 |
+
def compute_distance_matrix_cosine(s1_word_embeddigs, s2_word_embeddigs, distortion_ratio):
|
9 |
+
sim_matrix = (torch.matmul(F.normalize(s1_word_embeddigs), F.normalize(s2_word_embeddigs).t()) + 1.0) / 2 # Range 0-1
|
10 |
+
C = apply_distortion(sim_matrix, distortion_ratio)
|
11 |
+
C = min_max_scaling(C) # Range 0-1
|
12 |
+
C = 1.0 - C # Convert to distance
|
13 |
+
|
14 |
+
return C, sim_matrix
|
15 |
+
|
16 |
+
|
17 |
+
def compute_distance_matrix_l2(s1_word_embeddigs, s2_word_embeddigs, distortion_ratio):
|
18 |
+
C = torch.cdist(s1_word_embeddigs, s2_word_embeddigs, p=2)
|
19 |
+
C = min_max_scaling(C) # Range 0-1
|
20 |
+
C = 1.0 - C # Convert to similarity
|
21 |
+
C = apply_distortion(C, distortion_ratio)
|
22 |
+
C = min_max_scaling(C) # Range 0-1
|
23 |
+
C = 1.0 - C # Convert to distance
|
24 |
+
|
25 |
+
return C
|
26 |
+
|
27 |
+
|
28 |
+
def apply_distortion(sim_matrix, ratio):
|
29 |
+
shape = sim_matrix.shape
|
30 |
+
if (shape[0] < 2 or shape[1] < 2) or ratio == 0.0:
|
31 |
+
return sim_matrix
|
32 |
+
|
33 |
+
pos_x = torch.tensor([[y / float(shape[1] - 1) for y in range(shape[1])] for x in range(shape[0])],
|
34 |
+
device=device)
|
35 |
+
pos_y = torch.tensor([[x / float(shape[0] - 1) for x in range(shape[0])] for y in range(shape[1])],
|
36 |
+
device=device)
|
37 |
+
distortion_mask = 1.0 - ((pos_x - pos_y.T) ** 2) * ratio
|
38 |
+
|
39 |
+
sim_matrix = torch.mul(sim_matrix, distortion_mask)
|
40 |
+
|
41 |
+
return sim_matrix
|
42 |
+
|
43 |
+
|
44 |
+
def compute_weights_norm(s1_word_embeddigs, s2_word_embeddigs):
|
45 |
+
s1_weights = torch.norm(s1_word_embeddigs, dim=1)
|
46 |
+
s2_weights = torch.norm(s2_word_embeddigs, dim=1)
|
47 |
+
return s1_weights, s2_weights
|
48 |
+
|
49 |
+
|
50 |
+
def compute_weights_uniform(s1_word_embeddigs, s2_word_embeddigs):
|
51 |
+
s1_weights = torch.ones(s1_word_embeddigs.shape[0], dtype=torch.float64, device=device)
|
52 |
+
s2_weights = torch.ones(s2_word_embeddigs.shape[0], dtype=torch.float64, device=device)
|
53 |
+
|
54 |
+
# # Uniform weights to make L2 norm=1
|
55 |
+
# s1_weights /= torch.linalg.norm(s1_weights)
|
56 |
+
# s2_weights /= torch.linalg.norm(s2_weights)
|
57 |
+
|
58 |
+
return s1_weights, s2_weights
|
59 |
+
|
60 |
+
|
61 |
+
def min_max_scaling(C):
|
62 |
+
eps = 1e-10
|
63 |
+
# Min-max scaling for stabilization
|
64 |
+
nx = get_backend(C)
|
65 |
+
C_min = nx.min(C)
|
66 |
+
C_max = nx.max(C)
|
67 |
+
C = (C - C_min + eps) / (C_max - C_min + eps)
|
68 |
+
return C
|
plotools.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import plotly.graph_objects as go
|
3 |
+
|
4 |
+
|
5 |
+
def _debug_non_unique_axis_values(sent1: list[str], sent2: list[str]):
|
6 |
+
"""
|
7 |
+
solution:
|
8 |
+
using zero-width-space
|
9 |
+
cf. https://github.com/plotly/plotly.js/issues/1516#issuecomment-983090013
|
10 |
+
"""
|
11 |
+
sent1 = [word + i*'\u200b' for i, word in enumerate(sent1)]
|
12 |
+
sent2 = [word + i*'\u200b' for i, word in enumerate(sent2)]
|
13 |
+
|
14 |
+
return sent1, sent2
|
15 |
+
|
16 |
+
|
17 |
+
def discrete_colorscale(bvals, colors):
|
18 |
+
"""
|
19 |
+
bvals - list of values bounding intervals/ranges of interest
|
20 |
+
colors - list of rgb or hex colorcodes for values in [bvals[k], bvals[k+1]],0 <= k < len(bvals)-1
|
21 |
+
returns the plotly discrete colorscale
|
22 |
+
ref. https://community.plotly.com/t/colors-for-discrete-ranges-in-heatmaps/7780
|
23 |
+
"""
|
24 |
+
if len(bvals) != len(colors)+1:
|
25 |
+
raise ValueError('len(boundary values) should be equal to len(colors)+1')
|
26 |
+
bvals = sorted(bvals)
|
27 |
+
nvals = [(v-bvals[0])/(bvals[-1]-bvals[0]) for v in bvals] #normalized values
|
28 |
+
|
29 |
+
dcolorscale = [] #discrete colorscale
|
30 |
+
for k in range(len(colors)):
|
31 |
+
dcolorscale.extend([[nvals[k], colors[k]], [nvals[k+1], colors[k]]])
|
32 |
+
return dcolorscale
|
33 |
+
|
34 |
+
|
35 |
+
def plot_align_matrix_heatmap_plotly(align_matrix, sent1, sent2, threshhold, Cost):
|
36 |
+
align_matrix = np.where(align_matrix <= threshhold, 0, align_matrix)
|
37 |
+
sent1, sent2 = _debug_non_unique_axis_values(sent1, sent2)
|
38 |
+
_colors = ['#F2F2F2', '#E0F4FA', '#BEE4F0', '#88CCE5', '#33b7df', '#1B88A6', '#105264', '#092E39']
|
39 |
+
_ticks = [0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
|
40 |
+
|
41 |
+
colorscale = discrete_colorscale(_ticks, _colors)
|
42 |
+
|
43 |
+
fig = go.Figure()
|
44 |
+
|
45 |
+
fig.add_trace(go.Heatmap(
|
46 |
+
z=align_matrix,
|
47 |
+
customdata=Cost,
|
48 |
+
x=sent1,
|
49 |
+
y=sent2,
|
50 |
+
xgap=2,
|
51 |
+
ygap=2,
|
52 |
+
colorscale=colorscale,
|
53 |
+
colorbar=dict(
|
54 |
+
tick0=0,
|
55 |
+
dtick=0.125,
|
56 |
+
outlinewidth=0
|
57 |
+
),
|
58 |
+
hovertemplate=
|
59 |
+
'x: %{x}<br>' +
|
60 |
+
'y: %{y}<br>' +
|
61 |
+
'P: %{z:.3f}<br>' +
|
62 |
+
'cost: %{customdata:.3f} ',
|
63 |
+
name=''
|
64 |
+
))
|
65 |
+
fig.update_layout(
|
66 |
+
#xaxis=dict(scaleanchor='y'),
|
67 |
+
yaxis=dict(autorange='reversed'),
|
68 |
+
margin={'l': 0, 'r': 0, 't': 0, 'b': 0},
|
69 |
+
plot_bgcolor='rgba(0,0,0,0)',
|
70 |
+
font=dict(
|
71 |
+
size=16,
|
72 |
+
),
|
73 |
+
hoverlabel=dict(
|
74 |
+
bgcolor="#555",
|
75 |
+
font_color="white",
|
76 |
+
font_size=14,
|
77 |
+
font_family="Open Sans"
|
78 |
+
)
|
79 |
+
)
|
80 |
+
fig.update_xaxes(
|
81 |
+
tickangle=-45,
|
82 |
+
)
|
83 |
+
return fig
|
84 |
+
|
85 |
+
|
86 |
+
def plot_similarity_matrix_heatmap_plotly(similarity_matrix, sent1, sent2, Cost):
|
87 |
+
sent1, sent2 = _debug_non_unique_axis_values(sent1, sent2)
|
88 |
+
|
89 |
+
fig = go.Figure()
|
90 |
+
|
91 |
+
fig.add_trace(go.Heatmap(
|
92 |
+
z=similarity_matrix,
|
93 |
+
customdata=Cost,
|
94 |
+
x=sent1,
|
95 |
+
y=sent2,
|
96 |
+
xgap=2,
|
97 |
+
ygap=2,
|
98 |
+
colorscale="Reds",
|
99 |
+
colorbar=dict(
|
100 |
+
tick0=0,
|
101 |
+
dtick=0.125,
|
102 |
+
outlinewidth=0
|
103 |
+
),
|
104 |
+
hovertemplate=
|
105 |
+
'x: %{x}<br>' +
|
106 |
+
'y: %{y}<br>' +
|
107 |
+
'cosine: %{z:.3f}<br>' +
|
108 |
+
'cost: %{customdata:.3f} ',
|
109 |
+
name=''
|
110 |
+
))
|
111 |
+
fig.update_layout(
|
112 |
+
#xaxis=dict(scaleanchor='y'),
|
113 |
+
yaxis=dict(autorange='reversed'),
|
114 |
+
margin={'l': 0, 'r': 0, 't': 0, 'b': 0},
|
115 |
+
plot_bgcolor='rgba(0,0,0,0)',
|
116 |
+
font=dict(
|
117 |
+
size=16,
|
118 |
+
),
|
119 |
+
hoverlabel=dict(
|
120 |
+
bgcolor="#555",
|
121 |
+
font_color="white",
|
122 |
+
font_size=14,
|
123 |
+
font_family="Open Sans"
|
124 |
+
)
|
125 |
+
)
|
126 |
+
fig.update_xaxes(
|
127 |
+
tickangle=-45,
|
128 |
+
)
|
129 |
+
return fig
|
requirements.txt
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
POT==0.9.0
|
2 |
sentencepiece==0.1.99
|
3 |
-
streamlit==1.
|
4 |
tokenizers==0.13.3
|
5 |
transformers==4.30.2
|
6 |
matplotlib==3.7.1
|
7 |
-
|
8 |
-
torch==2.0.1
|
|
|
|
1 |
POT==0.9.0
|
2 |
sentencepiece==0.1.99
|
3 |
+
streamlit==1.27.2
|
4 |
tokenizers==0.13.3
|
5 |
transformers==4.30.2
|
6 |
matplotlib==3.7.1
|
7 |
+
plotly==5.15.0
|
8 |
+
torch==2.0.1
|
9 |
+
nltk==3.8.1
|
utils.py
CHANGED
@@ -1,105 +1,69 @@
|
|
1 |
import numpy as np
|
2 |
import torch
|
3 |
-
import torch.nn.functional as F
|
4 |
-
from ot.backend import get_backend
|
5 |
|
6 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
7 |
|
8 |
-
def
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
import matplotlib.pyplot as plt
|
73 |
-
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
74 |
-
|
75 |
-
def plot_align_matrix_heatmap(align_matrix, sent1, sent2, thresh, **kwargs):
|
76 |
-
|
77 |
-
align_matrix = np.where(align_matrix <= thresh, 0, align_matrix)
|
78 |
-
|
79 |
-
fig, ax = plt.subplots(figsize=(10, 6))
|
80 |
-
sns.set(font='sans-serif', style="ticks")
|
81 |
-
|
82 |
-
_color = ['#F2F2F2', '#E0F4FA', '#BEE4F0', '#88CCE5', '#33b7df', '#1B88A6', '#105264', '#092E39']
|
83 |
-
_ticks = [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
|
84 |
-
|
85 |
-
divider = make_axes_locatable(ax)
|
86 |
-
cbar_ax = divider.append_axes("right", size="2.5%", pad=0.1)
|
87 |
-
fig.add_axes(cbar_ax)
|
88 |
-
ax = sns.heatmap(
|
89 |
-
align_matrix,
|
90 |
-
xticklabels=sent1,
|
91 |
-
yticklabels=sent2,
|
92 |
-
cmap=_color,
|
93 |
-
linewidths=1,
|
94 |
-
square=True,
|
95 |
-
ax=ax,
|
96 |
-
cbar_ax=cbar_ax,
|
97 |
-
**kwargs
|
98 |
-
)
|
99 |
-
ax.collections[0].colorbar.ax.yaxis.set_ticks(_ticks, minor=False)
|
100 |
-
ax.collections[0].colorbar.set_ticklabels(_ticks)
|
101 |
-
cax = ax.collections[0].colorbar.ax
|
102 |
-
cax.tick_params(which='major', length=3, labelsize=5)
|
103 |
-
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
|
104 |
-
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
|
105 |
-
return fig
|
|
|
1 |
import numpy as np
|
2 |
import torch
|
|
|
|
|
3 |
|
4 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
5 |
|
6 |
+
def encode_sentence(sent, pair, tokenizer, model, layer: int):
|
7 |
+
if pair == None:
|
8 |
+
inputs = tokenizer(sent, padding=False, truncation=False, is_split_into_words=True, return_offsets_mapping=True,
|
9 |
+
return_tensors="pt")
|
10 |
+
with torch.no_grad():
|
11 |
+
outputs = model(inputs['input_ids'].to(device), inputs['attention_mask'].to(device),
|
12 |
+
inputs['token_type_ids'].to(device))
|
13 |
+
else:
|
14 |
+
inputs = tokenizer(text=sent, text_pair=pair, padding=False, truncation=True,
|
15 |
+
is_split_into_words=True,
|
16 |
+
return_offsets_mapping=True, return_tensors="pt")
|
17 |
+
with torch.no_grad():
|
18 |
+
outputs = model(inputs['input_ids'].to(device), inputs['attention_mask'].to(device),
|
19 |
+
inputs['token_type_ids'].to(device))
|
20 |
+
|
21 |
+
return outputs.hidden_states[layer][0], inputs['input_ids'][0], inputs['offset_mapping'][0]
|
22 |
+
|
23 |
+
|
24 |
+
def centering(hidden_outputs):
|
25 |
+
"""
|
26 |
+
hidden_outputs : [tokens, hidden_size]
|
27 |
+
"""
|
28 |
+
# 全てのトークンの埋め込みについて足し上げ、その平均ベクトルを求める
|
29 |
+
mean_vec = torch.sum(hidden_outputs, dim=0) / hidden_outputs.shape[0]
|
30 |
+
hidden_outputs = hidden_outputs - mean_vec
|
31 |
+
print(hidden_outputs.shape)
|
32 |
+
return hidden_outputs
|
33 |
+
|
34 |
+
|
35 |
+
def convert_to_word_embeddings(offset_mapping, token_ids, hidden_tensors, tokenizer, pair):
|
36 |
+
word_idx = -1
|
37 |
+
subword_to_word_conv = np.full((hidden_tensors.shape[0]), -1)
|
38 |
+
# Bug in hugging face tokenizer? Sometimes Metaspace is inserted
|
39 |
+
metaspace = getattr(tokenizer.decoder, "replacement", None)
|
40 |
+
metaspace = tokenizer.decoder.prefix if metaspace is None else metaspace
|
41 |
+
tokenizer_bug_idxes = [i for i, x in enumerate(tokenizer.convert_ids_to_tokens(token_ids)) if
|
42 |
+
x == metaspace]
|
43 |
+
|
44 |
+
for subw_idx, offset in enumerate(offset_mapping):
|
45 |
+
if subw_idx in tokenizer_bug_idxes:
|
46 |
+
continue
|
47 |
+
elif offset[0] == offset[1]: # Special token
|
48 |
+
continue
|
49 |
+
elif offset[0] == 0:
|
50 |
+
word_idx += 1
|
51 |
+
subword_to_word_conv[subw_idx] = word_idx
|
52 |
+
else:
|
53 |
+
subword_to_word_conv[subw_idx] = word_idx
|
54 |
+
|
55 |
+
word_embeddings = torch.vstack(
|
56 |
+
([torch.mean(hidden_tensors[subword_to_word_conv == word_idx], dim=0) for word_idx in range(word_idx + 1)]))
|
57 |
+
print(word_embeddings.shape)
|
58 |
+
|
59 |
+
if pair:
|
60 |
+
sep_tok_indices = [i for i, x in enumerate(token_ids) if x == tokenizer.sep_token_id]
|
61 |
+
s2_start_idx = subword_to_word_conv[
|
62 |
+
sep_tok_indices[0] + np.argmax(subword_to_word_conv[sep_tok_indices[0]:] > -1)]
|
63 |
+
|
64 |
+
s1_word_embeddigs = word_embeddings[0:s2_start_idx, :]
|
65 |
+
s2_word_embeddigs = word_embeddings[s2_start_idx:, :]
|
66 |
+
|
67 |
+
return s1_word_embeddigs, s2_word_embeddigs
|
68 |
+
else:
|
69 |
+
return word_embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|