dar-tau commited on
Commit
c5fa8a7
·
verified ·
1 Parent(s): 9fa205f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -6,10 +6,12 @@ import matplotlib.pyplot as plt
6
  import seaborn as sns
7
  from enum import Enum
8
 
9
- class VisType(Enum):
10
- ALL = 'ALL'
11
 
 
 
 
12
 
 
13
  dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
14
  tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True)
15
 
@@ -21,10 +23,13 @@ def analyze_sentence(index, vis_type):
21
  attn_map_shape = row['attention_maps_shape'][1:]
22
  seq_len = attn_map_shape[1]
23
  attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
24
- fig = plt.figure(figsize=(0.5 + 0.5 * len(tokenized), 0.4 * len(tokenized)))
25
  attn_maps = attn_maps[:, 1:, 1:]
26
- if vis_type == VisType.ALL.value:
27
  plot_data = attn_maps.sum(0)
 
 
 
28
  else:
29
  print(vis_type)
30
  0/0
@@ -34,7 +39,8 @@ def analyze_sentence(index, vis_type):
34
  plt.ylabel('TARGET')
35
  plt.xlabel('SOURCE')
36
  plt.grid()
37
- metrics = {k: v for k, v in row.items() if k not in ['text', 'attention_maps', 'attention_maps_shape']}
 
38
  return fig, metrics
39
 
40
  demo = gr.Blocks()
@@ -43,8 +49,8 @@ with demo:
43
  sentence_dropdown = gr.Dropdown(label="Sentence",
44
  choices=[x.split('</s> ')[1] for x in dataset['text']],
45
  value=0, min_width=500, type='index')
46
- vis_dropdown = gr.Dropdown(label="Visualization", choices=[x.value for x in VisType],
47
- min_width=150, value=VisType.ALL, type='value')
48
  btn = gr.Button("Run", min_width=30)
49
  output = gr.Plot(label="Plot", container=True)
50
  metrics = gr.Label("Metrics")
 
6
  import seaborn as sns
7
  from enum import Enum
8
 
 
 
9
 
10
+ class VisType(Enum):
11
+ SUM = 'Sum over Layers'
12
+
13
 
14
+ num_layers = 24
15
  dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
16
  tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True)
17
 
 
23
  attn_map_shape = row['attention_maps_shape'][1:]
24
  seq_len = attn_map_shape[1]
25
  attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
26
+ fig = plt.figure(figsize=(0.5 + 0.4 * len(tokenized), 0.35 * len(tokenized)))
27
  attn_maps = attn_maps[:, 1:, 1:]
28
+ if vis_type == VisType.SUM.value:
29
  plot_data = attn_maps.sum(0)
30
+ elif vis_type.startswith('Layer #'):
31
+ layer_to_inspect = int(vis_type.split('#')[1])
32
+ plot_data = attn_maps[layer_to_inspect]
33
  else:
34
  print(vis_type)
35
  0/0
 
39
  plt.ylabel('TARGET')
40
  plt.xlabel('SOURCE')
41
  plt.grid()
42
+ metrics = {'Metrics': 0}
43
+ metrics.update({k: v for k, v in row.items() if k not in ['text', 'attention_maps', 'attention_maps_shape']})
44
  return fig, metrics
45
 
46
  demo = gr.Blocks()
 
49
  sentence_dropdown = gr.Dropdown(label="Sentence",
50
  choices=[x.split('</s> ')[1] for x in dataset['text']],
51
  value=0, min_width=500, type='index')
52
+ vis_dropdown = gr.Dropdown(label="Visualization", choices=[x.value for x in VisType] + ['Layer #' + i for i in range(num_layers)],
53
+ min_width=150, value=VisType.SUM, type='value')
54
  btn = gr.Button("Run", min_width=30)
55
  output = gr.Plot(label="Plot", container=True)
56
  metrics = gr.Label("Metrics")