4kasha commited on
Commit
94f5fd3
1 Parent(s): e7088f8

update demo

Browse files
Files changed (6) hide show
  1. aligner.py +34 -47
  2. app.py +45 -91
  3. otfuncs.py +68 -0
  4. plotools.py +129 -0
  5. requirements.txt +4 -3
  6. utils.py +64 -100
aligner.py CHANGED
@@ -1,7 +1,7 @@
1
  import numpy as np
2
  import torch
3
  import ot
4
- from utils import (
5
  compute_distance_matrix_cosine,
6
  compute_distance_matrix_l2,
7
  compute_weights_norm,
@@ -30,55 +30,36 @@ class Aligner:
30
  else:
31
  self.weight_func = compute_weights_norm
32
 
33
- def compute_alignment_matrixes(self, s1_vecs, s2_vecs):
34
- self.align_matrixes = []
35
- for vecX, vecY in zip(s1_vecs, s2_vecs):
36
- P = self.compute_optimal_transport(vecX, vecY)
37
- if torch.is_tensor(P):
38
- P = P.to('cpu').numpy()
39
 
40
- self.align_matrixes.append(P)
41
-
42
- def get_alignments(self, thresh, assign_cost=False):
43
- assert len(self.align_matrixes) > 0
44
-
45
- self.thresh = thresh
46
- all_alignments = []
47
- for P in self.align_matrixes:
48
- alignments = self.matrix_to_alignments(P, assign_cost)
49
- all_alignments.append(alignments)
50
-
51
- return all_alignments
52
-
53
- def matrix_to_alignments(self, P, assign_cost):
54
- alignments = set()
55
- align_pairs = np.transpose(np.nonzero(P > self.thresh))
56
- if assign_cost:
57
- for i_j in align_pairs:
58
- alignments.add('{0}-{1}-{2:.4f}'.format(i_j[0], i_j[1], P[i_j[0], i_j[1]]))
59
- else:
60
- for i_j in align_pairs:
61
- alignments.add('{0}-{1}'.format(i_j[0], i_j[1]))
62
-
63
- return alignments
64
 
65
  def compute_optimal_transport(self, s1_word_embeddigs, s2_word_embeddigs):
66
  s1_word_embeddigs = s1_word_embeddigs.to(torch.float64)
67
  s2_word_embeddigs = s2_word_embeddigs.to(torch.float64)
68
 
69
- C = self.dist_func(s1_word_embeddigs, s2_word_embeddigs, self.distotion)
70
  s1_weights, s2_weights = self.weight_func(s1_word_embeddigs, s2_word_embeddigs)
71
 
72
  if self.ot_type == 'ot':
73
  s1_weights = s1_weights / s1_weights.sum()
74
  s2_weights = s2_weights / s2_weights.sum()
75
- s1_weights, s2_weights, C = self.comvert_to_numpy(s1_weights, s2_weights, C)
76
 
77
  if self.sinkhorn:
78
- P = ot.bregman.sinkhorn_log(s1_weights, s2_weights, C, reg=self.epsilon, stopThr=self.stopThr,
79
- numItermax=self.numItermax)
 
 
 
80
  else:
81
- P = ot.emd(s1_weights, s2_weights, C)
82
  # Min-max normalization
83
  P = min_max_scaling(P)
84
 
@@ -89,16 +70,18 @@ class Aligner:
89
  else:
90
  m = self.tau
91
 
92
- s1_weights, s2_weights, C = self.comvert_to_numpy(s1_weights, s2_weights, C)
93
  m = np.min((np.sum(s1_weights), np.sum(s2_weights))) * m
94
 
95
  if self.sinkhorn:
96
- P = ot.partial.entropic_partial_wasserstein(s1_weights, s2_weights, C,
97
- reg=self.epsilon,
98
- m=m, stopThr=self.stopThr, numItermax=self.numItermax)
 
 
99
  else:
100
  # To cope with round error
101
- P = ot.partial.partial_wasserstein(s1_weights, s2_weights, C, m=m)
102
  # Min-max normalization
103
  P = min_max_scaling(P)
104
 
@@ -109,20 +92,24 @@ class Aligner:
109
  tau = self.tau
110
 
111
  if self.ot_type == 'uot':
112
- P = ot.unbalanced.sinkhorn_stabilized_unbalanced(s1_weights, s2_weights, C, reg=self.epsilon, reg_m=tau,
113
- stopThr=self.stopThr, numItermax=self.numItermax)
 
 
114
  elif self.ot_type == 'uot-mm':
115
- P = ot.unbalanced.mm_unbalanced(s1_weights, s2_weights, C, reg_m=tau, div=self.div_type,
116
- stopThr=self.stopThr, numItermax=self.numItermax)
 
 
117
  # Min-max normalization
118
  P = min_max_scaling(P)
119
 
120
  elif self.ot_type == 'none':
121
  P = 1 - C
122
 
123
- return P
124
 
125
- def comvert_to_numpy(self, s1_weights, s2_weights, C):
126
  if torch.is_tensor(s1_weights):
127
  s1_weights = s1_weights.to('cpu').numpy()
128
  s2_weights = s2_weights.to('cpu').numpy()
 
1
  import numpy as np
2
  import torch
3
  import ot
4
+ from otfuncs import (
5
  compute_distance_matrix_cosine,
6
  compute_distance_matrix_l2,
7
  compute_weights_norm,
 
30
  else:
31
  self.weight_func = compute_weights_norm
32
 
33
+ def compute_alignment_matrixes(self, s1_word_embeddigs, s2_word_embeddigs):
34
+ P, Cost, log, similarity_matrix = self.compute_optimal_transport(s1_word_embeddigs, s2_word_embeddigs)
35
+ print(log.keys())
36
+ if torch.is_tensor(P):
37
+ P = P.to('cpu').numpy()
38
+ loss = log.get('cost', 'NotImplemented')
39
 
40
+ return P, Cost, loss, similarity_matrix
41
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  def compute_optimal_transport(self, s1_word_embeddigs, s2_word_embeddigs):
44
  s1_word_embeddigs = s1_word_embeddigs.to(torch.float64)
45
  s2_word_embeddigs = s2_word_embeddigs.to(torch.float64)
46
 
47
+ C, similarity_matrix = self.dist_func(s1_word_embeddigs, s2_word_embeddigs, self.distotion)
48
  s1_weights, s2_weights = self.weight_func(s1_word_embeddigs, s2_word_embeddigs)
49
 
50
  if self.ot_type == 'ot':
51
  s1_weights = s1_weights / s1_weights.sum()
52
  s2_weights = s2_weights / s2_weights.sum()
53
+ s1_weights, s2_weights, C = self.convert_to_numpy(s1_weights, s2_weights, C)
54
 
55
  if self.sinkhorn:
56
+ P, log = ot.bregman.sinkhorn_log(
57
+ s1_weights, s2_weights, C,
58
+ reg=self.epsilon, stopThr=self.stopThr,
59
+ numItermax=self.numItermax, log=True
60
+ )
61
  else:
62
+ P, log = ot.emd(s1_weights, s2_weights, C, log=True)
63
  # Min-max normalization
64
  P = min_max_scaling(P)
65
 
 
70
  else:
71
  m = self.tau
72
 
73
+ s1_weights, s2_weights, C = self.convert_to_numpy(s1_weights, s2_weights, C)
74
  m = np.min((np.sum(s1_weights), np.sum(s2_weights))) * m
75
 
76
  if self.sinkhorn:
77
+ P, log = ot.partial.entropic_partial_wasserstein(
78
+ s1_weights, s2_weights, C,
79
+ reg=self.epsilon,
80
+ m=m, stopThr=self.stopThr, numItermax=self.numItermax, log=True
81
+ )
82
  else:
83
  # To cope with round error
84
+ P, log = ot.partial.partial_wasserstein(s1_weights, s2_weights, C, m=m, log=True)
85
  # Min-max normalization
86
  P = min_max_scaling(P)
87
 
 
92
  tau = self.tau
93
 
94
  if self.ot_type == 'uot':
95
+ P, log = ot.unbalanced.sinkhorn_stabilized_unbalanced(
96
+ s1_weights, s2_weights, C, reg=self.epsilon, reg_m=tau,
97
+ stopThr=self.stopThr, numItermax=self.numItermax, log=True
98
+ )
99
  elif self.ot_type == 'uot-mm':
100
+ P, log = ot.unbalanced.mm_unbalanced(
101
+ s1_weights, s2_weights, C, reg_m=tau, div=self.div_type,
102
+ stopThr=self.stopThr, numItermax=self.numItermax, log=True
103
+ )
104
  # Min-max normalization
105
  P = min_max_scaling(P)
106
 
107
  elif self.ot_type == 'none':
108
  P = 1 - C
109
 
110
+ return P, C, log, similarity_matrix
111
 
112
+ def convert_to_numpy(self, s1_weights, s2_weights, C):
113
  if torch.is_tensor(s1_weights):
114
  s1_weights = s1_weights.to('cpu').numpy()
115
  s2_weights = s2_weights.to('cpu').numpy()
app.py CHANGED
@@ -2,14 +2,22 @@ import streamlit as st
2
  import random
3
  import numpy as np
4
  import torch
 
5
  from transformers import AutoTokenizer, AutoModel
6
  from aligner import Aligner
7
- from utils import plot_align_matrix_heatmap
 
 
 
 
 
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  torch.manual_seed(42)
11
  np.random.seed(42)
12
  random.seed(42)
 
 
13
 
14
 
15
  @st.cache_resource
@@ -34,72 +42,6 @@ def init_aligner(ot_type: str, sinkhorn: bool, distortion: float, threshhold: fl
34
  )
35
 
36
 
37
- def encode_sentence(sent, pair, tokenizer, model, layer: int):
38
- if pair == None:
39
- inputs = tokenizer(sent, padding=False, truncation=False, is_split_into_words=True, return_offsets_mapping=True,
40
- return_tensors="pt")
41
- with torch.no_grad():
42
- outputs = model(inputs['input_ids'].to(device), inputs['attention_mask'].to(device),
43
- inputs['token_type_ids'].to(device))
44
- else:
45
- inputs = tokenizer(text=sent, text_pair=pair, padding=False, truncation=True,
46
- is_split_into_words=True,
47
- return_offsets_mapping=True, return_tensors="pt")
48
- with torch.no_grad():
49
- outputs = model(inputs['input_ids'].to(device), inputs['attention_mask'].to(device),
50
- inputs['token_type_ids'].to(device))
51
-
52
- return outputs.hidden_states[layer][0], inputs['input_ids'][0], inputs['offset_mapping'][0]
53
-
54
-
55
- def centering(hidden_outputs):
56
- """
57
- hidden_outputs : [tokens, hidden_size]
58
- """
59
- # 全てのトークンの埋め込みについて足し上げ、その平均ベクトルを求める
60
- mean_vec = torch.sum(hidden_outputs, dim=0) / hidden_outputs.shape[0]
61
- hidden_outputs = hidden_outputs - mean_vec
62
- print(hidden_outputs.shape)
63
- return hidden_outputs
64
-
65
-
66
- def convert_to_word_embeddings(offset_mapping, token_ids, hidden_tensors, tokenizer, pair):
67
- word_idx = -1
68
- subword_to_word_conv = np.full((hidden_tensors.shape[0]), -1)
69
- # Bug in hugging face tokenizer? Sometimes Metaspace is inserted
70
- metaspace = getattr(tokenizer.decoder, "replacement", None)
71
- metaspace = tokenizer.decoder.prefix if metaspace is None else metaspace
72
- tokenizer_bug_idxes = [i for i, x in enumerate(tokenizer.convert_ids_to_tokens(token_ids)) if
73
- x == metaspace]
74
-
75
- for subw_idx, offset in enumerate(offset_mapping):
76
- if subw_idx in tokenizer_bug_idxes:
77
- continue
78
- elif offset[0] == offset[1]: # Special token
79
- continue
80
- elif offset[0] == 0:
81
- word_idx += 1
82
- subword_to_word_conv[subw_idx] = word_idx
83
- else:
84
- subword_to_word_conv[subw_idx] = word_idx
85
-
86
- word_embeddings = torch.vstack(
87
- ([torch.mean(hidden_tensors[subword_to_word_conv == word_idx], dim=0) for word_idx in range(word_idx + 1)]))
88
- print(word_embeddings.shape)
89
-
90
- if pair:
91
- sep_tok_indices = [i for i, x in enumerate(token_ids) if x == tokenizer.sep_token_id]
92
- s2_start_idx = subword_to_word_conv[
93
- sep_tok_indices[0] + np.argmax(subword_to_word_conv[sep_tok_indices[0]:] > -1)]
94
-
95
- s1_word_embeddigs = word_embeddings[0:s2_start_idx, :]
96
- s2_word_embeddigs = word_embeddings[s2_start_idx:, :]
97
-
98
- return s1_word_embeddigs, s2_word_embeddigs
99
- else:
100
- return word_embeddings
101
-
102
-
103
  def main():
104
  st.set_page_config(layout="wide")
105
 
@@ -107,21 +49,30 @@ def main():
107
  st.sidebar.markdown("## Settings & Parameters")
108
  model = st.sidebar.selectbox('model', ['microsoft/deberta-v3-base', 'bert-base-uncased'])
109
  layer = st.sidebar.slider(
110
- 'layer number for embeddings', 0, 11, value=9
111
  )
112
  is_centering = st.sidebar.checkbox('centering embeddings', value=True)
113
- ot_type = st.sidebar.selectbox('ot_type', ['OT', 'POT', 'UOT'])
 
 
 
114
  ot_type = ot_type.lower()
115
- sinkhorn = st.sidebar.checkbox('sinkhorn', value=True)
 
 
 
116
  distortion = st.sidebar.slider(
117
- 'distortion: $\kappa$', 0.0, 1.0, value=0.20
 
118
  )
119
  tau = st.sidebar.slider(
120
- 'tau: $\\tau$', 0.0, 1.0, value=0.98
121
- ) # with 0.02 interva
 
122
  threshhold = st.sidebar.slider(
123
- 'threshhold: $\lambda$', 0.0, 1.0
124
- ) # with 0.01 interval
 
125
 
126
  # Content
127
  st.markdown('## Playground: Unbalanced Optimal Transport for Unbalanced Word Alignment')
@@ -130,39 +81,42 @@ def main():
130
 
131
  with col1:
132
  sent1 = st.text_area(
133
- 'sentence 1',
134
- 'By one estimate , fewer than 20,000 lions exist in the wild , a drop of about 40 percent in the past two decades .'
 
135
  )
136
  with col2:
137
  sent2 = st.text_area(
138
- 'sentence 2',
139
- 'Today there are only around 20,000 wild lions left in the world .'
 
140
  )
141
 
142
  tokenizer, model = init_model(model)
143
  aligner = init_aligner(ot_type, sinkhorn, distortion, threshhold, tau)
144
 
145
  with st.container():
146
- st.write("word alignment matrix")
147
-
148
  if sent1 != '' and sent2 != '':
149
- sent1 = sent1.lower().split()
150
- sent2 = sent2.lower().split()
 
 
151
  hidden_output, input_id, offset_map = encode_sentence(sent1, sent2, tokenizer, model, layer=layer)
152
  if is_centering:
153
  hidden_output = centering(hidden_output)
154
  s1_vec, s2_vec = convert_to_word_embeddings(offset_map, input_id, hidden_output, tokenizer, pair=True)
155
- aligner.compute_alignment_matrixes([s1_vec], [s2_vec])
156
- align_matrix = aligner.align_matrixes[0]
157
- print(align_matrix.shape)
158
 
159
- #fig = align_matrix_heatmap(align_matrix.T, sent1, sent2, threshhold)
160
- #st.plotly_chart(fig, use_container_width=True)
161
- fig = plot_align_matrix_heatmap(align_matrix.T, sent1, sent2, threshhold)
162
- st.pyplot(fig, dpi=300)
 
 
 
163
 
164
  st.divider()
165
- st.markdown("Note that the centering in this demo is applied only to the input sentences, so the variance may be large.")
166
  st.subheader('Refs')
167
  st.write("Yuki Arase, Han Bao, Sho Yokoi, [Unbalanced Optimal Transport for Unbalanced Word Alignment](https://arxiv.org/abs/2306.04116), ACL2023 [[github](https://github.com/yukiar/OTAlign/tree/main)]")
168
 
 
2
  import random
3
  import numpy as np
4
  import torch
5
+ from nltk.tokenize import word_tokenize
6
  from transformers import AutoTokenizer, AutoModel
7
  from aligner import Aligner
8
+ from utils import (
9
+ encode_sentence,
10
+ centering,
11
+ convert_to_word_embeddings
12
+ )
13
+ from plotools import plot_align_matrix_heatmap_plotly, plot_similarity_matrix_heatmap_plotly
14
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  torch.manual_seed(42)
17
  np.random.seed(42)
18
  random.seed(42)
19
+ import nltk
20
+ nltk.download('punkt')
21
 
22
 
23
  @st.cache_resource
 
42
  )
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def main():
46
  st.set_page_config(layout="wide")
47
 
 
49
  st.sidebar.markdown("## Settings & Parameters")
50
  model = st.sidebar.selectbox('model', ['microsoft/deberta-v3-base', 'bert-base-uncased'])
51
  layer = st.sidebar.slider(
52
+ 'layer number for embeddings', 0, 11, value=9,
53
  )
54
  is_centering = st.sidebar.checkbox('centering embeddings', value=True)
55
+ ot_type = st.sidebar.selectbox(
56
+ 'ot_type', ['OT', 'POT', 'UOT'],
57
+ help="optimal transport algorithm to be used"
58
+ )
59
  ot_type = ot_type.lower()
60
+ sinkhorn = st.sidebar.checkbox(
61
+ 'sinkhorn', value=True,
62
+ help="use sinkhorn algorithm"
63
+ )
64
  distortion = st.sidebar.slider(
65
+ 'distortion: $\kappa$', 0.0, 1.0, value=0.20,
66
+ help="suppression of off-diagonal alignments"
67
  )
68
  tau = st.sidebar.slider(
69
+ 'm / $\\tau$', 0.0, 1.0, value=0.98,
70
+ help="fraction of fertility to be aligned (fraction of mass to be transported) / penalties"
71
+ )
72
  threshhold = st.sidebar.slider(
73
+ 'threshhold: $\lambda$', 0.0, 1.0, value=0.22,
74
+ help="sparsity of alignment matrix"
75
+ )
76
 
77
  # Content
78
  st.markdown('## Playground: Unbalanced Optimal Transport for Unbalanced Word Alignment')
 
81
 
82
  with col1:
83
  sent1 = st.text_area(
84
+ 'sentence 1',
85
+ 'By one estimate, fewer than 20,000 lions exist in the wild, a drop of about 40 percent in the past two decades.',
86
+ help="Initial text"
87
  )
88
  with col2:
89
  sent2 = st.text_area(
90
+ 'sentence 2',
91
+ 'Today there are only around 20,000 wild lions left in the world.',
92
+ help="Text to compare"
93
  )
94
 
95
  tokenizer, model = init_model(model)
96
  aligner = init_aligner(ot_type, sinkhorn, distortion, threshhold, tau)
97
 
98
  with st.container():
 
 
99
  if sent1 != '' and sent2 != '':
100
+ sent1 = word_tokenize(sent1.lower())
101
+ sent2 = word_tokenize(sent2.lower())
102
+ print(sent1)
103
+ print(sent2)
104
  hidden_output, input_id, offset_map = encode_sentence(sent1, sent2, tokenizer, model, layer=layer)
105
  if is_centering:
106
  hidden_output = centering(hidden_output)
107
  s1_vec, s2_vec = convert_to_word_embeddings(offset_map, input_id, hidden_output, tokenizer, pair=True)
108
+ align_matrix, cost_matrix, loss, similarity_matrix = aligner.compute_alignment_matrixes(s1_vec, s2_vec)
109
+ print(align_matrix.shape, cost_matrix.shape)
 
110
 
111
+ st.write(f"**word alignment matrix** (loss: :blue[{loss}])")
112
+ fig = plot_align_matrix_heatmap_plotly(align_matrix.T, sent1, sent2, threshhold, cost_matrix.T)
113
+ st.plotly_chart(fig, use_container_width=True)
114
+
115
+ st.write(f"**word similarity matrix**")
116
+ fig2 = plot_similarity_matrix_heatmap_plotly(similarity_matrix.T, sent1, sent2, cost_matrix.T)
117
+ st.plotly_chart(fig2, use_container_width=True)
118
 
119
  st.divider()
 
120
  st.subheader('Refs')
121
  st.write("Yuki Arase, Han Bao, Sho Yokoi, [Unbalanced Optimal Transport for Unbalanced Word Alignment](https://arxiv.org/abs/2306.04116), ACL2023 [[github](https://github.com/yukiar/OTAlign/tree/main)]")
122
 
otfuncs.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from ot.backend import get_backend
5
+
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+
8
+ def compute_distance_matrix_cosine(s1_word_embeddigs, s2_word_embeddigs, distortion_ratio):
9
+ sim_matrix = (torch.matmul(F.normalize(s1_word_embeddigs), F.normalize(s2_word_embeddigs).t()) + 1.0) / 2 # Range 0-1
10
+ C = apply_distortion(sim_matrix, distortion_ratio)
11
+ C = min_max_scaling(C) # Range 0-1
12
+ C = 1.0 - C # Convert to distance
13
+
14
+ return C, sim_matrix
15
+
16
+
17
+ def compute_distance_matrix_l2(s1_word_embeddigs, s2_word_embeddigs, distortion_ratio):
18
+ C = torch.cdist(s1_word_embeddigs, s2_word_embeddigs, p=2)
19
+ C = min_max_scaling(C) # Range 0-1
20
+ C = 1.0 - C # Convert to similarity
21
+ C = apply_distortion(C, distortion_ratio)
22
+ C = min_max_scaling(C) # Range 0-1
23
+ C = 1.0 - C # Convert to distance
24
+
25
+ return C
26
+
27
+
28
+ def apply_distortion(sim_matrix, ratio):
29
+ shape = sim_matrix.shape
30
+ if (shape[0] < 2 or shape[1] < 2) or ratio == 0.0:
31
+ return sim_matrix
32
+
33
+ pos_x = torch.tensor([[y / float(shape[1] - 1) for y in range(shape[1])] for x in range(shape[0])],
34
+ device=device)
35
+ pos_y = torch.tensor([[x / float(shape[0] - 1) for x in range(shape[0])] for y in range(shape[1])],
36
+ device=device)
37
+ distortion_mask = 1.0 - ((pos_x - pos_y.T) ** 2) * ratio
38
+
39
+ sim_matrix = torch.mul(sim_matrix, distortion_mask)
40
+
41
+ return sim_matrix
42
+
43
+
44
+ def compute_weights_norm(s1_word_embeddigs, s2_word_embeddigs):
45
+ s1_weights = torch.norm(s1_word_embeddigs, dim=1)
46
+ s2_weights = torch.norm(s2_word_embeddigs, dim=1)
47
+ return s1_weights, s2_weights
48
+
49
+
50
+ def compute_weights_uniform(s1_word_embeddigs, s2_word_embeddigs):
51
+ s1_weights = torch.ones(s1_word_embeddigs.shape[0], dtype=torch.float64, device=device)
52
+ s2_weights = torch.ones(s2_word_embeddigs.shape[0], dtype=torch.float64, device=device)
53
+
54
+ # # Uniform weights to make L2 norm=1
55
+ # s1_weights /= torch.linalg.norm(s1_weights)
56
+ # s2_weights /= torch.linalg.norm(s2_weights)
57
+
58
+ return s1_weights, s2_weights
59
+
60
+
61
+ def min_max_scaling(C):
62
+ eps = 1e-10
63
+ # Min-max scaling for stabilization
64
+ nx = get_backend(C)
65
+ C_min = nx.min(C)
66
+ C_max = nx.max(C)
67
+ C = (C - C_min + eps) / (C_max - C_min + eps)
68
+ return C
plotools.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import plotly.graph_objects as go
3
+
4
+
5
+ def _debug_non_unique_axis_values(sent1: list[str], sent2: list[str]):
6
+ """
7
+ solution:
8
+ using zero-width-space
9
+ cf. https://github.com/plotly/plotly.js/issues/1516#issuecomment-983090013
10
+ """
11
+ sent1 = [word + i*'\u200b' for i, word in enumerate(sent1)]
12
+ sent2 = [word + i*'\u200b' for i, word in enumerate(sent2)]
13
+
14
+ return sent1, sent2
15
+
16
+
17
+ def discrete_colorscale(bvals, colors):
18
+ """
19
+ bvals - list of values bounding intervals/ranges of interest
20
+ colors - list of rgb or hex colorcodes for values in [bvals[k], bvals[k+1]],0 <= k < len(bvals)-1
21
+ returns the plotly discrete colorscale
22
+ ref. https://community.plotly.com/t/colors-for-discrete-ranges-in-heatmaps/7780
23
+ """
24
+ if len(bvals) != len(colors)+1:
25
+ raise ValueError('len(boundary values) should be equal to len(colors)+1')
26
+ bvals = sorted(bvals)
27
+ nvals = [(v-bvals[0])/(bvals[-1]-bvals[0]) for v in bvals] #normalized values
28
+
29
+ dcolorscale = [] #discrete colorscale
30
+ for k in range(len(colors)):
31
+ dcolorscale.extend([[nvals[k], colors[k]], [nvals[k+1], colors[k]]])
32
+ return dcolorscale
33
+
34
+
35
+ def plot_align_matrix_heatmap_plotly(align_matrix, sent1, sent2, threshhold, Cost):
36
+ align_matrix = np.where(align_matrix <= threshhold, 0, align_matrix)
37
+ sent1, sent2 = _debug_non_unique_axis_values(sent1, sent2)
38
+ _colors = ['#F2F2F2', '#E0F4FA', '#BEE4F0', '#88CCE5', '#33b7df', '#1B88A6', '#105264', '#092E39']
39
+ _ticks = [0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
40
+
41
+ colorscale = discrete_colorscale(_ticks, _colors)
42
+
43
+ fig = go.Figure()
44
+
45
+ fig.add_trace(go.Heatmap(
46
+ z=align_matrix,
47
+ customdata=Cost,
48
+ x=sent1,
49
+ y=sent2,
50
+ xgap=2,
51
+ ygap=2,
52
+ colorscale=colorscale,
53
+ colorbar=dict(
54
+ tick0=0,
55
+ dtick=0.125,
56
+ outlinewidth=0
57
+ ),
58
+ hovertemplate=
59
+ 'x: %{x}<br>' +
60
+ 'y: %{y}<br>' +
61
+ 'P: %{z:.3f}<br>' +
62
+ 'cost: %{customdata:.3f} ',
63
+ name=''
64
+ ))
65
+ fig.update_layout(
66
+ #xaxis=dict(scaleanchor='y'),
67
+ yaxis=dict(autorange='reversed'),
68
+ margin={'l': 0, 'r': 0, 't': 0, 'b': 0},
69
+ plot_bgcolor='rgba(0,0,0,0)',
70
+ font=dict(
71
+ size=16,
72
+ ),
73
+ hoverlabel=dict(
74
+ bgcolor="#555",
75
+ font_color="white",
76
+ font_size=14,
77
+ font_family="Open Sans"
78
+ )
79
+ )
80
+ fig.update_xaxes(
81
+ tickangle=-45,
82
+ )
83
+ return fig
84
+
85
+
86
+ def plot_similarity_matrix_heatmap_plotly(similarity_matrix, sent1, sent2, Cost):
87
+ sent1, sent2 = _debug_non_unique_axis_values(sent1, sent2)
88
+
89
+ fig = go.Figure()
90
+
91
+ fig.add_trace(go.Heatmap(
92
+ z=similarity_matrix,
93
+ customdata=Cost,
94
+ x=sent1,
95
+ y=sent2,
96
+ xgap=2,
97
+ ygap=2,
98
+ colorscale="Reds",
99
+ colorbar=dict(
100
+ tick0=0,
101
+ dtick=0.125,
102
+ outlinewidth=0
103
+ ),
104
+ hovertemplate=
105
+ 'x: %{x}<br>' +
106
+ 'y: %{y}<br>' +
107
+ 'cosine: %{z:.3f}<br>' +
108
+ 'cost: %{customdata:.3f} ',
109
+ name=''
110
+ ))
111
+ fig.update_layout(
112
+ #xaxis=dict(scaleanchor='y'),
113
+ yaxis=dict(autorange='reversed'),
114
+ margin={'l': 0, 'r': 0, 't': 0, 'b': 0},
115
+ plot_bgcolor='rgba(0,0,0,0)',
116
+ font=dict(
117
+ size=16,
118
+ ),
119
+ hoverlabel=dict(
120
+ bgcolor="#555",
121
+ font_color="white",
122
+ font_size=14,
123
+ font_family="Open Sans"
124
+ )
125
+ )
126
+ fig.update_xaxes(
127
+ tickangle=-45,
128
+ )
129
+ return fig
requirements.txt CHANGED
@@ -1,8 +1,9 @@
1
  POT==0.9.0
2
  sentencepiece==0.1.99
3
- streamlit==1.24.0
4
  tokenizers==0.13.3
5
  transformers==4.30.2
6
  matplotlib==3.7.1
7
- seaborn==0.12.2
8
- torch==2.0.1
 
 
1
  POT==0.9.0
2
  sentencepiece==0.1.99
3
+ streamlit==1.27.2
4
  tokenizers==0.13.3
5
  transformers==4.30.2
6
  matplotlib==3.7.1
7
+ plotly==5.15.0
8
+ torch==2.0.1
9
+ nltk==3.8.1
utils.py CHANGED
@@ -1,105 +1,69 @@
1
  import numpy as np
2
  import torch
3
- import torch.nn.functional as F
4
- from ot.backend import get_backend
5
 
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
7
 
8
- def compute_distance_matrix_cosine(s1_word_embeddigs, s2_word_embeddigs, distortion_ratio):
9
- C = (torch.matmul(F.normalize(s1_word_embeddigs), F.normalize(s2_word_embeddigs).t()) + 1.0) / 2 # Range 0-1
10
- C = apply_distortion(C, distortion_ratio)
11
- C = min_max_scaling(C) # Range 0-1
12
- C = 1.0 - C # Convert to distance
13
-
14
- return C
15
-
16
-
17
- def compute_distance_matrix_l2(s1_word_embeddigs, s2_word_embeddigs, distortion_ratio):
18
- C = torch.cdist(s1_word_embeddigs, s2_word_embeddigs, p=2)
19
- C = min_max_scaling(C) # Range 0-1
20
- C = 1.0 - C # Convert to similarity
21
- C = apply_distortion(C, distortion_ratio)
22
- C = min_max_scaling(C) # Range 0-1
23
- C = 1.0 - C # Convert to distance
24
-
25
- return C
26
-
27
-
28
- def apply_distortion(sim_matrix, ratio):
29
- shape = sim_matrix.shape
30
- if (shape[0] < 2 or shape[1] < 2) or ratio == 0.0:
31
- return sim_matrix
32
-
33
- pos_x = torch.tensor([[y / float(shape[1] - 1) for y in range(shape[1])] for x in range(shape[0])],
34
- device=device)
35
- pos_y = torch.tensor([[x / float(shape[0] - 1) for x in range(shape[0])] for y in range(shape[1])],
36
- device=device)
37
- distortion_mask = 1.0 - ((pos_x - pos_y.T) ** 2) * ratio
38
-
39
- sim_matrix = torch.mul(sim_matrix, distortion_mask)
40
-
41
- return sim_matrix
42
-
43
-
44
- def compute_weights_norm(s1_word_embeddigs, s2_word_embeddigs):
45
- s1_weights = torch.norm(s1_word_embeddigs, dim=1)
46
- s2_weights = torch.norm(s2_word_embeddigs, dim=1)
47
- return s1_weights, s2_weights
48
-
49
-
50
- def compute_weights_uniform(s1_word_embeddigs, s2_word_embeddigs):
51
- s1_weights = torch.ones(s1_word_embeddigs.shape[0], dtype=torch.float64, device=device)
52
- s2_weights = torch.ones(s2_word_embeddigs.shape[0], dtype=torch.float64, device=device)
53
-
54
- # # Uniform weights to make L2 norm=1
55
- # s1_weights /= torch.linalg.norm(s1_weights)
56
- # s2_weights /= torch.linalg.norm(s2_weights)
57
-
58
- return s1_weights, s2_weights
59
-
60
-
61
- def min_max_scaling(C):
62
- eps = 1e-10
63
- # Min-max scaling for stabilization
64
- nx = get_backend(C)
65
- C_min = nx.min(C)
66
- C_max = nx.max(C)
67
- C = (C - C_min + eps) / (C_max - C_min + eps)
68
- return C
69
-
70
-
71
- import seaborn as sns
72
- import matplotlib.pyplot as plt
73
- from mpl_toolkits.axes_grid1 import make_axes_locatable
74
-
75
- def plot_align_matrix_heatmap(align_matrix, sent1, sent2, thresh, **kwargs):
76
-
77
- align_matrix = np.where(align_matrix <= thresh, 0, align_matrix)
78
-
79
- fig, ax = plt.subplots(figsize=(10, 6))
80
- sns.set(font='sans-serif', style="ticks")
81
-
82
- _color = ['#F2F2F2', '#E0F4FA', '#BEE4F0', '#88CCE5', '#33b7df', '#1B88A6', '#105264', '#092E39']
83
- _ticks = [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
84
-
85
- divider = make_axes_locatable(ax)
86
- cbar_ax = divider.append_axes("right", size="2.5%", pad=0.1)
87
- fig.add_axes(cbar_ax)
88
- ax = sns.heatmap(
89
- align_matrix,
90
- xticklabels=sent1,
91
- yticklabels=sent2,
92
- cmap=_color,
93
- linewidths=1,
94
- square=True,
95
- ax=ax,
96
- cbar_ax=cbar_ax,
97
- **kwargs
98
- )
99
- ax.collections[0].colorbar.ax.yaxis.set_ticks(_ticks, minor=False)
100
- ax.collections[0].colorbar.set_ticklabels(_ticks)
101
- cax = ax.collections[0].colorbar.ax
102
- cax.tick_params(which='major', length=3, labelsize=5)
103
- ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
104
- ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
105
- return fig
 
1
  import numpy as np
2
  import torch
 
 
3
 
4
  device = "cuda" if torch.cuda.is_available() else "cpu"
5
 
6
+ def encode_sentence(sent, pair, tokenizer, model, layer: int):
7
+ if pair == None:
8
+ inputs = tokenizer(sent, padding=False, truncation=False, is_split_into_words=True, return_offsets_mapping=True,
9
+ return_tensors="pt")
10
+ with torch.no_grad():
11
+ outputs = model(inputs['input_ids'].to(device), inputs['attention_mask'].to(device),
12
+ inputs['token_type_ids'].to(device))
13
+ else:
14
+ inputs = tokenizer(text=sent, text_pair=pair, padding=False, truncation=True,
15
+ is_split_into_words=True,
16
+ return_offsets_mapping=True, return_tensors="pt")
17
+ with torch.no_grad():
18
+ outputs = model(inputs['input_ids'].to(device), inputs['attention_mask'].to(device),
19
+ inputs['token_type_ids'].to(device))
20
+
21
+ return outputs.hidden_states[layer][0], inputs['input_ids'][0], inputs['offset_mapping'][0]
22
+
23
+
24
+ def centering(hidden_outputs):
25
+ """
26
+ hidden_outputs : [tokens, hidden_size]
27
+ """
28
+ # 全てのトークンの埋め込みについて足し上げ、その平均ベクトルを求める
29
+ mean_vec = torch.sum(hidden_outputs, dim=0) / hidden_outputs.shape[0]
30
+ hidden_outputs = hidden_outputs - mean_vec
31
+ print(hidden_outputs.shape)
32
+ return hidden_outputs
33
+
34
+
35
+ def convert_to_word_embeddings(offset_mapping, token_ids, hidden_tensors, tokenizer, pair):
36
+ word_idx = -1
37
+ subword_to_word_conv = np.full((hidden_tensors.shape[0]), -1)
38
+ # Bug in hugging face tokenizer? Sometimes Metaspace is inserted
39
+ metaspace = getattr(tokenizer.decoder, "replacement", None)
40
+ metaspace = tokenizer.decoder.prefix if metaspace is None else metaspace
41
+ tokenizer_bug_idxes = [i for i, x in enumerate(tokenizer.convert_ids_to_tokens(token_ids)) if
42
+ x == metaspace]
43
+
44
+ for subw_idx, offset in enumerate(offset_mapping):
45
+ if subw_idx in tokenizer_bug_idxes:
46
+ continue
47
+ elif offset[0] == offset[1]: # Special token
48
+ continue
49
+ elif offset[0] == 0:
50
+ word_idx += 1
51
+ subword_to_word_conv[subw_idx] = word_idx
52
+ else:
53
+ subword_to_word_conv[subw_idx] = word_idx
54
+
55
+ word_embeddings = torch.vstack(
56
+ ([torch.mean(hidden_tensors[subword_to_word_conv == word_idx], dim=0) for word_idx in range(word_idx + 1)]))
57
+ print(word_embeddings.shape)
58
+
59
+ if pair:
60
+ sep_tok_indices = [i for i, x in enumerate(token_ids) if x == tokenizer.sep_token_id]
61
+ s2_start_idx = subword_to_word_conv[
62
+ sep_tok_indices[0] + np.argmax(subword_to_word_conv[sep_tok_indices[0]:] > -1)]
63
+
64
+ s1_word_embeddigs = word_embeddings[0:s2_start_idx, :]
65
+ s2_word_embeddigs = word_embeddings[s2_start_idx:, :]
66
+
67
+ return s1_word_embeddigs, s2_word_embeddigs
68
+ else:
69
+ return word_embeddings