import numpy as np
import plotly.graph_objects as go
def _debug_non_unique_axis_values(sent1: list[str], sent2: list[str]):
"""
solution:
using zero-width-space
cf. https://github.com/plotly/plotly.js/issues/1516#issuecomment-983090013
"""
sent1 = [word + i*'\u200b' for i, word in enumerate(sent1)]
sent2 = [word + i*'\u200b' for i, word in enumerate(sent2)]
return sent1, sent2
def discrete_colorscale(bvals, colors):
"""
bvals - list of values bounding intervals/ranges of interest
colors - list of rgb or hex colorcodes for values in [bvals[k], bvals[k+1]],0 <= k < len(bvals)-1
returns the plotly discrete colorscale
ref. https://community.plotly.com/t/colors-for-discrete-ranges-in-heatmaps/7780
"""
if len(bvals) != len(colors)+1:
raise ValueError('len(boundary values) should be equal to len(colors)+1')
bvals = sorted(bvals)
nvals = [(v-bvals[0])/(bvals[-1]-bvals[0]) for v in bvals] #normalized values
dcolorscale = [] #discrete colorscale
for k in range(len(colors)):
dcolorscale.extend([[nvals[k], colors[k]], [nvals[k+1], colors[k]]])
return dcolorscale
def plot_align_matrix_heatmap_plotly(align_matrix, sent1, sent2, threshhold, Cost):
align_matrix = np.where(align_matrix <= threshhold, 0, align_matrix)
sent1, sent2 = _debug_non_unique_axis_values(sent1, sent2)
_colors = ['#F2F2F2', '#E0F4FA', '#BEE4F0', '#88CCE5', '#33b7df', '#1B88A6', '#105264', '#092E39']
_ticks = [0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
colorscale = discrete_colorscale(_ticks, _colors)
fig = go.Figure()
fig.add_trace(go.Heatmap(
z=align_matrix,
customdata=Cost,
x=sent1,
y=sent2,
xgap=2,
ygap=2,
colorscale=colorscale,
colorbar=dict(
tick0=0,
dtick=0.125,
outlinewidth=0
),
hovertemplate=
'x: %{x}
' +
'y: %{y}
' +
'P: %{z:.3f}
' +
'cost: %{customdata:.3f} ',
name=''
))
fig.update_layout(
#xaxis=dict(scaleanchor='y'),
yaxis=dict(autorange='reversed'),
margin={'l': 0, 'r': 0, 't': 0, 'b': 0},
plot_bgcolor='rgba(0,0,0,0)',
font=dict(
size=16,
),
hoverlabel=dict(
bgcolor="#555",
font_color="white",
font_size=14,
font_family="Open Sans"
)
)
fig.update_xaxes(
tickangle=-45,
)
return fig
def plot_similarity_matrix_heatmap_plotly(similarity_matrix, sent1, sent2, Cost):
sent1, sent2 = _debug_non_unique_axis_values(sent1, sent2)
fig = go.Figure()
fig.add_trace(go.Heatmap(
z=similarity_matrix,
customdata=Cost,
x=sent1,
y=sent2,
xgap=2,
ygap=2,
colorscale="Reds",
colorbar=dict(
tick0=0,
dtick=0.125,
outlinewidth=0
),
hovertemplate=
'x: %{x}
' +
'y: %{y}
' +
'cosine: %{z:.3f}
' +
'cost: %{customdata:.3f} ',
name=''
))
fig.update_layout(
#xaxis=dict(scaleanchor='y'),
yaxis=dict(autorange='reversed'),
margin={'l': 0, 'r': 0, 't': 0, 'b': 0},
plot_bgcolor='rgba(0,0,0,0)',
font=dict(
size=16,
),
hoverlabel=dict(
bgcolor="#555",
font_color="white",
font_size=14,
font_family="Open Sans"
)
)
fig.update_xaxes(
tickangle=-45,
)
return fig