UOT / plotools.py
4kasha
update
37d364a
raw
history blame
5.35 kB
import numpy as np
import plotly.graph_objects as go
def _debug_non_unique_axis_values(sent1: list[str], sent2: list[str]):
"""
solution:
using zero-width-space
cf. https://github.com/plotly/plotly.js/issues/1516#issuecomment-983090013
"""
sent1 = [word + i * "\u200b" for i, word in enumerate(sent1)]
sent2 = [word + i * "\u200b" for i, word in enumerate(sent2)]
return sent1, sent2
def discrete_colorscale(bvals, colors):
"""
bvals - list of values bounding intervals/ranges of interest
colors - list of rgb or hex colorcodes for values in [bvals[k], bvals[k+1]],0<=k < len(bvals)-1
returns the plotly discrete colorscale
ref. https://community.plotly.com/t/colors-for-discrete-ranges-in-heatmaps/7780
"""
if len(bvals) != len(colors) + 1:
raise ValueError("len(boundary values) should be equal to len(colors)+1")
bvals = sorted(bvals)
nvals = [
(v - bvals[0]) / (bvals[-1] - bvals[0]) for v in bvals
] # normalized values
dcolorscale = [] # discrete colorscale
for k in range(len(colors)):
dcolorscale.extend([[nvals[k], colors[k]], [nvals[k + 1], colors[k]]])
return dcolorscale
def plot_align_matrix_heatmap_plotly(align_matrix, sent1, sent2, threshhold, Cost):
align_matrix = np.where(align_matrix <= threshhold, 0, align_matrix)
sent1, sent2 = _debug_non_unique_axis_values(sent1, sent2)
_colors = [
"#F2F2F2",
"#E0F4FA",
"#BEE4F0",
"#88CCE5",
"#33b7df",
"#1B88A6",
"#105264",
"#092E39",
]
_ticks = [0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
colorscale = discrete_colorscale(_ticks, _colors)
fig = go.Figure()
fig.add_trace(
go.Heatmap(
z=align_matrix,
customdata=Cost,
x=sent1,
y=sent2,
xgap=2,
ygap=2,
colorscale=colorscale,
colorbar=dict(tick0=0, dtick=0.125, outlinewidth=0),
hovertemplate="x: %{x}<br>"
+ "y: %{y}<br>"
+ "P: %{z:.3f}<br>"
+ "cost: %{customdata:.3f} ",
name="",
)
)
fig.update_layout(
# xaxis=dict(scaleanchor='y'),
yaxis=dict(autorange="reversed"),
margin={"l": 0, "r": 0, "t": 0, "b": 0},
plot_bgcolor="rgba(0,0,0,0)",
font=dict(
size=16,
),
hoverlabel=dict(
bgcolor="#555", font_color="white", font_size=14, font_family="Open Sans"
),
)
fig.update_xaxes(
tickangle=-45,
)
return fig
def plot_similarity_matrix_heatmap_plotly(
similarity_matrix, sent1, sent2, Cost, colorscale="Reds", hover_z="cosine"
):
sent1, sent2 = _debug_non_unique_axis_values(sent1, sent2)
fig = go.Figure()
fig.add_trace(
go.Heatmap(
z=similarity_matrix,
customdata=Cost,
x=sent1,
y=sent2,
xgap=2,
ygap=2,
colorscale=colorscale,
colorbar=dict(tick0=0, dtick=0.125, outlinewidth=0),
hovertemplate="x: %{x}<br>"
+ "y: %{y}<br>"
+ f"{hover_z}: "
+ "%{z:.3f}<br>"
+ "cost: %{customdata:.3f} ",
name="",
)
)
fig.update_layout(
# xaxis=dict(scaleanchor='y'),
yaxis=dict(autorange="reversed"),
margin={"l": 0, "r": 0, "t": 0, "b": 0},
plot_bgcolor="rgba(0,0,0,0)",
font=dict(
size=16,
),
hoverlabel=dict(
bgcolor="#555", font_color="white", font_size=14, font_family="Open Sans"
),
)
fig.update_xaxes(
tickangle=-45,
)
return fig
def show_assignments_plotly(P, word_embeddings, sents1, sents2, thr=0):
P = np.where(P <= thr, 0, P)
s1_end = len(sents1)
a = word_embeddings[:s1_end]
b = word_embeddings[s1_end:]
traces = []
sample = 0
for i in range(a.shape[0]):
for j in range(b.shape[0]):
if P[i, j] > 0:
sample += 1
traces.append(
go.Scatter(
x=[a[i, 0], b[j, 0]],
y=[a[i, 1], b[j, 1]],
mode="lines",
line=dict(color="black", width=P[i, j] * 2),
opacity=P[i, j],
name=f"{sample}",
)
)
# ソースサンプルの描画
traces.append(
go.Scatter(
x=a[:, 0],
y=a[:, 1],
mode="markers+text",
marker=dict(color="blue", size=8, symbol="cross"),
text=sents1,
textposition="top center",
name="Source samples",
)
)
# ターゲットサンプルの描画
traces.append(
go.Scatter(
x=b[:, 0],
y=b[:, 1],
mode="markers+text",
marker=dict(color="red", size=8, symbol="x"),
text=sents2,
textposition="bottom center",
name="Target samples",
)
)
layout = go.Layout(
showlegend=True,
margin=dict(l=0, r=0, t=10, b=0),
)
fig = go.Figure(data=traces, layout=layout)
return fig