UOT / aligner.py
4kasha
update demo
94f5fd3
raw
history blame
4.6 kB
import numpy as np
import torch
import ot
from otfuncs import (
compute_distance_matrix_cosine,
compute_distance_matrix_l2,
compute_weights_norm,
compute_weights_uniform,
min_max_scaling
)
class Aligner:
def __init__(self, ot_type, sinkhorn, chimera, dist_type, weight_type, distortion, thresh, tau, **kwargs):
self.ot_type = ot_type
self.sinkhorn = sinkhorn
self.chimera = chimera
self.dist_type = dist_type
self.weight_type = weight_type
self.distotion = distortion
self.thresh = thresh
self.tau = tau
self.epsilon = 0.1
self.stopThr = 1e-6
self.numItermax = 1000
self.div_type = kwargs['div_type']
self.dist_func = compute_distance_matrix_cosine if dist_type == 'cos' else compute_distance_matrix_l2
if weight_type == 'uniform':
self.weight_func = compute_weights_uniform
else:
self.weight_func = compute_weights_norm
def compute_alignment_matrixes(self, s1_word_embeddigs, s2_word_embeddigs):
P, Cost, log, similarity_matrix = self.compute_optimal_transport(s1_word_embeddigs, s2_word_embeddigs)
print(log.keys())
if torch.is_tensor(P):
P = P.to('cpu').numpy()
loss = log.get('cost', 'NotImplemented')
return P, Cost, loss, similarity_matrix
def compute_optimal_transport(self, s1_word_embeddigs, s2_word_embeddigs):
s1_word_embeddigs = s1_word_embeddigs.to(torch.float64)
s2_word_embeddigs = s2_word_embeddigs.to(torch.float64)
C, similarity_matrix = self.dist_func(s1_word_embeddigs, s2_word_embeddigs, self.distotion)
s1_weights, s2_weights = self.weight_func(s1_word_embeddigs, s2_word_embeddigs)
if self.ot_type == 'ot':
s1_weights = s1_weights / s1_weights.sum()
s2_weights = s2_weights / s2_weights.sum()
s1_weights, s2_weights, C = self.convert_to_numpy(s1_weights, s2_weights, C)
if self.sinkhorn:
P, log = ot.bregman.sinkhorn_log(
s1_weights, s2_weights, C,
reg=self.epsilon, stopThr=self.stopThr,
numItermax=self.numItermax, log=True
)
else:
P, log = ot.emd(s1_weights, s2_weights, C, log=True)
# Min-max normalization
P = min_max_scaling(P)
elif self.ot_type == 'pot':
if self.chimera:
m = self.tau * self.bertscore_F1(s1_word_embeddigs, s2_word_embeddigs)
m = min(1.0, m.item())
else:
m = self.tau
s1_weights, s2_weights, C = self.convert_to_numpy(s1_weights, s2_weights, C)
m = np.min((np.sum(s1_weights), np.sum(s2_weights))) * m
if self.sinkhorn:
P, log = ot.partial.entropic_partial_wasserstein(
s1_weights, s2_weights, C,
reg=self.epsilon,
m=m, stopThr=self.stopThr, numItermax=self.numItermax, log=True
)
else:
# To cope with round error
P, log = ot.partial.partial_wasserstein(s1_weights, s2_weights, C, m=m, log=True)
# Min-max normalization
P = min_max_scaling(P)
elif 'uot' in self.ot_type:
if self.chimera:
tau = self.tau * self.bertscore_F1(s1_word_embeddigs, s2_word_embeddigs)
else:
tau = self.tau
if self.ot_type == 'uot':
P, log = ot.unbalanced.sinkhorn_stabilized_unbalanced(
s1_weights, s2_weights, C, reg=self.epsilon, reg_m=tau,
stopThr=self.stopThr, numItermax=self.numItermax, log=True
)
elif self.ot_type == 'uot-mm':
P, log = ot.unbalanced.mm_unbalanced(
s1_weights, s2_weights, C, reg_m=tau, div=self.div_type,
stopThr=self.stopThr, numItermax=self.numItermax, log=True
)
# Min-max normalization
P = min_max_scaling(P)
elif self.ot_type == 'none':
P = 1 - C
return P, C, log, similarity_matrix
def convert_to_numpy(self, s1_weights, s2_weights, C):
if torch.is_tensor(s1_weights):
s1_weights = s1_weights.to('cpu').numpy()
s2_weights = s2_weights.to('cpu').numpy()
if torch.is_tensor(C):
C = C.to('cpu').numpy()
return s1_weights, s2_weights, C