4kasha commited on
Commit
37d364a
1 Parent(s): 94f5fd3
Files changed (5) hide show
  1. aligner.py +8 -19
  2. app.py +85 -42
  3. otfuncs.py +28 -14
  4. plotools.py +137 -75
  5. requirements.txt +2 -1
aligner.py CHANGED
@@ -10,10 +10,9 @@ from otfuncs import (
10
  )
11
 
12
  class Aligner:
13
- def __init__(self, ot_type, sinkhorn, chimera, dist_type, weight_type, distortion, thresh, tau, **kwargs):
14
  self.ot_type = ot_type
15
  self.sinkhorn = sinkhorn
16
- self.chimera = chimera
17
  self.dist_type = dist_type
18
  self.weight_type = weight_type
19
  self.distotion = distortion
@@ -31,20 +30,19 @@ class Aligner:
31
  self.weight_func = compute_weights_norm
32
 
33
  def compute_alignment_matrixes(self, s1_word_embeddigs, s2_word_embeddigs):
34
- P, Cost, log, similarity_matrix = self.compute_optimal_transport(s1_word_embeddigs, s2_word_embeddigs)
35
  print(log.keys())
36
  if torch.is_tensor(P):
37
  P = P.to('cpu').numpy()
38
  loss = log.get('cost', 'NotImplemented')
39
 
40
- return P, Cost, loss, similarity_matrix
41
 
42
-
43
  def compute_optimal_transport(self, s1_word_embeddigs, s2_word_embeddigs):
44
  s1_word_embeddigs = s1_word_embeddigs.to(torch.float64)
45
  s2_word_embeddigs = s2_word_embeddigs.to(torch.float64)
46
 
47
- C, similarity_matrix = self.dist_func(s1_word_embeddigs, s2_word_embeddigs, self.distotion)
48
  s1_weights, s2_weights = self.weight_func(s1_word_embeddigs, s2_word_embeddigs)
49
 
50
  if self.ot_type == 'ot':
@@ -64,14 +62,8 @@ class Aligner:
64
  P = min_max_scaling(P)
65
 
66
  elif self.ot_type == 'pot':
67
- if self.chimera:
68
- m = self.tau * self.bertscore_F1(s1_word_embeddigs, s2_word_embeddigs)
69
- m = min(1.0, m.item())
70
- else:
71
- m = self.tau
72
-
73
  s1_weights, s2_weights, C = self.convert_to_numpy(s1_weights, s2_weights, C)
74
- m = np.min((np.sum(s1_weights), np.sum(s2_weights))) * m
75
 
76
  if self.sinkhorn:
