MrVicente commited on
Commit
4ae80b2
·
1 Parent(s): d74add2

added attention vizualization and qa model

Browse files
Files changed (3) hide show
  1. app.py +31 -8
  2. attention_viz.py +227 -0
  3. custom_bart/bart_attention.py +1 -1
app.py CHANGED
@@ -2,7 +2,8 @@ import gradio as gr
2
  import matplotlib.pyplot as plt
3
 
4
  from inference import RelationsInference
5
- from utils import KGType,Model_Type
 
6
 
7
  #prep
8
  import nltk
@@ -16,28 +17,50 @@ examples = [["What's the meaning of life?", "eli5", "constraint"],
16
  ["boat, water, bird", "commongen", "constraint"],
17
  ["What flows under a bridge?", "commonsense_qa", "constraint"]]
18
 
19
- bart = RelationsInference(
20
  model_path='MrVicente/commonsense_bart_commongen',
21
  kg_type=KGType.CONCEPTNET,
22
  model_type=Model_Type.RELATIONS,
23
  max_length=32
24
  )
25
 
 
 
 
 
 
 
 
26
  #############################
27
  # Helper
28
  #############################
29
 
30
  def infer_bart(context, task_type, decoding_type_str):
31
- response, encoder_attentions, model_input = bart.generate_based_on_context(context, use_kg=False)
 
 
 
 
 
 
 
 
32
  return response[0]
33
 
34
 
35
- def plot_attention(layer, head):
36
  fig = plt.figure()
37
- plt.plot([1, 2, 3], [2, 4, 6])
38
- plt.title("Things")
39
- plt.ylabel("Cases")
40
- plt.xlabel("Days since Day 0")
 
 
 
 
 
 
 
41
  return fig
42
 
43
 
 
2
  import matplotlib.pyplot as plt
3
 
4
  from inference import RelationsInference
5
+ from attention_viz import AttentionVisualizer
6
+ from utils import KGType, Model_Type, Data_Type
7
 
8
  #prep
9
  import nltk
 
17
  ["boat, water, bird", "commongen", "constraint"],
18
  ["What flows under a bridge?", "commonsense_qa", "constraint"]]
19
 
20
+ commongen_bart = RelationsInference(
21
  model_path='MrVicente/commonsense_bart_commongen',
22
  kg_type=KGType.CONCEPTNET,
23
  model_type=Model_Type.RELATIONS,
24
  max_length=32
25
  )
26
 
27
+ qa_bart = RelationsInference(
28
+ model_path='MrVicente/commonsense_bart_absqa',
29
+ kg_type=KGType.CONCEPTNET,
30
+ model_type=Model_Type.RELATIONS,
31
+ max_length=128
32
+ )
33
+ att_viz = AttentionVisualizer(device='cpu')
34
  #############################
35
  # Helper
36
  #############################
37
 
38
  def infer_bart(context, task_type, decoding_type_str):
39
+ if Data_Type(task_type) == Data_Type.COMMONGEN:
40
+ if decoding_type_str =='default':
41
+ response, _, _ = commongen_bart.generate_based_on_context(context, use_kg=False)
42
+ else:
43
+ response, _, _ = commongen_bart.generate_contrained_based_on_context([context], use_kg=True)
44
+ elif Data_Type(task_type) == Data_Type.ELI5:
45
+ response, _, _ = qa_bart.generate_based_on_context(context, use_kg=False)
46
+ else:
47
+ raise NotImplementedError()
48
  return response[0]
49
 
50
 
51
+ def plot_attention(context, task_type, layer, head):
52
  fig = plt.figure()
53
+ if Data_Type(task_type) == Data_Type.COMMONGEN:
54
+ model = commongen_bart
55
+ elif Data_Type(task_type) == Data_Type.ELI5:
56
+ model = qa_bart
57
+ else:
58
+ raise NotImplementedError()
59
+ response, examples, relations = model.prepare_context_for_visualization(context)
60
+ att_viz.plot_attn_lines_concepts_ids('Input text importance visualized',
61
+ examples,
62
+ layer, head,
63
+ relations)
64
  return fig
65
 
66
 
