File size: 5,171 Bytes
527e550
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import numpy as np
import torch
import ot
from utils import (
    compute_distance_matrix_cosine,
    compute_distance_matrix_l2,
    compute_weights_norm,
    compute_weights_uniform,
    min_max_scaling
)

class Aligner:
    def __init__(self, ot_type, sinkhorn, chimera, dist_type, weight_type, distortion, thresh, tau, **kwargs):
        self.ot_type = ot_type
        self.sinkhorn = sinkhorn
        self.chimera = chimera
        self.dist_type = dist_type
        self.weight_type = weight_type
        self.distotion = distortion
        self.thresh = thresh
        self.tau = tau
        self.epsilon = 0.1
        self.stopThr = 1e-6
        self.numItermax = 1000
        self.div_type = kwargs['div_type']

        self.dist_func = compute_distance_matrix_cosine if dist_type == 'cos' else compute_distance_matrix_l2
        if weight_type == 'uniform':
            self.weight_func = compute_weights_uniform
        else:
            self.weight_func = compute_weights_norm

    def compute_alignment_matrixes(self, s1_vecs, s2_vecs):
        self.align_matrixes = []
        for vecX, vecY in zip(s1_vecs, s2_vecs):
            P = self.compute_optimal_transport(vecX, vecY)
            if torch.is_tensor(P):
                P = P.to('cpu').numpy()

            self.align_matrixes.append(P)

    def get_alignments(self, thresh, assign_cost=False):
        assert len(self.align_matrixes) > 0

        self.thresh = thresh
        all_alignments = []
        for P in self.align_matrixes:
            alignments = self.matrix_to_alignments(P, assign_cost)
            all_alignments.append(alignments)

        return all_alignments

    def matrix_to_alignments(self, P, assign_cost):
        alignments = set()
        align_pairs = np.transpose(np.nonzero(P > self.thresh))
        if assign_cost:
            for i_j in align_pairs:
                alignments.add('{0}-{1}-{2:.4f}'.format(i_j[0], i_j[1], P[i_j[0], i_j[1]]))
        else:
            for i_j in align_pairs:
                alignments.add('{0}-{1}'.format(i_j[0], i_j[1]))

        return alignments

    def compute_optimal_transport(self, s1_word_embeddigs, s2_word_embeddigs):
        s1_word_embeddigs = s1_word_embeddigs.to(torch.float64)
        s2_word_embeddigs = s2_word_embeddigs.to(torch.float64)

        C = self.dist_func(s1_word_embeddigs, s2_word_embeddigs, self.distotion)
        s1_weights, s2_weights = self.weight_func(s1_word_embeddigs, s2_word_embeddigs)

        if self.ot_type == 'ot':
            s1_weights = s1_weights / s1_weights.sum()
            s2_weights = s2_weights / s2_weights.sum()
            s1_weights, s2_weights, C = self.comvert_to_numpy(s1_weights, s2_weights, C)

            if self.sinkhorn:
                P = ot.bregman.sinkhorn_log(s1_weights, s2_weights, C, reg=self.epsilon, stopThr=self.stopThr,
                                            numItermax=self.numItermax)
            else:
                P = ot.emd(s1_weights, s2_weights, C)
            # Min-max normalization
            P = min_max_scaling(P)

        elif self.ot_type == 'pot':
            if self.chimera:
                m = self.tau * self.bertscore_F1(s1_word_embeddigs, s2_word_embeddigs)
                m = min(1.0, m.item())
            else:
                m = self.tau

            s1_weights, s2_weights, C = self.comvert_to_numpy(s1_weights, s2_weights, C)
            m = np.min((np.sum(s1_weights), np.sum(s2_weights))) * m

            if self.sinkhorn:
                P = ot.partial.entropic_partial_wasserstein(s1_weights, s2_weights, C,
                                                            reg=self.epsilon,
                                                            m=m, stopThr=self.stopThr, numItermax=self.numItermax)
            else:
                # To cope with round error
                P = ot.partial.partial_wasserstein(s1_weights, s2_weights, C, m=m)
            # Min-max normalization
            P = min_max_scaling(P)

        elif 'uot' in self.ot_type:
            if self.chimera:
                tau = self.tau * self.bertscore_F1(s1_word_embeddigs, s2_word_embeddigs)
            else:
                tau = self.tau

            if self.ot_type == 'uot':
                P = ot.unbalanced.sinkhorn_stabilized_unbalanced(s1_weights, s2_weights, C, reg=self.epsilon, reg_m=tau,
                                                                 stopThr=self.stopThr, numItermax=self.numItermax)
            elif self.ot_type == 'uot-mm':
                P = ot.unbalanced.mm_unbalanced(s1_weights, s2_weights, C, reg_m=tau, div=self.div_type,
                                                stopThr=self.stopThr, numItermax=self.numItermax)
            # Min-max normalization
            P = min_max_scaling(P)

        elif self.ot_type == 'none':
            P = 1 - C

        return P

    def comvert_to_numpy(self, s1_weights, s2_weights, C):
        if torch.is_tensor(s1_weights):
            s1_weights = s1_weights.to('cpu').numpy()
            s2_weights = s2_weights.to('cpu').numpy()
        if torch.is_tensor(C):
            C = C.to('cpu').numpy()

        return s1_weights, s2_weights, C