File size: 4,220 Bytes
527e550 94f5fd3 527e550 37d364a 527e550 94f5fd3 f31ab4f 94f5fd3 527e550 f31ab4f 94f5fd3 527e550 f31ab4f 527e550 94f5fd3 527e550 94f5fd3 527e550 94f5fd3 527e550 94f5fd3 37d364a 527e550 94f5fd3 527e550 94f5fd3 527e550 37d364a 527e550 94f5fd3 527e550 94f5fd3 527e550 f31ab4f 527e550 94f5fd3 527e550 37d364a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import numpy as np
import torch
import ot
from otfuncs import (
compute_distance_matrix_cosine,
compute_distance_matrix_l2,
compute_weights_norm,
compute_weights_uniform,
min_max_scaling
)
class Aligner:
def __init__(self, ot_type, sinkhorn, dist_type, weight_type, distortion, thresh, tau, **kwargs):
self.ot_type = ot_type
self.sinkhorn = sinkhorn
self.dist_type = dist_type
self.weight_type = weight_type
self.distotion = distortion
self.thresh = thresh
self.tau = tau
self.epsilon = 0.1
self.stopThr = 1e-6
self.numItermax = 1000
self.div_type = kwargs['div_type']
self.dist_func = compute_distance_matrix_cosine if dist_type == 'cos' else compute_distance_matrix_l2
if weight_type == 'uniform':
self.weight_func = compute_weights_uniform
else:
self.weight_func = compute_weights_norm
def compute_alignment_matrixes(self, s1_word_embeddigs, s2_word_embeddigs):
P, Cost, log, similarity_matrix = self.compute_optimal_transport(s1_word_embeddigs, s2_word_embeddigs)
print(log.keys())
if torch.is_tensor(P):
P = P.to('cpu').numpy()
loss = log.get('cost', 'NotImplemented')
return P, Cost, loss, similarity_matrix
def compute_optimal_transport(self, s1_word_embeddigs, s2_word_embeddigs):
s1_word_embeddigs = s1_word_embeddigs.to(torch.float64)
s2_word_embeddigs = s2_word_embeddigs.to(torch.float64)
C, similarity_matrix = self.dist_func(s1_word_embeddigs, s2_word_embeddigs, self.distotion)
s1_weights, s2_weights = self.weight_func(s1_word_embeddigs, s2_word_embeddigs)
if self.ot_type == 'ot':
s1_weights = s1_weights / s1_weights.sum()
s2_weights = s2_weights / s2_weights.sum()
s1_weights, s2_weights, C = self.convert_to_numpy(s1_weights, s2_weights, C)
if self.sinkhorn:
P, log = ot.bregman.sinkhorn_log(
s1_weights, s2_weights, C,
reg=self.epsilon, stopThr=self.stopThr,
numItermax=self.numItermax, log=True
)
else:
P, log = ot.emd(s1_weights, s2_weights, C, log=True)
# Min-max normalization
P = min_max_scaling(P)
elif self.ot_type == 'pot':
s1_weights, s2_weights, C = self.convert_to_numpy(s1_weights, s2_weights, C)
m = np.min((np.sum(s1_weights), np.sum(s2_weights))) * self.tau
if self.sinkhorn:
P, log = ot.partial.entropic_partial_wasserstein(
s1_weights, s2_weights, C,
reg=self.epsilon,
m=m, stopThr=self.stopThr, numItermax=self.numItermax, log=True
)
else:
# To cope with round error
P, log = ot.partial.partial_wasserstein(s1_weights, s2_weights, C, m=m, log=True)
# Min-max normalization
P = min_max_scaling(P)
elif 'uot' in self.ot_type:
tau = self.tau
if self.ot_type == 'uot':
P, log = ot.unbalanced.sinkhorn_stabilized_unbalanced(
s1_weights, s2_weights, C, reg=self.epsilon, reg_m=tau,
stopThr=self.stopThr, numItermax=self.numItermax, log=True
)
elif self.ot_type == 'uot-mm':
P, log = ot.unbalanced.mm_unbalanced(
s1_weights, s2_weights, C, reg_m=tau, div=self.div_type,
stopThr=self.stopThr, numItermax=self.numItermax, log=True
)
# Min-max normalization
P = min_max_scaling(P)
elif self.ot_type == 'none':
P = 1 - C
return P, C, log, similarity_matrix
def convert_to_numpy(self, s1_weights, s2_weights, C):
if torch.is_tensor(s1_weights):
s1_weights = s1_weights.to('cpu').numpy()
s2_weights = s2_weights.to('cpu').numpy()
if torch.is_tensor(C):
C = C.to('cpu').numpy()
return s1_weights, s2_weights, C
|