attention_viz.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #############################
2
+ # Imports
3
+ #############################
4
+
5
+ # Python modules
6
+
7
+ # Remote modules
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import torch
11
+
12
+ # Local modules
13
+
14
+ #############################
15
+ # Constants
16
+ #############################
17
+
18
+ class AttentionVisualizer:
19
+ def __init__(self, device):
20
+ self.device = device
21
+
22
+ def visualize_token2token_scores(self, all_tokens,
23
+ scores_mat,
24
+ useful_indeces,
25
+ x_label_name='Head',
26
+ apply_normalization=True):
27
+ fig = plt.figure(figsize=(20, 20))
28
+
29
+ all_tokens = np.array(all_tokens)[useful_indeces]
30
+ for idx, scores in enumerate(scores_mat):
31
+ if apply_normalization:
32
+ scores = torch.from_numpy(scores)
33
+ shape = scores.shape
34
+ scores = scores.reshape((shape[0],shape[1], 1))
35
+ scores = torch.linalg.norm(scores, dim=2)
36
+ scores_np = np.array(scores)
37
+ scores_np = scores_np[useful_indeces, :]
38
+ scores_np = scores_np[:, useful_indeces]
39
+ ax = fig.add_subplot(4, 4, idx + 1)
40
+ # append the attention weights
41
+ im = ax.imshow(scores_np, cmap='viridis')
42
+
43
+ fontdict = {'fontsize': 10}
44
+
45
+ ax.set_xticks(range(len(all_tokens)))
46
+ ax.set_yticks(range(len(all_tokens)))
47
+
48
+ ax.set_xticklabels(all_tokens, fontdict=fontdict, rotation=90)
49
+ ax.set_yticklabels(all_tokens, fontdict=fontdict)
50
+ ax.set_xlabel('{} {}'.format(x_label_name, idx + 1))
51
+
52
+ fig.colorbar(im, fraction=0.046, pad=0.04)
53
+ plt.tight_layout()
54
+ plt.show()
55
+
56
+ def visualize_matrix(self,
57
+ scores_mat,
58
+ label_name='heads_layers'):
59
+ _fig = plt.figure(figsize=(20, 20))
60
+ scores_np = np.array(scores_mat)
61
+ fig, ax = plt.subplots()
62
+ im = ax.imshow(scores_np, cmap='viridis')
63
+
64
+ fontdict = {'fontsize': 10}
65
+
66
+ ax.set_xticks(range(len(scores_mat[0])))
67
+ ax.set_yticks(range(len(scores_mat)))
68
+
69
+ x_labels = [f'head-{i}' for i in range(1, len(scores_mat[0])+1)]
70
+ y_labels = [f'layer-{i}' for i in range(1, len(scores_mat) + 1)]
71
+
72
+ ax.set_xticklabels(x_labels, fontdict=fontdict, rotation=90)
73
+ ax.set_yticklabels(y_labels, fontdict=fontdict)
74
+ ax.set_xlabel('{}'.format(label_name))
75
+
76
+ fig.colorbar(im, fraction=0.046, pad=0.04)
77
+ plt.tight_layout()
78
+ #plt.show()
79
+ plt.savefig(f'figs/{label_name}.png', dpi=fig.dpi)
80
+
81
+ def visualize_token2head_scores(self, all_tokens, scores_mat):
82
+ fig = plt.figure(figsize=(30, 50))
83
+ for idx, scores in enumerate(scores_mat):
84
+ scores_np = np.array(scores)
85
+ ax = fig.add_subplot(6, 3, idx + 1)
86
+ # append the attention weights
87
+ im = ax.matshow(scores_np, cmap='viridis')
88
+
89
+ fontdict = {'fontsize': 20}
90
+
91
+ ax.set_xticks(range(len(all_tokens)))
92
+ ax.set_yticks(range(len(scores)))
93
+
94
+ ax.set_xticklabels(all_tokens, fontdict=fontdict, rotation=90)
95
+ ax.set_yticklabels(range(len(scores[0])), fontdict=fontdict)
96
+ ax.set_xlabel('Layer {}'.format(idx + 1))
97
+
98
+ fig.colorbar(im, fraction=0.046, pad=0.04)
99
+ plt.tight_layout()
100
+ plt.show()
101
+
102
+ def plot_attn_lines(self, data, heads):
103
+ """Plots attention maps for the given example and attention heads."""
104
+ width = 3
105
+ example_sep = 3
106
+ word_height = 1
107
+ pad = 0.1
108
+
109
+ for ei, (layer, head) in enumerate(heads):
110
+ yoffset = 1
111
+ xoffset = ei * width * example_sep
112
+
113
+ attn = data["attns"][layer][head]
114
+ attn = np.array(attn)
115
+ attn /= attn.sum(axis=-1, keepdims=True)
116
+ words = data["tokens"]
117
+ words[0] = "..."
118
+ n_words = len(words)
119
+
120
+ for position, word in enumerate(words):
121
+ plt.text(xoffset + 0, yoffset - position * word_height, word,
122
+ ha="right", va="center")
123
+ plt.text(xoffset + width, yoffset - position * word_height, word,
124
+ ha="left", va="center")
125
+ for i in range(1, n_words):
126
+ for j in range(1, n_words):
127
+ plt.plot([xoffset + pad, xoffset + width - pad],
128
+ [yoffset - word_height * i, yoffset - word_height * j],
129
+ color="blue", linewidth=1, alpha=attn[i, j])
130
+
131
+ def plot_attn_lines_concepts(self, title, examples, layer, head, color_words,
132
+ color_from=True, width=3, example_sep=3,
133
+ word_height=1, pad=0.1, hide_sep=False):
134
+ # examples -> {'words': tokens, 'attentions': [layer][head]}
135
+ plt.figure(figsize=(4, 4))
136
+ for i, example in enumerate(examples):
137
+ yoffset = 0
138
+ if i == 0:
139
+ yoffset += (len(examples[0]["words"]) -
140
+ len(examples[1]["words"])) * word_height / 2
141
+ xoffset = i * width * example_sep
142
+ attn = example["attentions"][layer][head]
143
+ if hide_sep:
144
+ attn = np.array(attn)
145
+ attn[:, 0] = 0
146
+ attn[:, -1] = 0
147
+ attn /= attn.sum(axis=-1, keepdims=True)
148
+
149
+ words = example["words"]
150
+ n_words = len(words)
151
+ for position, word in enumerate(words):
152
+ for x, from_word in [(xoffset, True), (xoffset + width, False)]:
153
+ color = "k"
154
+ if from_word == color_from and word in color_words:
155
+ color = "#cc0000"
156
+ plt.text(x, yoffset - (position * word_height), word,
157
+ ha="right" if from_word else "left", va="center",
158
+ color=color)
159
+
160
+ for i in range(n_words):
161
+ for j in range(n_words):
162
+ color = "b"
163
+ if words[i if color_from else j] in color_words:
164
+ color = "r"
165
+ print(attn[i, j])
166
+ plt.plot([xoffset + pad, xoffset + width - pad],
167
+ [yoffset - word_height * i, yoffset - word_height * j],
168
+ color=color, linewidth=1, alpha=attn[i, j])
169
+ plt.axis("off")
170
+ plt.title(title)
171
+ plt.show()
172
+
173
+ def plot_attn_lines_concepts_ids(title, examples, layer, head,
174
+ relations_total, width=3, example_sep=3,
175
+ word_height=1, pad=0.1, hide_sep=False):
176
+ # examples -> {'words': tokens, 'attentions': [layer][head]}
177
+ plt.clf()
178
+ plt.figure(figsize=(10, 5))
179
+ # print('relations_total:', relations_total)
180
+ # print(examples[0])
181
+ for idx, example in enumerate(examples):
182
+ yoffset = 0
183
+ if idx == 0:
184
+ yoffset += (len(examples[0]["words"]) -
185
+ len(examples[0]["words"])) * word_height / 2
186
+ xoffset = idx * width * example_sep
187
+ attn = example["attentions"][layer][head]
188
+ if hide_sep:
189
+ attn = np.array(attn)
190
+ attn[:, 0] = 0
191
+ attn[:, -1] = 0
192
+ attn /= attn.sum(axis=-1, keepdims=True)
193
+
194
+ words = example["words"]
195
+ n_words = len(words)
196
+ example_rel = relations_total[idx]
197
+ for position, word in enumerate(words):
198
+ for x, from_word in [(xoffset, True), (xoffset + width, False)]:
199
+ color = "k"
200
+ for y_idx, y in enumerate(words):
201
+ if from_word and example_rel[position, y_idx] > 0:
202
+ # print('outgoing', position, y_idx)
203
+ color = "r"
204
+ if not from_word and example_rel[y_idx, position] > 0:
205
+ # print('coming', position, y_idx)
206
+ color = "g"
207
+ # if from_word == color_from and word in color_words:
208
+ # color = "#cc0000"
209
+ plt.text(x, yoffset - (position * word_height), word,
210
+ ha="right" if from_word else "left", va="center",
211
+ color=color)
212
+
213
+ for i in range(n_words):
214
+ for j in range(n_words):
215
+ color = "k"
216
+ # print(i,j, example_rel[i,j])
217
+ if example_rel[i, j].item() > 0 and i <= j:
218
+ color = "r"
219
+ if example_rel[i, j].item() > 0 and i >= j:
220
+ color = "g"
221
+ plt.plot([xoffset + pad, xoffset + width - pad],
222
+ [yoffset - word_height * i, yoffset - word_height * j],
223
+ color=color, linewidth=1, alpha=attn[i, j])
224
+ # color=color, linewidth=1, alpha=min(attn[i, j]*10,1))
225
+ plt.axis("off")
226
+ plt.title(title)
227
+ plt.show()
custom_bart/bart_attention.py CHANGED
@@ -94,7 +94,7 @@ class BartCustomAttention(nn.Module):
94
  # TODO
95
  print('oh no')
96
  relation_inputs = torch.zeros((bsz, tgt_len, tgt_len)).to('cuda').long()
97
- print(relation_inputs.shape, ' | ', (bsz, tgt_len, tgt_len))
98
  assert relation_inputs.shape == (bsz, tgt_len, tgt_len)
99
 
100
  # (batch_size, seq_length, seq_length, self.num_relation_kinds, self.inner_dim // num_relation_kinds)
 
94
  # TODO
95
  print('oh no')
96
  relation_inputs = torch.zeros((bsz, tgt_len, tgt_len)).to('cuda').long()
97
+ #print(relation_inputs.shape, ' | ', (bsz, tgt_len, tgt_len))
98
  assert relation_inputs.shape == (bsz, tgt_len, tgt_len)
99
 
100
  # (batch_size, seq_length, seq_length, self.num_relation_kinds, self.inner_dim // num_relation_kinds)