fixed gradio plot issue
Browse files- app.py +1 -2
- attention_viz.py +3 -2
app.py
CHANGED
@@ -49,7 +49,6 @@ def infer_bart(context, task_type, decoding_type_str):
|
|
49 |
|
50 |
|
51 |
def plot_attention(context, task_type, layer, head):
|
52 |
-
fig = plt.figure()
|
53 |
if Data_Type(task_type) == Data_Type.COMMONGEN:
|
54 |
model = commongen_bart
|
55 |
elif Data_Type(task_type) == Data_Type.ELI5:
|
@@ -57,7 +56,7 @@ def plot_attention(context, task_type, layer, head):
|
|
57 |
else:
|
58 |
raise NotImplementedError()
|
59 |
response, examples, relations = model.prepare_context_for_visualization(context)
|
60 |
-
att_viz.plot_attn_lines_concepts_ids('Input text importance visualized',
|
61 |
examples,
|
62 |
layer, head,
|
63 |
relations)
|
|
|
49 |
|
50 |
|
51 |
def plot_attention(context, task_type, layer, head):
|
|
|
52 |
if Data_Type(task_type) == Data_Type.COMMONGEN:
|
53 |
model = commongen_bart
|
54 |
elif Data_Type(task_type) == Data_Type.ELI5:
|
|
|
56 |
else:
|
57 |
raise NotImplementedError()
|
58 |
response, examples, relations = model.prepare_context_for_visualization(context)
|
59 |
+
fig = att_viz.plot_attn_lines_concepts_ids('Input text importance visualized',
|
60 |
examples,
|
61 |
layer, head,
|
62 |
relations)
|
attention_viz.py
CHANGED
@@ -175,7 +175,7 @@ class AttentionVisualizer:
|
|
175 |
word_height=1, pad=0.1, hide_sep=False):
|
176 |
# examples -> {'words': tokens, 'attentions': [layer][head]}
|
177 |
plt.clf()
|
178 |
-
plt.figure(figsize=(10, 5))
|
179 |
# print('relations_total:', relations_total)
|
180 |
# print(examples[0])
|
181 |
for idx, example in enumerate(examples):
|
@@ -224,4 +224,5 @@ class AttentionVisualizer:
|
|
224 |
# color=color, linewidth=1, alpha=min(attn[i, j]*10,1))
|
225 |
plt.axis("off")
|
226 |
plt.title(title)
|
227 |
-
plt.show()
|
|
|
|
175 |
word_height=1, pad=0.1, hide_sep=False):
|
176 |
# examples -> {'words': tokens, 'attentions': [layer][head]}
|
177 |
plt.clf()
|
178 |
+
fig = plt.figure(figsize=(10, 5))
|
179 |
# print('relations_total:', relations_total)
|
180 |
# print(examples[0])
|
181 |
for idx, example in enumerate(examples):
|
|
|
224 |
# color=color, linewidth=1, alpha=min(attn[i, j]*10,1))
|
225 |
plt.axis("off")
|
226 |
plt.title(title)
|
227 |
+
#plt.show()
|
228 |
+
return fig
|