|
import numpy as np |
|
import torch |
|
import ot |
|
from utils import ( |
|
compute_distance_matrix_cosine, |
|
compute_distance_matrix_l2, |
|
compute_weights_norm, |
|
compute_weights_uniform, |
|
min_max_scaling |
|
) |
|
|
|
class Aligner: |
|
def __init__(self, ot_type, sinkhorn, chimera, dist_type, weight_type, distortion, thresh, tau, **kwargs): |
|
self.ot_type = ot_type |
|
self.sinkhorn = sinkhorn |
|
self.chimera = chimera |
|
self.dist_type = dist_type |
|
self.weight_type = weight_type |
|
self.distotion = distortion |
|
self.thresh = thresh |
|
self.tau = tau |
|
self.epsilon = 0.1 |
|
self.stopThr = 1e-6 |
|
self.numItermax = 1000 |
|
self.div_type = kwargs['div_type'] |
|
|
|
self.dist_func = compute_distance_matrix_cosine if dist_type == 'cos' else compute_distance_matrix_l2 |
|
if weight_type == 'uniform': |
|
self.weight_func = compute_weights_uniform |
|
else: |
|
self.weight_func = compute_weights_norm |
|
|
|
def compute_alignment_matrixes(self, s1_vecs, s2_vecs): |
|
self.align_matrixes = [] |
|
for vecX, vecY in zip(s1_vecs, s2_vecs): |
|
P = self.compute_optimal_transport(vecX, vecY) |
|
if torch.is_tensor(P): |
|
P = P.to('cpu').numpy() |
|
|
|
self.align_matrixes.append(P) |
|
|
|
def get_alignments(self, thresh, assign_cost=False): |
|
assert len(self.align_matrixes) > 0 |
|
|
|
self.thresh = thresh |
|
all_alignments = [] |
|
for P in self.align_matrixes: |
|
alignments = self.matrix_to_alignments(P, assign_cost) |
|
all_alignments.append(alignments) |
|
|
|
return all_alignments |
|
|
|
def matrix_to_alignments(self, P, assign_cost): |
|
alignments = set() |
|
align_pairs = np.transpose(np.nonzero(P > self.thresh)) |
|
if assign_cost: |
|
for i_j in align_pairs: |
|
alignments.add('{0}-{1}-{2:.4f}'.format(i_j[0], i_j[1], P[i_j[0], i_j[1]])) |
|
else: |
|
for i_j in align_pairs: |
|
alignments.add('{0}-{1}'.format(i_j[0], i_j[1])) |
|
|
|
return alignments |
|
|
|
def compute_optimal_transport(self, s1_word_embeddigs, s2_word_embeddigs): |
|
s1_word_embeddigs = s1_word_embeddigs.to(torch.float64) |
|
s2_word_embeddigs = s2_word_embeddigs.to(torch.float64) |
|
|
|
C = self.dist_func(s1_word_embeddigs, s2_word_embeddigs, self.distotion) |
|
s1_weights, s2_weights = self.weight_func(s1_word_embeddigs, s2_word_embeddigs) |
|
|
|
if self.ot_type == 'ot': |
|
s1_weights = s1_weights / s1_weights.sum() |
|
s2_weights = s2_weights / s2_weights.sum() |
|
s1_weights, s2_weights, C = self.comvert_to_numpy(s1_weights, s2_weights, C) |
|
|
|
if self.sinkhorn: |
|
P = ot.bregman.sinkhorn_log(s1_weights, s2_weights, C, reg=self.epsilon, stopThr=self.stopThr, |
|
numItermax=self.numItermax) |
|
else: |
|
P = ot.emd(s1_weights, s2_weights, C) |
|
|
|
P = min_max_scaling(P) |
|
|
|
elif self.ot_type == 'pot': |
|
if self.chimera: |
|
m = self.tau * self.bertscore_F1(s1_word_embeddigs, s2_word_embeddigs) |
|
m = min(1.0, m.item()) |
|
else: |
|
m = self.tau |
|
|
|
s1_weights, s2_weights, C = self.comvert_to_numpy(s1_weights, s2_weights, C) |
|
m = np.min((np.sum(s1_weights), np.sum(s2_weights))) * m |
|
|
|
if self.sinkhorn: |
|
P = ot.partial.entropic_partial_wasserstein(s1_weights, s2_weights, C, |
|
reg=self.epsilon, |
|
m=m, stopThr=self.stopThr, numItermax=self.numItermax) |
|
else: |
|
|
|
P = ot.partial.partial_wasserstein(s1_weights, s2_weights, C, m=m) |
|
|
|
P = min_max_scaling(P) |
|
|
|
elif 'uot' in self.ot_type: |
|
if self.chimera: |
|
tau = self.tau * self.bertscore_F1(s1_word_embeddigs, s2_word_embeddigs) |
|
else: |
|
tau = self.tau |
|
|
|
if self.ot_type == 'uot': |
|
P = ot.unbalanced.sinkhorn_stabilized_unbalanced(s1_weights, s2_weights, C, reg=self.epsilon, reg_m=tau, |
|
stopThr=self.stopThr, numItermax=self.numItermax) |
|
elif self.ot_type == 'uot-mm': |
|
P = ot.unbalanced.mm_unbalanced(s1_weights, s2_weights, C, reg_m=tau, div=self.div_type, |
|
stopThr=self.stopThr, numItermax=self.numItermax) |
|
|
|
P = min_max_scaling(P) |
|
|
|
elif self.ot_type == 'none': |
|
P = 1 - C |
|
|
|
return P |
|
|
|
def comvert_to_numpy(self, s1_weights, s2_weights, C): |
|
if torch.is_tensor(s1_weights): |
|
s1_weights = s1_weights.to('cpu').numpy() |
|
s2_weights = s2_weights.to('cpu').numpy() |
|
if torch.is_tensor(C): |
|
C = C.to('cpu').numpy() |
|
|
|
return s1_weights, s2_weights, C |