File size: 5,171 Bytes
527e550 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import numpy as np
import torch
import ot
from utils import (
compute_distance_matrix_cosine,
compute_distance_matrix_l2,
compute_weights_norm,
compute_weights_uniform,
min_max_scaling
)
class Aligner:
def __init__(self, ot_type, sinkhorn, chimera, dist_type, weight_type, distortion, thresh, tau, **kwargs):
self.ot_type = ot_type
self.sinkhorn = sinkhorn
self.chimera = chimera
self.dist_type = dist_type
self.weight_type = weight_type
self.distotion = distortion
self.thresh = thresh
self.tau = tau
self.epsilon = 0.1
self.stopThr = 1e-6
self.numItermax = 1000
self.div_type = kwargs['div_type']
self.dist_func = compute_distance_matrix_cosine if dist_type == 'cos' else compute_distance_matrix_l2
if weight_type == 'uniform':
self.weight_func = compute_weights_uniform
else:
self.weight_func = compute_weights_norm
def compute_alignment_matrixes(self, s1_vecs, s2_vecs):
self.align_matrixes = []
for vecX, vecY in zip(s1_vecs, s2_vecs):
P = self.compute_optimal_transport(vecX, vecY)
if torch.is_tensor(P):
P = P.to('cpu').numpy()
self.align_matrixes.append(P)
def get_alignments(self, thresh, assign_cost=False):
assert len(self.align_matrixes) > 0
self.thresh = thresh
all_alignments = []
for P in self.align_matrixes:
alignments = self.matrix_to_alignments(P, assign_cost)
all_alignments.append(alignments)
return all_alignments
def matrix_to_alignments(self, P, assign_cost):
alignments = set()
align_pairs = np.transpose(np.nonzero(P > self.thresh))
if assign_cost:
for i_j in align_pairs:
alignments.add('{0}-{1}-{2:.4f}'.format(i_j[0], i_j[1], P[i_j[0], i_j[1]]))
else:
for i_j in align_pairs:
alignments.add('{0}-{1}'.format(i_j[0], i_j[1]))
return alignments
def compute_optimal_transport(self, s1_word_embeddigs, s2_word_embeddigs):
s1_word_embeddigs = s1_word_embeddigs.to(torch.float64)
s2_word_embeddigs = s2_word_embeddigs.to(torch.float64)
C = self.dist_func(s1_word_embeddigs, s2_word_embeddigs, self.distotion)
s1_weights, s2_weights = self.weight_func(s1_word_embeddigs, s2_word_embeddigs)
if self.ot_type == 'ot':
s1_weights = s1_weights / s1_weights.sum()
s2_weights = s2_weights / s2_weights.sum()
s1_weights, s2_weights, C = self.comvert_to_numpy(s1_weights, s2_weights, C)
if self.sinkhorn:
P = ot.bregman.sinkhorn_log(s1_weights, s2_weights, C, reg=self.epsilon, stopThr=self.stopThr,
numItermax=self.numItermax)
else:
P = ot.emd(s1_weights, s2_weights, C)
# Min-max normalization
P = min_max_scaling(P)
elif self.ot_type == 'pot':
if self.chimera:
m = self.tau * self.bertscore_F1(s1_word_embeddigs, s2_word_embeddigs)
m = min(1.0, m.item())
else:
m = self.tau
s1_weights, s2_weights, C = self.comvert_to_numpy(s1_weights, s2_weights, C)
m = np.min((np.sum(s1_weights), np.sum(s2_weights))) * m
if self.sinkhorn:
P = ot.partial.entropic_partial_wasserstein(s1_weights, s2_weights, C,
reg=self.epsilon,
m=m, stopThr=self.stopThr, numItermax=self.numItermax)
else:
# To cope with round error
P = ot.partial.partial_wasserstein(s1_weights, s2_weights, C, m=m)
# Min-max normalization
P = min_max_scaling(P)
elif 'uot' in self.ot_type:
if self.chimera:
tau = self.tau * self.bertscore_F1(s1_word_embeddigs, s2_word_embeddigs)
else:
tau = self.tau
if self.ot_type == 'uot':
P = ot.unbalanced.sinkhorn_stabilized_unbalanced(s1_weights, s2_weights, C, reg=self.epsilon, reg_m=tau,
stopThr=self.stopThr, numItermax=self.numItermax)
elif self.ot_type == 'uot-mm':
P = ot.unbalanced.mm_unbalanced(s1_weights, s2_weights, C, reg_m=tau, div=self.div_type,
stopThr=self.stopThr, numItermax=self.numItermax)
# Min-max normalization
P = min_max_scaling(P)
elif self.ot_type == 'none':
P = 1 - C
return P
def comvert_to_numpy(self, s1_weights, s2_weights, C):
if torch.is_tensor(s1_weights):
s1_weights = s1_weights.to('cpu').numpy()
s2_weights = s2_weights.to('cpu').numpy()
if torch.is_tensor(C):
C = C.to('cpu').numpy()
return s1_weights, s2_weights, C |