import numpy as np import plotly.graph_objects as go def _debug_non_unique_axis_values(sent1: list[str], sent2: list[str]): """ solution: using zero-width-space cf. https://github.com/plotly/plotly.js/issues/1516#issuecomment-983090013 """ sent1 = [word + i*'\u200b' for i, word in enumerate(sent1)] sent2 = [word + i*'\u200b' for i, word in enumerate(sent2)] return sent1, sent2 def discrete_colorscale(bvals, colors): """ bvals - list of values bounding intervals/ranges of interest colors - list of rgb or hex colorcodes for values in [bvals[k], bvals[k+1]],0 <= k < len(bvals)-1 returns the plotly discrete colorscale ref. https://community.plotly.com/t/colors-for-discrete-ranges-in-heatmaps/7780 """ if len(bvals) != len(colors)+1: raise ValueError('len(boundary values) should be equal to len(colors)+1') bvals = sorted(bvals) nvals = [(v-bvals[0])/(bvals[-1]-bvals[0]) for v in bvals] #normalized values dcolorscale = [] #discrete colorscale for k in range(len(colors)): dcolorscale.extend([[nvals[k], colors[k]], [nvals[k+1], colors[k]]]) return dcolorscale def plot_align_matrix_heatmap_plotly(align_matrix, sent1, sent2, threshhold, Cost): align_matrix = np.where(align_matrix <= threshhold, 0, align_matrix) sent1, sent2 = _debug_non_unique_axis_values(sent1, sent2) _colors = ['#F2F2F2', '#E0F4FA', '#BEE4F0', '#88CCE5', '#33b7df', '#1B88A6', '#105264', '#092E39'] _ticks = [0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0] colorscale = discrete_colorscale(_ticks, _colors) fig = go.Figure() fig.add_trace(go.Heatmap( z=align_matrix, customdata=Cost, x=sent1, y=sent2, xgap=2, ygap=2, colorscale=colorscale, colorbar=dict( tick0=0, dtick=0.125, outlinewidth=0 ), hovertemplate= 'x: %{x}
' + 'y: %{y}
' + 'P: %{z:.3f}
' + 'cost: %{customdata:.3f} ', name='' )) fig.update_layout( #xaxis=dict(scaleanchor='y'), yaxis=dict(autorange='reversed'), margin={'l': 0, 'r': 0, 't': 0, 'b': 0}, plot_bgcolor='rgba(0,0,0,0)', font=dict( size=16, ), hoverlabel=dict( bgcolor="#555", font_color="white", font_size=14, font_family="Open Sans" ) ) fig.update_xaxes( tickangle=-45, ) return fig def plot_similarity_matrix_heatmap_plotly(similarity_matrix, sent1, sent2, Cost): sent1, sent2 = _debug_non_unique_axis_values(sent1, sent2) fig = go.Figure() fig.add_trace(go.Heatmap( z=similarity_matrix, customdata=Cost, x=sent1, y=sent2, xgap=2, ygap=2, colorscale="Reds", colorbar=dict( tick0=0, dtick=0.125, outlinewidth=0 ), hovertemplate= 'x: %{x}
' + 'y: %{y}
' + 'cosine: %{z:.3f}
' + 'cost: %{customdata:.3f} ', name='' )) fig.update_layout( #xaxis=dict(scaleanchor='y'), yaxis=dict(autorange='reversed'), margin={'l': 0, 'r': 0, 't': 0, 'b': 0}, plot_bgcolor='rgba(0,0,0,0)', font=dict( size=16, ), hoverlabel=dict( bgcolor="#555", font_color="white", font_size=14, font_family="Open Sans" ) ) fig.update_xaxes( tickangle=-45, ) return fig