UOT / aligner.py
4kasha
init
527e550
raw
history blame
5.17 kB
import numpy as np
import torch
import ot
from utils import (
compute_distance_matrix_cosine,
compute_distance_matrix_l2,
compute_weights_norm,
compute_weights_uniform,
min_max_scaling
)
class Aligner:
def __init__(self, ot_type, sinkhorn, chimera, dist_type, weight_type, distortion, thresh, tau, **kwargs):
self.ot_type = ot_type
self.sinkhorn = sinkhorn
self.chimera = chimera
self.dist_type = dist_type
self.weight_type = weight_type
self.distotion = distortion
self.thresh = thresh
self.tau = tau
self.epsilon = 0.1
self.stopThr = 1e-6
self.numItermax = 1000
self.div_type = kwargs['div_type']
self.dist_func = compute_distance_matrix_cosine if dist_type == 'cos' else compute_distance_matrix_l2
if weight_type == 'uniform':
self.weight_func = compute_weights_uniform
else:
self.weight_func = compute_weights_norm
def compute_alignment_matrixes(self, s1_vecs, s2_vecs):
self.align_matrixes = []
for vecX, vecY in zip(s1_vecs, s2_vecs):
P = self.compute_optimal_transport(vecX, vecY)
if torch.is_tensor(P):
P = P.to('cpu').numpy()
self.align_matrixes.append(P)
def get_alignments(self, thresh, assign_cost=False):
assert len(self.align_matrixes) > 0
self.thresh = thresh
all_alignments = []
for P in self.align_matrixes:
alignments = self.matrix_to_alignments(P, assign_cost)
all_alignments.append(alignments)
return all_alignments
def matrix_to_alignments(self, P, assign_cost):
alignments = set()
align_pairs = np.transpose(np.nonzero(P > self.thresh))
if assign_cost:
for i_j in align_pairs:
alignments.add('{0}-{1}-{2:.4f}'.format(i_j[0], i_j[1], P[i_j[0], i_j[1]]))
else:
for i_j in align_pairs:
alignments.add('{0}-{1}'.format(i_j[0], i_j[1]))
return alignments
def compute_optimal_transport(self, s1_word_embeddigs, s2_word_embeddigs):
s1_word_embeddigs = s1_word_embeddigs.to(torch.float64)
s2_word_embeddigs = s2_word_embeddigs.to(torch.float64)
C = self.dist_func(s1_word_embeddigs, s2_word_embeddigs, self.distotion)
s1_weights, s2_weights = self.weight_func(s1_word_embeddigs, s2_word_embeddigs)
if self.ot_type == 'ot':
s1_weights = s1_weights / s1_weights.sum()
s2_weights = s2_weights / s2_weights.sum()
s1_weights, s2_weights, C = self.comvert_to_numpy(s1_weights, s2_weights, C)
if self.sinkhorn:
P = ot.bregman.sinkhorn_log(s1_weights, s2_weights, C, reg=self.epsilon, stopThr=self.stopThr,
numItermax=self.numItermax)
else:
P = ot.emd(s1_weights, s2_weights, C)
# Min-max normalization
P = min_max_scaling(P)
elif self.ot_type == 'pot':
if self.chimera:
m = self.tau * self.bertscore_F1(s1_word_embeddigs, s2_word_embeddigs)
m = min(1.0, m.item())
else:
m = self.tau
s1_weights, s2_weights, C = self.comvert_to_numpy(s1_weights, s2_weights, C)
m = np.min((np.sum(s1_weights), np.sum(s2_weights))) * m
if self.sinkhorn:
P = ot.partial.entropic_partial_wasserstein(s1_weights, s2_weights, C,
reg=self.epsilon,
m=m, stopThr=self.stopThr, numItermax=self.numItermax)
else:
# To cope with round error
P = ot.partial.partial_wasserstein(s1_weights, s2_weights, C, m=m)
# Min-max normalization
P = min_max_scaling(P)
elif 'uot' in self.ot_type:
if self.chimera:
tau = self.tau * self.bertscore_F1(s1_word_embeddigs, s2_word_embeddigs)
else:
tau = self.tau
if self.ot_type == 'uot':
P = ot.unbalanced.sinkhorn_stabilized_unbalanced(s1_weights, s2_weights, C, reg=self.epsilon, reg_m=tau,
stopThr=self.stopThr, numItermax=self.numItermax)
elif self.ot_type == 'uot-mm':
P = ot.unbalanced.mm_unbalanced(s1_weights, s2_weights, C, reg_m=tau, div=self.div_type,
stopThr=self.stopThr, numItermax=self.numItermax)
# Min-max normalization
P = min_max_scaling(P)
elif self.ot_type == 'none':
P = 1 - C
return P
def comvert_to_numpy(self, s1_weights, s2_weights, C):
if torch.is_tensor(s1_weights):
s1_weights = s1_weights.to('cpu').numpy()
s2_weights = s2_weights.to('cpu').numpy()
if torch.is_tensor(C):
C = C.to('cpu').numpy()
return s1_weights, s2_weights, C