import numpy as np import torch import torch.nn.functional as F from ot.backend import get_backend import plotly.graph_objects as go device = "cuda" if torch.cuda.is_available() else "cpu" def compute_distance_matrix_cosine(s1_word_embeddigs, s2_word_embeddigs, distortion_ratio): C = (torch.matmul(F.normalize(s1_word_embeddigs), F.normalize(s2_word_embeddigs).t()) + 1.0) / 2 # Range 0-1 C = apply_distortion(C, distortion_ratio) C = min_max_scaling(C) # Range 0-1 C = 1.0 - C # Convert to distance return C def compute_distance_matrix_l2(s1_word_embeddigs, s2_word_embeddigs, distortion_ratio): C = torch.cdist(s1_word_embeddigs, s2_word_embeddigs, p=2) C = min_max_scaling(C) # Range 0-1 C = 1.0 - C # Convert to similarity C = apply_distortion(C, distortion_ratio) C = min_max_scaling(C) # Range 0-1 C = 1.0 - C # Convert to distance return C def apply_distortion(sim_matrix, ratio): shape = sim_matrix.shape if (shape[0] < 2 or shape[1] < 2) or ratio == 0.0: return sim_matrix pos_x = torch.tensor([[y / float(shape[1] - 1) for y in range(shape[1])] for x in range(shape[0])], device=device) pos_y = torch.tensor([[x / float(shape[0] - 1) for x in range(shape[0])] for y in range(shape[1])], device=device) distortion_mask = 1.0 - ((pos_x - pos_y.T) ** 2) * ratio sim_matrix = torch.mul(sim_matrix, distortion_mask) return sim_matrix def compute_weights_norm(s1_word_embeddigs, s2_word_embeddigs): s1_weights = torch.norm(s1_word_embeddigs, dim=1) s2_weights = torch.norm(s2_word_embeddigs, dim=1) return s1_weights, s2_weights def compute_weights_uniform(s1_word_embeddigs, s2_word_embeddigs): s1_weights = torch.ones(s1_word_embeddigs.shape[0], dtype=torch.float64, device=device) s2_weights = torch.ones(s2_word_embeddigs.shape[0], dtype=torch.float64, device=device) # # Uniform weights to make L2 norm=1 # s1_weights /= torch.linalg.norm(s1_weights) # s2_weights /= torch.linalg.norm(s2_weights) return s1_weights, s2_weights def min_max_scaling(C): eps = 1e-10 # Min-max scaling for stabilization nx = get_backend(C) C_min = nx.min(C) C_max = nx.max(C) C = (C - C_min + eps) / (C_max - C_min + eps) return C import seaborn as sns import matplotlib.pyplot as plt from mpl_toolkits.axes_grid1 import make_axes_locatable def plot_align_matrix_heatmap(align_matrix, sent1, sent2, thresh, **kwargs): align_matrix = np.where(align_matrix <= thresh, 0, align_matrix) fig, ax = plt.subplots(figsize=(10, 6)) sns.set(font='sans-serif', style="ticks") _color = ['#F2F2F2', '#E0F4FA', '#BEE4F0', '#88CCE5', '#33b7df', '#1B88A6', '#105264', '#092E39'] _ticks = [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0] divider = make_axes_locatable(ax) cbar_ax = divider.append_axes("right", size="2.5%", pad=0.1) fig.add_axes(cbar_ax) ax = sns.heatmap( align_matrix, xticklabels=sent1, yticklabels=sent2, cmap=_color, linewidths=1, square=True, ax=ax, cbar_ax=cbar_ax, **kwargs ) ax.collections[0].colorbar.ax.yaxis.set_ticks(_ticks, minor=False) ax.collections[0].colorbar.set_ticklabels(_ticks) cax = ax.collections[0].colorbar.ax cax.tick_params(which='major', length=3, labelsize=5) ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right') ax.set_yticklabels(ax.get_yticklabels(), rotation=0) return fig