Spaces:
Running
Running
import plotly.graph_objs as go | |
import textwrap | |
import re | |
from collections import defaultdict | |
from paraphraser import generate_paraphrase | |
from masking_methods import mask, mask_non_stopword | |
def generate_plot(original_sentence): | |
paraphrased_sentences = generate_paraphrase(original_sentence) | |
first_paraphrased_sentence = paraphrased_sentences[0] | |
masked_sentence = mask_non_stopword(first_paraphrased_sentence) | |
masked_versions = mask(masked_sentence) | |
nodes = [] | |
nodes.append(original_sentence) | |
nodes.extend(paraphrased_sentences) | |
nodes.extend(masked_versions) | |
nodes[0] += ' L0' | |
para_len = len(paraphrased_sentences) | |
for i in range(1, para_len+1): | |
nodes[i] += ' L1' | |
for i in range(para_len+1, len(nodes)): | |
nodes[i] += ' L2' | |
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes] | |
wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=30)) for node in cleaned_nodes] | |
def get_levels_and_edges(nodes): | |
levels = {} | |
edges = [] | |
for i, node in enumerate(nodes): | |
level = int(node.split()[-1][1]) | |
levels[i] = level | |
# Add edges from L0 to all L1 nodes | |
root_node = next(i for i, level in levels.items() if level == 0) | |
for i, level in levels.items(): | |
if level == 1: | |
edges.append((root_node, i)) | |
# Identify the first L1 node | |
first_l1_node = next(i for i, level in levels.items() if level == 1) | |
# Add edges from the first L1 node to all L2 nodes | |
for i, level in levels.items(): | |
if level == 2: | |
edges.append((first_l1_node, i)) | |
return levels, edges | |
# Get levels and dynamic edges | |
levels, edges = get_levels_and_edges(nodes) | |
max_level = max(levels.values()) | |
# Calculate positions | |
positions = {} | |
level_widths = defaultdict(int) | |
for node, level in levels.items(): | |
level_widths[level] += 1 | |
x_offsets = {level: - (width - 1) / 2 for level, width in level_widths.items()} | |
y_gap = 4 | |
for node, level in levels.items(): | |
positions[node] = (x_offsets[level], -level * y_gap) | |
x_offsets[level] += 1 | |
# Create figure | |
fig = go.Figure() | |
# Add nodes to the figure | |
for i, node in enumerate(wrapped_nodes): | |
x, y = positions[i] | |
fig.add_trace(go.Scatter( | |
x=[x], | |
y=[y], | |
mode='markers', | |
marker=dict(size=10, color='blue'), | |
hoverinfo='none' | |
)) | |
fig.add_annotation( | |
x=x, | |
y=y, | |
text=node, | |
showarrow=False, | |
yshift=20, # Adjust the y-shift value to avoid overlap | |
align="center", | |
font=dict(size=10), | |
bordercolor='black', | |
borderwidth=1, | |
borderpad=4, | |
bgcolor='white', | |
width=200 | |
) | |
# Add edges to the figure | |
for edge in edges: | |
x0, y0 = positions[edge[0]] | |
x1, y1 = positions[edge[1]] | |
fig.add_trace(go.Scatter( | |
x=[x0, x1], | |
y=[y0, y1], | |
mode='lines', | |
line=dict(color='black', width=2) | |
)) | |
fig.update_layout( | |
showlegend=False, | |
margin=dict(t=50, b=50, l=50, r=50), | |
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
width=1470, | |
height=800 # Increase height to provide more space | |
) | |
return fig |