Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -6,10 +6,12 @@ import matplotlib.pyplot as plt
|
|
6 |
import seaborn as sns
|
7 |
from enum import Enum
|
8 |
|
9 |
-
class VisType(Enum):
|
10 |
-
ALL = 'ALL'
|
11 |
|
|
|
|
|
|
|
12 |
|
|
|
13 |
dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
|
14 |
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True)
|
15 |
|
@@ -21,10 +23,13 @@ def analyze_sentence(index, vis_type):
|
|
21 |
attn_map_shape = row['attention_maps_shape'][1:]
|
22 |
seq_len = attn_map_shape[1]
|
23 |
attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
|
24 |
-
fig = plt.figure(figsize=(0.5 + 0.
|
25 |
attn_maps = attn_maps[:, 1:, 1:]
|
26 |
-
if vis_type == VisType.
|
27 |
plot_data = attn_maps.sum(0)
|
|
|
|
|
|
|
28 |
else:
|
29 |
print(vis_type)
|
30 |
0/0
|
@@ -34,7 +39,8 @@ def analyze_sentence(index, vis_type):
|
|
34 |
plt.ylabel('TARGET')
|
35 |
plt.xlabel('SOURCE')
|
36 |
plt.grid()
|
37 |
-
metrics = {
|
|
|
38 |
return fig, metrics
|
39 |
|
40 |
demo = gr.Blocks()
|
@@ -43,8 +49,8 @@ with demo:
|
|
43 |
sentence_dropdown = gr.Dropdown(label="Sentence",
|
44 |
choices=[x.split('</s> ')[1] for x in dataset['text']],
|
45 |
value=0, min_width=500, type='index')
|
46 |
-
vis_dropdown = gr.Dropdown(label="Visualization", choices=[x.value for x in VisType],
|
47 |
-
min_width=150, value=VisType.
|
48 |
btn = gr.Button("Run", min_width=30)
|
49 |
output = gr.Plot(label="Plot", container=True)
|
50 |
metrics = gr.Label("Metrics")
|
|
|
6 |
import seaborn as sns
|
7 |
from enum import Enum
|
8 |
|
|
|
|
|
9 |
|
10 |
+
class VisType(Enum):
|
11 |
+
SUM = 'Sum over Layers'
|
12 |
+
|
13 |
|
14 |
+
num_layers = 24
|
15 |
dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
|
16 |
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True)
|
17 |
|
|
|
23 |
attn_map_shape = row['attention_maps_shape'][1:]
|
24 |
seq_len = attn_map_shape[1]
|
25 |
attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
|
26 |
+
fig = plt.figure(figsize=(0.5 + 0.4 * len(tokenized), 0.35 * len(tokenized)))
|
27 |
attn_maps = attn_maps[:, 1:, 1:]
|
28 |
+
if vis_type == VisType.SUM.value:
|
29 |
plot_data = attn_maps.sum(0)
|
30 |
+
elif vis_type.startswith('Layer #'):
|
31 |
+
layer_to_inspect = int(vis_type.split('#')[1])
|
32 |
+
plot_data = attn_maps[layer_to_inspect]
|
33 |
else:
|
34 |
print(vis_type)
|
35 |
0/0
|
|
|
39 |
plt.ylabel('TARGET')
|
40 |
plt.xlabel('SOURCE')
|
41 |
plt.grid()
|
42 |
+
metrics = {'Metrics': 0}
|
43 |
+
metrics.update({k: v for k, v in row.items() if k not in ['text', 'attention_maps', 'attention_maps_shape']})
|
44 |
return fig, metrics
|
45 |
|
46 |
demo = gr.Blocks()
|
|
|
49 |
sentence_dropdown = gr.Dropdown(label="Sentence",
|
50 |
choices=[x.split('</s> ')[1] for x in dataset['text']],
|
51 |
value=0, min_width=500, type='index')
|
52 |
+
vis_dropdown = gr.Dropdown(label="Visualization", choices=[x.value for x in VisType] + ['Layer #' + i for i in range(num_layers)],
|
53 |
+
min_width=150, value=VisType.SUM, type='value')
|
54 |
btn = gr.Button("Run", min_width=30)
|
55 |
output = gr.Plot(label="Plot", container=True)
|
56 |
metrics = gr.Label("Metrics")
|