77
  P, log = ot.partial.entropic_partial_wasserstein(
@@ -86,10 +78,7 @@ class Aligner:
86
  P = min_max_scaling(P)
87
 
88
  elif 'uot' in self.ot_type:
89
- if self.chimera:
90
- tau = self.tau * self.bertscore_F1(s1_word_embeddigs, s2_word_embeddigs)
91
- else:
92
- tau = self.tau
93
 
94
  if self.ot_type == 'uot':
95
  P, log = ot.unbalanced.sinkhorn_stabilized_unbalanced(
@@ -107,7 +96,7 @@ class Aligner:
107
  elif self.ot_type == 'none':
108
  P = 1 - C
109
 
110
- return P, C, log, similarity_matrix
111
 
112
  def convert_to_numpy(self, s1_weights, s2_weights, C):
113
  if torch.is_tensor(s1_weights):
@@ -116,4 +105,4 @@ class Aligner:
116
  if torch.is_tensor(C):
117
  C = C.to('cpu').numpy()
118
 
119
- return s1_weights, s2_weights, C
 
10
  )
11
 
12
  class Aligner:
13
+ def __init__(self, ot_type, sinkhorn, dist_type, weight_type, distortion, thresh, tau, **kwargs):
14
  self.ot_type = ot_type
15
  self.sinkhorn = sinkhorn
 
16
  self.dist_type = dist_type
17
  self.weight_type = weight_type
18
  self.distotion = distortion
 
30
  self.weight_func = compute_weights_norm
31
 
32
  def compute_alignment_matrixes(self, s1_word_embeddigs, s2_word_embeddigs):
33
+ P, Cost, log, similarity_matrix, relative_distance = self.compute_optimal_transport(s1_word_embeddigs, s2_word_embeddigs)
34
  print(log.keys())
35
  if torch.is_tensor(P):
36
  P = P.to('cpu').numpy()
37
  loss = log.get('cost', 'NotImplemented')
38
 
39
+ return P, Cost, loss, similarity_matrix, relative_distance
40
 
 
41
  def compute_optimal_transport(self, s1_word_embeddigs, s2_word_embeddigs):
42
  s1_word_embeddigs = s1_word_embeddigs.to(torch.float64)
43
  s2_word_embeddigs = s2_word_embeddigs.to(torch.float64)
44
 
45
+ C, similarity_matrix, relative_distance = self.dist_func(s1_word_embeddigs, s2_word_embeddigs, self.distotion)
46
  s1_weights, s2_weights = self.weight_func(s1_word_embeddigs, s2_word_embeddigs)
47
 
48
  if self.ot_type == 'ot':
 
62
  P = min_max_scaling(P)
63
 
64
  elif self.ot_type == 'pot':
 
 
 
 
 
 
65
  s1_weights, s2_weights, C = self.convert_to_numpy(s1_weights, s2_weights, C)
66
+ m = np.min((np.sum(s1_weights), np.sum(s2_weights))) * self.tau
67
 
68
  if self.sinkhorn:
69
  P, log = ot.partial.entropic_partial_wasserstein(
 
78
  P = min_max_scaling(P)
79
 
80
  elif 'uot' in self.ot_type:
81
+ tau = self.tau
 
 
 
82
 
83
  if self.ot_type == 'uot':
84
  P, log = ot.unbalanced.sinkhorn_stabilized_unbalanced(
 
96
  elif self.ot_type == 'none':
97
  P = 1 - C
98
 
99
+ return P, C, log, similarity_matrix, relative_distance
100
 
101
  def convert_to_numpy(self, s1_weights, s2_weights, C):
102
  if torch.is_tensor(s1_weights):
 
105
  if torch.is_tensor(C):
106
  C = C.to('cpu').numpy()
107
 
108
+ return s1_weights, s2_weights, C
app.py CHANGED
@@ -1,44 +1,53 @@
1
- import streamlit as st
2
  import random
 
3
  import numpy as np
 
4
  import torch
 
5
  from nltk.tokenize import word_tokenize
6
- from transformers import AutoTokenizer, AutoModel
 
7
  from aligner import Aligner
8
- from utils import (
9
- encode_sentence,
10
- centering,
11
- convert_to_word_embeddings
 
 
12
  )
13
- from plotools import plot_align_matrix_heatmap_plotly, plot_similarity_matrix_heatmap_plotly
14
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  torch.manual_seed(42)
17
  np.random.seed(42)
18
  random.seed(42)
19
  import nltk
20
- nltk.download('punkt')
 
21
 
22
 
23
  @st.cache_resource
24
  def init_model(model: str):
25
  tokenizer = AutoTokenizer.from_pretrained(model)
26
- model = AutoModel.from_pretrained(model, output_hidden_states=True).to(device).eval()
 
 
27
  return tokenizer, model
28
 
29
 
30
  @st.cache_resource(max_entries=100)
31
- def init_aligner(ot_type: str, sinkhorn: bool, distortion: float, threshhold: float, tau: float):
 
 
32
  return Aligner(
33
  ot_type=ot_type,
34
  sinkhorn=sinkhorn,
35
- chimera=False,
36
  dist_type="cos",
37
  weight_type="uniform",
38
  distortion=distortion,
39
- thresh=threshhold,
40
- tau=tau,
41
- div_type="--"
42
  )
43
 
44
 
@@ -47,51 +56,70 @@ def main():
47
 
48
  # Sidebar
49
  st.sidebar.markdown("## Settings & Parameters")
50
- model = st.sidebar.selectbox('model', ['microsoft/deberta-v3-base', 'bert-base-uncased'])
 
 
51
  layer = st.sidebar.slider(
52
- 'layer number for embeddings', 0, 11, value=9,
 
 
 
53
  )
54
- is_centering = st.sidebar.checkbox('centering embeddings', value=True)
55
  ot_type = st.sidebar.selectbox(
56
- 'ot_type', ['OT', 'POT', 'UOT'],
57
- help="optimal transport algorithm to be used"
58
  )
59
  ot_type = ot_type.lower()
60
  sinkhorn = st.sidebar.checkbox(
61
- 'sinkhorn', value=True,
62
- help="use sinkhorn algorithm"
63
  )
64
  distortion = st.sidebar.slider(
65
- 'distortion: $\kappa$', 0.0, 1.0, value=0.20,
66
- help="suppression of off-diagonal alignments"
 
 
 
67
  )
68
  tau = st.sidebar.slider(
69
- 'm / $\\tau$', 0.0, 1.0, value=0.98,
70
- help="fraction of fertility to be aligned (fraction of mass to be transported) / penalties"
71
- )
 
 
 
72
  threshhold = st.sidebar.slider(
73
- 'threshhold: $\lambda$', 0.0, 1.0, value=0.22,
74
- help="sparsity of alignment matrix"
75
- )
 
 
 
 
 
 
 
 
76
 
77
  # Content
78
- st.markdown('## Playground: Unbalanced Optimal Transport for Unbalanced Word Alignment')
 
 
79
 
80
  col1, col2 = st.columns(2)
81
 
82
  with col1:
83
- sent1 = st.text_area(
84
- 'sentence 1',
85
- 'By one estimate, fewer than 20,000 lions exist in the wild, a drop of about 40 percent in the past two decades.',
86
- help="Initial text"
87
- )
88
  with col2:
89
- sent2 = st.text_area(
90
- 'sentence 2',
91
- 'Today there are only around 20,000 wild lions left in the world.',
92
- help="Text to compare"
93
- )
94
-
95
  tokenizer, model = init_model(model)
96
  aligner = init_aligner(ot_type, sinkhorn, distortion, threshhold, tau)
97
 
@@ -115,10 +143,25 @@ def main():
115
  st.write(f"**word similarity matrix**")
116
  fig2 = plot_similarity_matrix_heatmap_plotly(similarity_matrix.T, sent1, sent2, cost_matrix.T)
117
  st.plotly_chart(fig2, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  st.divider()
120
  st.subheader('Refs')
121
  st.write("Yuki Arase, Han Bao, Sho Yokoi, [Unbalanced Optimal Transport for Unbalanced Word Alignment](https://arxiv.org/abs/2306.04116), ACL2023 [[github](https://github.com/yukiar/OTAlign/tree/main)]")
122
 
123
  if __name__ == '__main__':
124
- main()
 
 
1
  import random
2
+
3
  import numpy as np
4
+ import streamlit as st
5
  import torch
6
+ import umap
7
  from nltk.tokenize import word_tokenize
8
+ from transformers import AutoModel, AutoTokenizer
9
+
10
  from aligner import Aligner
11
+
12
+ # from utils import align_matrix_heatmap, plot_align_matrix_heatmap
13
+ from plotools import (
14
+ plot_align_matrix_heatmap_plotly,
15
+ plot_similarity_matrix_heatmap_plotly,
16
+ show_assignments_plotly,
17
  )
18
+ from utils import centering, convert_to_word_embeddings, encode_sentence
19
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
  torch.manual_seed(42)
22
  np.random.seed(42)
23
  random.seed(42)
24
  import nltk
25
+
26
+ nltk.download("punkt")
27
 
28
 
29
  @st.cache_resource
30
  def init_model(model: str):
31
  tokenizer = AutoTokenizer.from_pretrained(model)
32
+ model = (
33
+ AutoModel.from_pretrained(model, output_hidden_states=True).to(device).eval()
34
+ )
35
  return tokenizer, model
36
 
37
 
38
  @st.cache_resource(max_entries=100)
39
+ def init_aligner(
40
+ ot_type: str, sinkhorn: bool, distortion: float, threshhold: float, tau: float
41
+ ):
42
  return Aligner(
43
  ot_type=ot_type,
44
  sinkhorn=sinkhorn,
 
45
  dist_type="cos",
46
  weight_type="uniform",
47
  distortion=distortion,
48
+ thresh=threshhold, # 0.25252525252525254
49
+ tau=tau, # 0.9803921568627451
50
+ div_type="--",
51
  )
52
 
53
 
 
56
 
57
  # Sidebar
58
  st.sidebar.markdown("## Settings & Parameters")
59
+ model = st.sidebar.selectbox(
60
+ "model", ["microsoft/deberta-v3-base", "bert-base-uncased"]
61
+ )
62
  layer = st.sidebar.slider(
63
+ "layer number for embeddings",
64
+ 0,
65
+ 11,
66
+ value=9,
67
  )
68
+ is_centering = st.sidebar.checkbox("centering embeddings", value=True)
69
  ot_type = st.sidebar.selectbox(
70
+ "ot_type", ["POT", "UOT", "OT"], help="optimal transport algorithm to be used"
 
71
  )
72
  ot_type = ot_type.lower()
73
  sinkhorn = st.sidebar.checkbox(
74
+ "sinkhorn", value=True, help="use sinkhorn algorithm"
 
75
  )
76
  distortion = st.sidebar.slider(
77
+ "distortion: $\kappa$",
78
+ 0.0,
79
+ 1.0,
80
+ value=0.20,
81
+ help="suppression of off-diagonal alignments",
82
  )
83
  tau = st.sidebar.slider(
84
+ "m / $\\tau$",
85
+ 0.0,
86
+ 1.0,
87
+ value=0.98,
88
+ help="fraction of fertility to be aligned (fraction of mass to be transported) / penalties",
89
+ ) # with 0.02 interva
90
  threshhold = st.sidebar.slider(
91
+ "threshhold: $\lambda$",
92
+ 0.0,
93
+ 1.0,
94
+ value=0.22,
95
+ help="sparsity of alignment matrix",
96
+ ) # with 0.01 interval
97
+ show_assignments = st.sidebar.checkbox("show assignments", value=True)
98
+ if show_assignments:
99
+ n_neighbors = st.sidebar.slider(
100
+ "n_neighbors", 2, 10, value=8, help="number of neighbors for umap"
101
+ )
102
 
103
  # Content
104
+ st.markdown(
105
+ "## Playground: Unbalanced Optimal Transport for Unbalanced Word Alignment"
106
+ )
107
 
108
  col1, col2 = st.columns(2)
109
 
110
  with col1:
111
+ sent1 = st.text_area(
112
+ "sentence 1",
113
+ "By one estimate, fewer than 20,000 lions exist in the wild, a drop of about 40 percent in the past two decades.",
114
+ help="Initial text",
115
+ )
116
  with col2:
117
+ sent2 = st.text_area(
118
+ "sentence 2",
119
+ "Today there are only around 20,000 wild lions left in the world.",
120
+ help="Text to compare",
121
+ )
122
+
123
  tokenizer, model = init_model(model)
124
  aligner = init_aligner(ot_type, sinkhorn, distortion, threshhold, tau)
125
 
 
143
  st.write(f"**word similarity matrix**")
144
  fig2 = plot_similarity_matrix_heatmap_plotly(similarity_matrix.T, sent1, sent2, cost_matrix.T)
145
  st.plotly_chart(fig2, use_container_width=True)
146
+
147
+ if show_assignments:
148
+ st.write(f"**Alignments after UMAP**")
149
+ word_embeddings = torch.vstack([s1_vec, s2_vec])
150
+ umap_embeddings = umap.UMAP(
151
+ n_neighbors=n_neighbors,
152
+ n_components=2,
153
+ random_state=42,
154
+ metric="cosine",
155
+ ).fit_transform(word_embeddings.detach().numpy())
156
+ print(umap_embeddings.shape)
157
+ fig3 = show_assignments_plotly(
158
+ align_matrix, umap_embeddings, sent1, sent2, thr=threshhold
159
+ )
160
+ st.plotly_chart(fig3, use_container_width=True)
161
 
162
  st.divider()
163
  st.subheader('Refs')
164
  st.write("Yuki Arase, Han Bao, Sho Yokoi, [Unbalanced Optimal Transport for Unbalanced Word Alignment](https://arxiv.org/abs/2306.04116), ACL2023 [[github](https://github.com/yukiar/OTAlign/tree/main)]")
165
 
166
  if __name__ == '__main__':
167
+ main()
otfuncs.py CHANGED
@@ -1,17 +1,22 @@
1
- import numpy as np
2
  import torch
3
  import torch.nn.functional as F
4
  from ot.backend import get_backend
5
 
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
7
 
8
- def compute_distance_matrix_cosine(s1_word_embeddigs, s2_word_embeddigs, distortion_ratio):
9
- sim_matrix = (torch.matmul(F.normalize(s1_word_embeddigs), F.normalize(s2_word_embeddigs).t()) + 1.0) / 2 # Range 0-1
10
- C = apply_distortion(sim_matrix, distortion_ratio)
 
 
 
 
 
 
11
  C = min_max_scaling(C) # Range 0-1
12
  C = 1.0 - C # Convert to distance
13
 
14
- return C, sim_matrix
15
 
16
 
17
  def compute_distance_matrix_l2(s1_word_embeddigs, s2_word_embeddigs, distortion_ratio):
@@ -30,15 +35,20 @@ def apply_distortion(sim_matrix, ratio):
30
  if (shape[0] < 2 or shape[1] < 2) or ratio == 0.0:
31
  return sim_matrix
32
 
33
- pos_x = torch.tensor([[y / float(shape[1] - 1) for y in range(shape[1])] for x in range(shape[0])],
34
- device=device)
35
- pos_y = torch.tensor([[x / float(shape[0] - 1) for x in range(shape[0])] for y in range(shape[1])],
36
- device=device)
37
- distortion_mask = 1.0 - ((pos_x - pos_y.T) ** 2) * ratio
 
 
 
 
 
38
 
39
  sim_matrix = torch.mul(sim_matrix, distortion_mask)
40
 
41
- return sim_matrix
42
 
43
 
44
  def compute_weights_norm(s1_word_embeddigs, s2_word_embeddigs):
@@ -48,8 +58,12 @@ def compute_weights_norm(s1_word_embeddigs, s2_word_embeddigs):
48
 
49
 
50
  def compute_weights_uniform(s1_word_embeddigs, s2_word_embeddigs):
51
- s1_weights = torch.ones(s1_word_embeddigs.shape[0], dtype=torch.float64, device=device)
52
- s2_weights = torch.ones(s2_word_embeddigs.shape[0], dtype=torch.float64, device=device)
 
 
 
 
53
 
54
  # # Uniform weights to make L2 norm=1
55
  # s1_weights /= torch.linalg.norm(s1_weights)
@@ -65,4 +79,4 @@ def min_max_scaling(C):
65
  C_min = nx.min(C)
66
  C_max = nx.max(C)
67
  C = (C - C_min + eps) / (C_max - C_min + eps)
68
- return C
 
 
1
  import torch
2
  import torch.nn.functional as F
3
  from ot.backend import get_backend
4
 
5
  device = "cuda" if torch.cuda.is_available() else "cpu"
6
 
7
+
8
+ def compute_distance_matrix_cosine(
9
+ s1_word_embeddigs, s2_word_embeddigs, distortion_ratio
10
+ ):
11
+ sim_matrix = (
12
+ torch.matmul(F.normalize(s1_word_embeddigs), F.normalize(s2_word_embeddigs).t())
13
+ + 1.0
14
+ ) / 2 # Range 0-1
15
+ C, relative_distance = apply_distortion(sim_matrix, distortion_ratio)
16
  C = min_max_scaling(C) # Range 0-1
17
  C = 1.0 - C # Convert to distance
18
 
19
+ return C, sim_matrix, relative_distance
20
 
21
 
22
  def compute_distance_matrix_l2(s1_word_embeddigs, s2_word_embeddigs, distortion_ratio):
 
35
  if (shape[0] < 2 or shape[1] < 2) or ratio == 0.0:
36
  return sim_matrix
37
 
38
+ pos_x = torch.tensor(
39
+ [[y / float(shape[1] - 1) for y in range(shape[1])] for x in range(shape[0])],
40
+ device=device,
41
+ )
42
+ pos_y = torch.tensor(
43
+ [[x / float(shape[0] - 1) for x in range(shape[0])] for y in range(shape[1])],
44
+ device=device,
45
+ )
46
+ relative_distance = (pos_x - pos_y.T) ** 2
47
+ distortion_mask = 1.0 - relative_distance * ratio
48
 
49
  sim_matrix = torch.mul(sim_matrix, distortion_mask)
50
 
51
+ return sim_matrix, relative_distance
52
 
53
 
54
  def compute_weights_norm(s1_word_embeddigs, s2_word_embeddigs):
 
58
 
59
 
60
  def compute_weights_uniform(s1_word_embeddigs, s2_word_embeddigs):
61
+ s1_weights = torch.ones(
62
+ s1_word_embeddigs.shape[0], dtype=torch.float64, device=device
63
+ )
64
+ s2_weights = torch.ones(
65
+ s2_word_embeddigs.shape[0], dtype=torch.float64, device=device
66
+ )
67
 
68
  # # Uniform weights to make L2 norm=1
69
  # s1_weights /= torch.linalg.norm(s1_weights)
 
79
  C_min = nx.min(C)
80
  C_max = nx.max(C)
81
  C = (C - C_min + eps) / (C_max - C_min + eps)
82
+ return C
plotools.py CHANGED
@@ -8,74 +8,79 @@ def _debug_non_unique_axis_values(sent1: list[str], sent2: list[str]):
8
  using zero-width-space
9
  cf. https://github.com/plotly/plotly.js/issues/1516#issuecomment-983090013
10
  """
11
- sent1 = [word + i*'\u200b' for i, word in enumerate(sent1)]
12
- sent2 = [word + i*'\u200b' for i, word in enumerate(sent2)]
13
-
14
  return sent1, sent2
15
 
16
 
17
  def discrete_colorscale(bvals, colors):
18
  """
19
  bvals - list of values bounding intervals/ranges of interest
20
- colors - list of rgb or hex colorcodes for values in [bvals[k], bvals[k+1]],0 <= k < len(bvals)-1
21
  returns the plotly discrete colorscale
22
  ref. https://community.plotly.com/t/colors-for-discrete-ranges-in-heatmaps/7780
23
  """
24
- if len(bvals) != len(colors)+1:
25
- raise ValueError('len(boundary values) should be equal to len(colors)+1')
26
- bvals = sorted(bvals)
27
- nvals = [(v-bvals[0])/(bvals[-1]-bvals[0]) for v in bvals] #normalized values
28
-
29
- dcolorscale = [] #discrete colorscale
 
 
30
  for k in range(len(colors)):
31
- dcolorscale.extend([[nvals[k], colors[k]], [nvals[k+1], colors[k]]])
32
- return dcolorscale
33
 
34
 
35
  def plot_align_matrix_heatmap_plotly(align_matrix, sent1, sent2, threshhold, Cost):
36
  align_matrix = np.where(align_matrix <= threshhold, 0, align_matrix)
37
  sent1, sent2 = _debug_non_unique_axis_values(sent1, sent2)
38
- _colors = ['#F2F2F2', '#E0F4FA', '#BEE4F0', '#88CCE5', '#33b7df', '#1B88A6', '#105264', '#092E39']
 
 
 
 
 
 
 
 
 
39
  _ticks = [0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
40
 
41
  colorscale = discrete_colorscale(_ticks, _colors)
42
 
43
  fig = go.Figure()
44
-
45
- fig.add_trace(go.Heatmap(
46
- z=align_matrix,
47
- customdata=Cost,
48
- x=sent1,
49
- y=sent2,
50
- xgap=2,
51
- ygap=2,
52
- colorscale=colorscale,
53
- colorbar=dict(
54
- tick0=0,
55
- dtick=0.125,
56
- outlinewidth=0
57
- ),
58
- hovertemplate=
59
- 'x: %{x}<br>' +
60
- 'y: %{y}<br>' +
61
- 'P: %{z:.3f}<br>' +
62
- 'cost: %{customdata:.3f} ',
63
- name=''
64
- ))
65
  fig.update_layout(
66
- #xaxis=dict(scaleanchor='y'),
67
- yaxis=dict(autorange='reversed'),
68
- margin={'l': 0, 'r': 0, 't': 0, 'b': 0},
69
- plot_bgcolor='rgba(0,0,0,0)',
70
  font=dict(
71
  size=16,
72
  ),
73
  hoverlabel=dict(
74
- bgcolor="#555",
75
- font_color="white",
76
- font_size=14,
77
- font_family="Open Sans"
78
- )
79
  )
80
  fig.update_xaxes(
81
  tickangle=-45,
@@ -83,47 +88,104 @@ def plot_align_matrix_heatmap_plotly(align_matrix, sent1, sent2, threshhold, Cos
83
  return fig
84
 
85
 
86
- def plot_similarity_matrix_heatmap_plotly(similarity_matrix, sent1, sent2, Cost):
 
 
87
  sent1, sent2 = _debug_non_unique_axis_values(sent1, sent2)
88
 
89
  fig = go.Figure()
90
-
91
- fig.add_trace(go.Heatmap(
92
- z=similarity_matrix,
93
- customdata=Cost,
94
- x=sent1,
95
- y=sent2,
96
- xgap=2,
97
- ygap=2,
98
- colorscale="Reds",
99
- colorbar=dict(
100
- tick0=0,
101
- dtick=0.125,
102
- outlinewidth=0
103
- ),
104
- hovertemplate=
105
- 'x: %{x}<br>' +
106
- 'y: %{y}<br>' +
107
- 'cosine: %{z:.3f}<br>' +
108
- 'cost: %{customdata:.3f} ',
109
- name=''
110
- ))
111
  fig.update_layout(
112
- #xaxis=dict(scaleanchor='y'),
113
- yaxis=dict(autorange='reversed'),
114
- margin={'l': 0, 'r': 0, 't': 0, 'b': 0},
115
- plot_bgcolor='rgba(0,0,0,0)',
116
  font=dict(
117
  size=16,
118
  ),
119
  hoverlabel=dict(
120
- bgcolor="#555",
121
- font_color="white",
122
- font_size=14,
123
- font_family="Open Sans"
124
- )
125
  )
126
  fig.update_xaxes(
127
  tickangle=-45,
128
  )
129
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  using zero-width-space
9
  cf. https://github.com/plotly/plotly.js/issues/1516#issuecomment-983090013
10
  """
11
+ sent1 = [word + i * "\u200b" for i, word in enumerate(sent1)]
12
+ sent2 = [word + i * "\u200b" for i, word in enumerate(sent2)]
13
+
14
  return sent1, sent2
15
 
16
 
17
  def discrete_colorscale(bvals, colors):
18
  """
19
  bvals - list of values bounding intervals/ranges of interest
20
+ colors - list of rgb or hex colorcodes for values in [bvals[k], bvals[k+1]],0<=k < len(bvals)-1
21
  returns the plotly discrete colorscale
22
  ref. https://community.plotly.com/t/colors-for-discrete-ranges-in-heatmaps/7780
23
  """
24
+ if len(bvals) != len(colors) + 1:
25
+ raise ValueError("len(boundary values) should be equal to len(colors)+1")
26
+ bvals = sorted(bvals)
27
+ nvals = [
28
+ (v - bvals[0]) / (bvals[-1] - bvals[0]) for v in bvals
29
+ ] # normalized values
30
+
31
+ dcolorscale = [] # discrete colorscale
32
  for k in range(len(colors)):
33
+ dcolorscale.extend([[nvals[k], colors[k]], [nvals[k + 1], colors[k]]])
34
+ return dcolorscale
35
 
36
 
37
  def plot_align_matrix_heatmap_plotly(align_matrix, sent1, sent2, threshhold, Cost):
38
  align_matrix = np.where(align_matrix <= threshhold, 0, align_matrix)
39
  sent1, sent2 = _debug_non_unique_axis_values(sent1, sent2)
40
+ _colors = [
41
+ "#F2F2F2",
42
+ "#E0F4FA",
43
+ "#BEE4F0",
44
+ "#88CCE5",
45
+ "#33b7df",
46
+ "#1B88A6",
47
+ "#105264",
48
+ "#092E39",
49
+ ]
50
  _ticks = [0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
51
 
52
  colorscale = discrete_colorscale(_ticks, _colors)
53
 
54
  fig = go.Figure()
55
+
56
+ fig.add_trace(
57
+ go.Heatmap(
58
+ z=align_matrix,
59
+ customdata=Cost,
60
+ x=sent1,
61
+ y=sent2,
62
+ xgap=2,
63
+ ygap=2,
64
+ colorscale=colorscale,
65
+ colorbar=dict(tick0=0, dtick=0.125, outlinewidth=0),
66
+ hovertemplate="x: %{x}<br>"
67
+ + "y: %{y}<br>"
68
+ + "P: %{z:.3f}<br>"
69
+ + "cost: %{customdata:.3f} ",
70
+ name="",
71
+ )
72
+ )
 
 
 
73
  fig.update_layout(
74
+ # xaxis=dict(scaleanchor='y'),
75
+ yaxis=dict(autorange="reversed"),
76
+ margin={"l": 0, "r": 0, "t": 0, "b": 0},
77
+ plot_bgcolor="rgba(0,0,0,0)",
78
  font=dict(
79
  size=16,
80
  ),
81
  hoverlabel=dict(
82
+ bgcolor="#555", font_color="white", font_size=14, font_family="Open Sans"
83
+ ),
 
 
 
84
  )
85
  fig.update_xaxes(
86
  tickangle=-45,
 
88
  return fig
89
 
90
 
91
+ def plot_similarity_matrix_heatmap_plotly(
92
+ similarity_matrix, sent1, sent2, Cost, colorscale="Reds", hover_z="cosine"
93
+ ):
94
  sent1, sent2 = _debug_non_unique_axis_values(sent1, sent2)
95
 
96
  fig = go.Figure()
97
+
98
+ fig.add_trace(
99
+ go.Heatmap(
100
+ z=similarity_matrix,
101
+ customdata=Cost,
102
+ x=sent1,
103
+ y=sent2,
104
+ xgap=2,
105
+ ygap=2,
106
+ colorscale=colorscale,
107
+ colorbar=dict(tick0=0, dtick=0.125, outlinewidth=0),
108
+ hovertemplate="x: %{x}<br>"
109
+ + "y: %{y}<br>"
110
+ + f"{hover_z}: "
111
+ + "%{z:.3f}<br>"
112
+ + "cost: %{customdata:.3f} ",
113
+ name="",
114
+ )
115
+ )
 
 
116
  fig.update_layout(
117
+ # xaxis=dict(scaleanchor='y'),
118
+ yaxis=dict(autorange="reversed"),
119
+ margin={"l": 0, "r": 0, "t": 0, "b": 0},
120
+ plot_bgcolor="rgba(0,0,0,0)",
121
  font=dict(
122
  size=16,
123
  ),
124
  hoverlabel=dict(
125
+ bgcolor="#555", font_color="white", font_size=14, font_family="Open Sans"
126
+ ),
 
 
 
127
  )
128
  fig.update_xaxes(
129
  tickangle=-45,
130
  )
131
+ return fig
132
+
133
+
134
+ def show_assignments_plotly(P, word_embeddings, sents1, sents2, thr=0):
135
+ P = np.where(P <= thr, 0, P)
136
+
137
+ s1_end = len(sents1)
138
+ a = word_embeddings[:s1_end]
139
+ b = word_embeddings[s1_end:]
140
+
141
+ traces = []
142
+ sample = 0
143
+
144
+ for i in range(a.shape[0]):
145
+ for j in range(b.shape[0]):
146
+ if P[i, j] > 0:
147
+ sample += 1
148
+ traces.append(
149
+ go.Scatter(
150
+ x=[a[i, 0], b[j, 0]],
151
+ y=[a[i, 1], b[j, 1]],
152
+ mode="lines",
153
+ line=dict(color="black", width=P[i, j] * 2),
154
+ opacity=P[i, j],
155
+ name=f"{sample}",
156
+ )
157
+ )
158
+
159
+ # ソースサンプルの描画
160
+ traces.append(
161
+ go.Scatter(
162
+ x=a[:, 0],
163
+ y=a[:, 1],
164
+ mode="markers+text",
165
+ marker=dict(color="blue", size=8, symbol="cross"),
166
+ text=sents1,
167
+ textposition="top center",
168
+ name="Source samples",
169
+ )
170
+ )
171
+
172
+ # ターゲットサンプルの描画
173
+ traces.append(
174
+ go.Scatter(
175
+ x=b[:, 0],
176
+ y=b[:, 1],
177
+ mode="markers+text",
178
+ marker=dict(color="red", size=8, symbol="x"),
179
+ text=sents2,
180
+ textposition="bottom center",
181
+ name="Target samples",
182
+ )
183
+ )
184
+
185
+ layout = go.Layout(
186
+ showlegend=True,
187
+ margin=dict(l=0, r=0, t=10, b=0),
188
+ )
189
+
190
+ fig = go.Figure(data=traces, layout=layout)
191
+ return fig
requirements.txt CHANGED
@@ -6,4 +6,5 @@ transformers==4.30.2
6
  matplotlib==3.7.1
7
  plotly==5.15.0
8
  torch==2.0.1
9
- nltk==3.8.1
 
 
6
  matplotlib==3.7.1
7
  plotly==5.15.0
8
  torch==2.0.1
9
+ nltk==3.8.1
10
+ umap-learn==0.5.5