File size: 4,220 Bytes
527e550
 
 
94f5fd3
527e550
 
 
 
 
 
 
 
37d364a
527e550
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94f5fd3
f31ab4f
94f5fd3
 
 
 
527e550
f31ab4f
94f5fd3
527e550
 
 
 
f31ab4f
527e550
 
 
 
 
94f5fd3
527e550
 
94f5fd3
 
 
 
 
527e550
94f5fd3
527e550
 
 
 
94f5fd3
37d364a
527e550
 
94f5fd3
 
 
 
 
527e550
 
94f5fd3
527e550
 
 
 
37d364a
527e550
 
94f5fd3
 
 
 
527e550
94f5fd3
 
 
 
527e550
 
 
 
 
 
f31ab4f
527e550
94f5fd3
527e550
 
 
 
 
 
37d364a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import numpy as np
import torch
import ot
from otfuncs import (
    compute_distance_matrix_cosine,
    compute_distance_matrix_l2,
    compute_weights_norm,
    compute_weights_uniform,
    min_max_scaling
)

class Aligner:
    def __init__(self, ot_type, sinkhorn, dist_type, weight_type, distortion, thresh, tau, **kwargs):
        self.ot_type = ot_type
        self.sinkhorn = sinkhorn
        self.dist_type = dist_type
        self.weight_type = weight_type
        self.distotion = distortion
        self.thresh = thresh
        self.tau = tau
        self.epsilon = 0.1
        self.stopThr = 1e-6
        self.numItermax = 1000
        self.div_type = kwargs['div_type']

        self.dist_func = compute_distance_matrix_cosine if dist_type == 'cos' else compute_distance_matrix_l2
        if weight_type == 'uniform':
            self.weight_func = compute_weights_uniform
        else:
            self.weight_func = compute_weights_norm

    def compute_alignment_matrixes(self, s1_word_embeddigs, s2_word_embeddigs):
        P, Cost, log, similarity_matrix = self.compute_optimal_transport(s1_word_embeddigs, s2_word_embeddigs)
        print(log.keys())
        if torch.is_tensor(P):
            P = P.to('cpu').numpy()
        loss = log.get('cost', 'NotImplemented')

        return P, Cost, loss, similarity_matrix
        
    def compute_optimal_transport(self, s1_word_embeddigs, s2_word_embeddigs):
        s1_word_embeddigs = s1_word_embeddigs.to(torch.float64)
        s2_word_embeddigs = s2_word_embeddigs.to(torch.float64)

        C, similarity_matrix = self.dist_func(s1_word_embeddigs, s2_word_embeddigs, self.distotion)
        s1_weights, s2_weights = self.weight_func(s1_word_embeddigs, s2_word_embeddigs)

        if self.ot_type == 'ot':
            s1_weights = s1_weights / s1_weights.sum()
            s2_weights = s2_weights / s2_weights.sum()
            s1_weights, s2_weights, C = self.convert_to_numpy(s1_weights, s2_weights, C)

            if self.sinkhorn:
                P, log = ot.bregman.sinkhorn_log(
                    s1_weights, s2_weights, C,
                    reg=self.epsilon, stopThr=self.stopThr,
                    numItermax=self.numItermax, log=True
                )
            else:
                P, log = ot.emd(s1_weights, s2_weights, C, log=True)
            # Min-max normalization
            P = min_max_scaling(P)

        elif self.ot_type == 'pot':
            s1_weights, s2_weights, C = self.convert_to_numpy(s1_weights, s2_weights, C)
            m = np.min((np.sum(s1_weights), np.sum(s2_weights))) * self.tau

            if self.sinkhorn:
                P, log = ot.partial.entropic_partial_wasserstein(
                    s1_weights, s2_weights, C,
                    reg=self.epsilon,
                    m=m, stopThr=self.stopThr, numItermax=self.numItermax, log=True
                )
            else:
                # To cope with round error
                P, log = ot.partial.partial_wasserstein(s1_weights, s2_weights, C, m=m, log=True)
            # Min-max normalization
            P = min_max_scaling(P)

        elif 'uot' in self.ot_type:
            tau = self.tau

            if self.ot_type == 'uot':
                P, log = ot.unbalanced.sinkhorn_stabilized_unbalanced(
                    s1_weights, s2_weights, C, reg=self.epsilon, reg_m=tau,
                    stopThr=self.stopThr, numItermax=self.numItermax, log=True
                )
            elif self.ot_type == 'uot-mm':
                P, log = ot.unbalanced.mm_unbalanced(
                    s1_weights, s2_weights, C, reg_m=tau, div=self.div_type,
                    stopThr=self.stopThr, numItermax=self.numItermax, log=True
                )
            # Min-max normalization
            P = min_max_scaling(P)

        elif self.ot_type == 'none':
            P = 1 - C

        return P, C, log, similarity_matrix

    def convert_to_numpy(self, s1_weights, s2_weights, C):
        if torch.is_tensor(s1_weights):
            s1_weights = s1_weights.to('cpu').numpy()
            s2_weights = s2_weights.to('cpu').numpy()
        if torch.is_tensor(C):
            C = C.to('cpu').numpy()

        return s1_weights, s2_weights, C