import numpy as np import plotly.graph_objects as go def _debug_non_unique_axis_values(sent1: list[str], sent2: list[str]): """ solution: using zero-width-space cf. https://github.com/plotly/plotly.js/issues/1516#issuecomment-983090013 """ sent1 = [word + i * "\u200b" for i, word in enumerate(sent1)] sent2 = [word + i * "\u200b" for i, word in enumerate(sent2)] return sent1, sent2 def discrete_colorscale(bvals, colors): """ bvals - list of values bounding intervals/ranges of interest colors - list of rgb or hex colorcodes for values in [bvals[k], bvals[k+1]],0<=k < len(bvals)-1 returns the plotly discrete colorscale ref. https://community.plotly.com/t/colors-for-discrete-ranges-in-heatmaps/7780 """ if len(bvals) != len(colors) + 1: raise ValueError("len(boundary values) should be equal to len(colors)+1") bvals = sorted(bvals) nvals = [ (v - bvals[0]) / (bvals[-1] - bvals[0]) for v in bvals ] # normalized values dcolorscale = [] # discrete colorscale for k in range(len(colors)): dcolorscale.extend([[nvals[k], colors[k]], [nvals[k + 1], colors[k]]]) return dcolorscale def plot_align_matrix_heatmap_plotly(align_matrix, sent1, sent2, threshhold, Cost): align_matrix = np.where(align_matrix <= threshhold, 0, align_matrix) sent1, sent2 = _debug_non_unique_axis_values(sent1, sent2) _colors = [ "#F2F2F2", "#E0F4FA", "#BEE4F0", "#88CCE5", "#33b7df", "#1B88A6", "#105264", "#092E39", ] _ticks = [0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0] colorscale = discrete_colorscale(_ticks, _colors) fig = go.Figure() fig.add_trace( go.Heatmap( z=align_matrix, customdata=Cost, x=sent1, y=sent2, xgap=2, ygap=2, colorscale=colorscale, colorbar=dict(tick0=0, dtick=0.125, outlinewidth=0), hovertemplate="x: %{x}
" + "y: %{y}
" + "P: %{z:.3f}
" + "cost: %{customdata:.3f} ", name="", ) ) fig.update_layout( # xaxis=dict(scaleanchor='y'), yaxis=dict(autorange="reversed"), margin={"l": 0, "r": 0, "t": 0, "b": 0}, plot_bgcolor="rgba(0,0,0,0)", font=dict( size=16, ), hoverlabel=dict( bgcolor="#555", font_color="white", font_size=14, font_family="Open Sans" ), ) fig.update_xaxes( tickangle=-45, ) return fig def plot_similarity_matrix_heatmap_plotly( similarity_matrix, sent1, sent2, Cost, colorscale="Reds", hover_z="cosine" ): sent1, sent2 = _debug_non_unique_axis_values(sent1, sent2) fig = go.Figure() fig.add_trace( go.Heatmap( z=similarity_matrix, customdata=Cost, x=sent1, y=sent2, xgap=2, ygap=2, colorscale=colorscale, colorbar=dict(tick0=0, dtick=0.125, outlinewidth=0), hovertemplate="x: %{x}
" + "y: %{y}
" + f"{hover_z}: " + "%{z:.3f}
" + "cost: %{customdata:.3f} ", name="", ) ) fig.update_layout( # xaxis=dict(scaleanchor='y'), yaxis=dict(autorange="reversed"), margin={"l": 0, "r": 0, "t": 0, "b": 0}, plot_bgcolor="rgba(0,0,0,0)", font=dict( size=16, ), hoverlabel=dict( bgcolor="#555", font_color="white", font_size=14, font_family="Open Sans" ), ) fig.update_xaxes( tickangle=-45, ) return fig def show_assignments_plotly(P, word_embeddings, sents1, sents2, thr=0): P = np.where(P <= thr, 0, P) s1_end = len(sents1) a = word_embeddings[:s1_end] b = word_embeddings[s1_end:] traces = [] sample = 0 for i in range(a.shape[0]): for j in range(b.shape[0]): if P[i, j] > 0: sample += 1 traces.append( go.Scatter( x=[a[i, 0], b[j, 0]], y=[a[i, 1], b[j, 1]], mode="lines", line=dict(color="black", width=P[i, j] * 2), opacity=P[i, j], name=f"{sample}", ) ) # ソースサンプルの描画 traces.append( go.Scatter( x=a[:, 0], y=a[:, 1], mode="markers+text", marker=dict(color="blue", size=8, symbol="cross"), text=sents1, textposition="top center", name="Source samples", ) ) # ターゲットサンプルの描画 traces.append( go.Scatter( x=b[:, 0], y=b[:, 1], mode="markers+text", marker=dict(color="red", size=8, symbol="x"), text=sents2, textposition="bottom center", name="Target samples", ) ) layout = go.Layout( showlegend=True, margin=dict(l=0, r=0, t=10, b=0), ) fig = go.Figure(data=traces, layout=layout) return fig