4kasha commited on
Commit
f31ab4f
1 Parent(s): 37d364a
Files changed (3) hide show
  1. aligner.py +4 -4
  2. app.py +4 -6
  3. otfuncs.py +3 -3
aligner.py CHANGED
@@ -30,19 +30,19 @@ class Aligner:
30
  self.weight_func = compute_weights_norm
31
 
32
  def compute_alignment_matrixes(self, s1_word_embeddigs, s2_word_embeddigs):
33
- P, Cost, log, similarity_matrix, relative_distance = self.compute_optimal_transport(s1_word_embeddigs, s2_word_embeddigs)
34
  print(log.keys())
35
  if torch.is_tensor(P):
36
  P = P.to('cpu').numpy()
37
  loss = log.get('cost', 'NotImplemented')
38
 
39
- return P, Cost, loss, similarity_matrix, relative_distance
40
 
41
  def compute_optimal_transport(self, s1_word_embeddigs, s2_word_embeddigs):
42
  s1_word_embeddigs = s1_word_embeddigs.to(torch.float64)
43
  s2_word_embeddigs = s2_word_embeddigs.to(torch.float64)
44
 
45
- C, similarity_matrix, relative_distance = self.dist_func(s1_word_embeddigs, s2_word_embeddigs, self.distotion)
46
  s1_weights, s2_weights = self.weight_func(s1_word_embeddigs, s2_word_embeddigs)
47
 
48
  if self.ot_type == 'ot':
@@ -96,7 +96,7 @@ class Aligner:
96
  elif self.ot_type == 'none':
97
  P = 1 - C
98
 
99
- return P, C, log, similarity_matrix, relative_distance
100
 
101
  def convert_to_numpy(self, s1_weights, s2_weights, C):
102
  if torch.is_tensor(s1_weights):
 
30
  self.weight_func = compute_weights_norm
31
 
32
  def compute_alignment_matrixes(self, s1_word_embeddigs, s2_word_embeddigs):
33
+ P, Cost, log, similarity_matrix = self.compute_optimal_transport(s1_word_embeddigs, s2_word_embeddigs)
34
  print(log.keys())
35
  if torch.is_tensor(P):
36
  P = P.to('cpu').numpy()
37
  loss = log.get('cost', 'NotImplemented')
38
 
39
+ return P, Cost, loss, similarity_matrix
40
 
41
  def compute_optimal_transport(self, s1_word_embeddigs, s2_word_embeddigs):
42
  s1_word_embeddigs = s1_word_embeddigs.to(torch.float64)
43
  s2_word_embeddigs = s2_word_embeddigs.to(torch.float64)
44
 
45
+ C, similarity_matrix = self.dist_func(s1_word_embeddigs, s2_word_embeddigs, self.distotion)
46
  s1_weights, s2_weights = self.weight_func(s1_word_embeddigs, s2_word_embeddigs)
47
 
48
  if self.ot_type == 'ot':
 
96
  elif self.ot_type == 'none':
97
  P = 1 - C
98
 
99
+ return P, C, log, similarity_matrix
100
 
101
  def convert_to_numpy(self, s1_weights, s2_weights, C):
102
  if torch.is_tensor(s1_weights):
app.py CHANGED
@@ -8,8 +8,6 @@ from nltk.tokenize import word_tokenize
8
  from transformers import AutoModel, AutoTokenizer
9
 
10
  from aligner import Aligner
11
-
12
- # from utils import align_matrix_heatmap, plot_align_matrix_heatmap
13
  from plotools import (
14
  plot_align_matrix_heatmap_plotly,
15
  plot_similarity_matrix_heatmap_plotly,
@@ -45,8 +43,8 @@ def init_aligner(
45
  dist_type="cos",
46
  weight_type="uniform",
47
  distortion=distortion,
48
- thresh=threshhold, # 0.25252525252525254
49
- tau=tau, # 0.9803921568627451
50
  div_type="--",
51
  )
52
 
@@ -86,14 +84,14 @@ def main():
86
  1.0,
87
  value=0.98,
88
  help="fraction of fertility to be aligned (fraction of mass to be transported) / penalties",
89
- ) # with 0.02 interva
90
  threshhold = st.sidebar.slider(
91
  "threshhold: $\lambda$",
92
  0.0,
93
  1.0,
94
  value=0.22,
95
  help="sparsity of alignment matrix",
96
- ) # with 0.01 interval
97
  show_assignments = st.sidebar.checkbox("show assignments", value=True)
98
  if show_assignments:
99
  n_neighbors = st.sidebar.slider(
 
8
  from transformers import AutoModel, AutoTokenizer
9
 
10
  from aligner import Aligner
 
 
11
  from plotools import (
12
  plot_align_matrix_heatmap_plotly,
13
  plot_similarity_matrix_heatmap_plotly,
 
43
  dist_type="cos",
44
  weight_type="uniform",
45
  distortion=distortion,
46
+ thresh=threshhold,
47
+ tau=tau,
48
  div_type="--",
49
  )
50
 
 
84
  1.0,
85
  value=0.98,
86
  help="fraction of fertility to be aligned (fraction of mass to be transported) / penalties",
87
+ )
88
  threshhold = st.sidebar.slider(
89
  "threshhold: $\lambda$",
90
  0.0,
91
  1.0,
92
  value=0.22,
93
  help="sparsity of alignment matrix",
94
+ )
95
  show_assignments = st.sidebar.checkbox("show assignments", value=True)
96
  if show_assignments:
97
  n_neighbors = st.sidebar.slider(
otfuncs.py CHANGED
@@ -12,11 +12,11 @@ def compute_distance_matrix_cosine(
12
  torch.matmul(F.normalize(s1_word_embeddigs), F.normalize(s2_word_embeddigs).t())
13
  + 1.0
14
  ) / 2 # Range 0-1
15
- C, relative_distance = apply_distortion(sim_matrix, distortion_ratio)
16
  C = min_max_scaling(C) # Range 0-1
17
  C = 1.0 - C # Convert to distance
18
 
19
- return C, sim_matrix, relative_distance
20
 
21
 
22
  def compute_distance_matrix_l2(s1_word_embeddigs, s2_word_embeddigs, distortion_ratio):
@@ -48,7 +48,7 @@ def apply_distortion(sim_matrix, ratio):
48
 
49
  sim_matrix = torch.mul(sim_matrix, distortion_mask)
50
 
51
- return sim_matrix, relative_distance
52
 
53
 
54
  def compute_weights_norm(s1_word_embeddigs, s2_word_embeddigs):
 
12
  torch.matmul(F.normalize(s1_word_embeddigs), F.normalize(s2_word_embeddigs).t())
13
  + 1.0
14
  ) / 2 # Range 0-1
15
+ C = apply_distortion(sim_matrix, distortion_ratio)
16
  C = min_max_scaling(C) # Range 0-1
17
  C = 1.0 - C # Convert to distance
18
 
19
+ return C, sim_matrix
20
 
21
 
22
  def compute_distance_matrix_l2(s1_word_embeddigs, s2_word_embeddigs, distortion_ratio):
 
48
 
49
  sim_matrix = torch.mul(sim_matrix, distortion_mask)
50
 
51
+ return sim_matrix
52
 
53
 
54
  def compute_weights_norm(s1_word_embeddigs, s2_word_embeddigs):