import numpy as np import torch import ot from otfuncs import ( compute_distance_matrix_cosine, compute_distance_matrix_l2, compute_weights_norm, compute_weights_uniform, min_max_scaling ) class Aligner: def __init__(self, ot_type, sinkhorn, dist_type, weight_type, distortion, thresh, tau, **kwargs): self.ot_type = ot_type self.sinkhorn = sinkhorn self.dist_type = dist_type self.weight_type = weight_type self.distotion = distortion self.thresh = thresh self.tau = tau self.epsilon = 0.1 self.stopThr = 1e-6 self.numItermax = 1000 self.div_type = kwargs['div_type'] self.dist_func = compute_distance_matrix_cosine if dist_type == 'cos' else compute_distance_matrix_l2 if weight_type == 'uniform': self.weight_func = compute_weights_uniform else: self.weight_func = compute_weights_norm def compute_alignment_matrixes(self, s1_word_embeddigs, s2_word_embeddigs): P, Cost, log, similarity_matrix = self.compute_optimal_transport(s1_word_embeddigs, s2_word_embeddigs) print(log.keys()) if torch.is_tensor(P): P = P.to('cpu').numpy() loss = log.get('cost', 'NotImplemented') return P, Cost, loss, similarity_matrix def compute_optimal_transport(self, s1_word_embeddigs, s2_word_embeddigs): s1_word_embeddigs = s1_word_embeddigs.to(torch.float64) s2_word_embeddigs = s2_word_embeddigs.to(torch.float64) C, similarity_matrix = self.dist_func(s1_word_embeddigs, s2_word_embeddigs, self.distotion) s1_weights, s2_weights = self.weight_func(s1_word_embeddigs, s2_word_embeddigs) if self.ot_type == 'ot': s1_weights = s1_weights / s1_weights.sum() s2_weights = s2_weights / s2_weights.sum() s1_weights, s2_weights, C = self.convert_to_numpy(s1_weights, s2_weights, C) if self.sinkhorn: P, log = ot.bregman.sinkhorn_log( s1_weights, s2_weights, C, reg=self.epsilon, stopThr=self.stopThr, numItermax=self.numItermax, log=True ) else: P, log = ot.emd(s1_weights, s2_weights, C, log=True) # Min-max normalization P = min_max_scaling(P) elif self.ot_type == 'pot': s1_weights, s2_weights, C = self.convert_to_numpy(s1_weights, s2_weights, C) m = np.min((np.sum(s1_weights), np.sum(s2_weights))) * self.tau if self.sinkhorn: P, log = ot.partial.entropic_partial_wasserstein( s1_weights, s2_weights, C, reg=self.epsilon, m=m, stopThr=self.stopThr, numItermax=self.numItermax, log=True ) else: # To cope with round error P, log = ot.partial.partial_wasserstein(s1_weights, s2_weights, C, m=m, log=True) # Min-max normalization P = min_max_scaling(P) elif 'uot' in self.ot_type: tau = self.tau if self.ot_type == 'uot': P, log = ot.unbalanced.sinkhorn_stabilized_unbalanced( s1_weights, s2_weights, C, reg=self.epsilon, reg_m=tau, stopThr=self.stopThr, numItermax=self.numItermax, log=True ) elif self.ot_type == 'uot-mm': P, log = ot.unbalanced.mm_unbalanced( s1_weights, s2_weights, C, reg_m=tau, div=self.div_type, stopThr=self.stopThr, numItermax=self.numItermax, log=True ) # Min-max normalization P = min_max_scaling(P) elif self.ot_type == 'none': P = 1 - C return P, C, log, similarity_matrix def convert_to_numpy(self, s1_weights, s2_weights, C): if torch.is_tensor(s1_weights): s1_weights = s1_weights.to('cpu').numpy() s2_weights = s2_weights.to('cpu').numpy() if torch.is_tensor(C): C = C.to('cpu').numpy() return s1_weights, s2_weights, C