|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from ot.backend import get_backend |
|
import plotly.graph_objects as go |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
def compute_distance_matrix_cosine(s1_word_embeddigs, s2_word_embeddigs, distortion_ratio): |
|
C = (torch.matmul(F.normalize(s1_word_embeddigs), F.normalize(s2_word_embeddigs).t()) + 1.0) / 2 |
|
C = apply_distortion(C, distortion_ratio) |
|
C = min_max_scaling(C) |
|
C = 1.0 - C |
|
|
|
return C |
|
|
|
|
|
def compute_distance_matrix_l2(s1_word_embeddigs, s2_word_embeddigs, distortion_ratio): |
|
C = torch.cdist(s1_word_embeddigs, s2_word_embeddigs, p=2) |
|
C = min_max_scaling(C) |
|
C = 1.0 - C |
|
C = apply_distortion(C, distortion_ratio) |
|
C = min_max_scaling(C) |
|
C = 1.0 - C |
|
|
|
return C |
|
|
|
|
|
def apply_distortion(sim_matrix, ratio): |
|
shape = sim_matrix.shape |
|
if (shape[0] < 2 or shape[1] < 2) or ratio == 0.0: |
|
return sim_matrix |
|
|
|
pos_x = torch.tensor([[y / float(shape[1] - 1) for y in range(shape[1])] for x in range(shape[0])], |
|
device=device) |
|
pos_y = torch.tensor([[x / float(shape[0] - 1) for x in range(shape[0])] for y in range(shape[1])], |
|
device=device) |
|
distortion_mask = 1.0 - ((pos_x - pos_y.T) ** 2) * ratio |
|
|
|
sim_matrix = torch.mul(sim_matrix, distortion_mask) |
|
|
|
return sim_matrix |
|
|
|
|
|
def compute_weights_norm(s1_word_embeddigs, s2_word_embeddigs): |
|
s1_weights = torch.norm(s1_word_embeddigs, dim=1) |
|
s2_weights = torch.norm(s2_word_embeddigs, dim=1) |
|
return s1_weights, s2_weights |
|
|
|
|
|
def compute_weights_uniform(s1_word_embeddigs, s2_word_embeddigs): |
|
s1_weights = torch.ones(s1_word_embeddigs.shape[0], dtype=torch.float64, device=device) |
|
s2_weights = torch.ones(s2_word_embeddigs.shape[0], dtype=torch.float64, device=device) |
|
|
|
|
|
|
|
|
|
|
|
return s1_weights, s2_weights |
|
|
|
|
|
def min_max_scaling(C): |
|
eps = 1e-10 |
|
|
|
nx = get_backend(C) |
|
C_min = nx.min(C) |
|
C_max = nx.max(C) |
|
C = (C - C_min + eps) / (C_max - C_min + eps) |
|
return C |
|
|
|
|
|
import seaborn as sns |
|
import matplotlib.pyplot as plt |
|
from mpl_toolkits.axes_grid1 import make_axes_locatable |
|
|
|
def plot_align_matrix_heatmap(align_matrix, sent1, sent2, thresh, **kwargs): |
|
|
|
align_matrix = np.where(align_matrix <= thresh, 0, align_matrix) |
|
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
sns.set(font='sans-serif', style="ticks") |
|
|
|
_color = ['#F2F2F2', '#E0F4FA', '#BEE4F0', '#88CCE5', '#33b7df', '#1B88A6', '#105264', '#092E39'] |
|
_ticks = [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0] |
|
|
|
divider = make_axes_locatable(ax) |
|
cbar_ax = divider.append_axes("right", size="2.5%", pad=0.1) |
|
fig.add_axes(cbar_ax) |
|
ax = sns.heatmap( |
|
align_matrix, |
|
xticklabels=sent1, |
|
yticklabels=sent2, |
|
cmap=_color, |
|
linewidths=1, |
|
square=True, |
|
ax=ax, |
|
cbar_ax=cbar_ax, |
|
**kwargs |
|
) |
|
ax.collections[0].colorbar.ax.yaxis.set_ticks(_ticks, minor=False) |
|
ax.collections[0].colorbar.set_ticklabels(_ticks) |
|
cax = ax.collections[0].colorbar.ax |
|
cax.tick_params(which='major', length=3, labelsize=5) |
|
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right') |
|
ax.set_yticklabels(ax.get_yticklabels(), rotation=0) |
|
return fig |