Spaces:
Runtime error
Runtime error
taka-yamakoshi
commited on
Commit
·
28c20d6
1
Parent(s):
a0471c4
slide
Browse files
app.py
CHANGED
@@ -222,102 +222,101 @@ if __name__=='__main__':
|
|
222 |
st.experimental_rerun()
|
223 |
|
224 |
if st.session_state['page_status']=='analysis':
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
st.pyplot(fig)
|
|
|
222 |
st.experimental_rerun()
|
223 |
|
224 |
if st.session_state['page_status']=='analysis':
|
225 |
+
sent_1 = st.session_state['sent_1']
|
226 |
+
sent_2 = st.session_state['sent_2']
|
227 |
+
#show_annotated_sentence(st.session_state['decoded_sent_1'],
|
228 |
+
# option_locs=st.session_state['option_locs_1'],
|
229 |
+
# mask_locs=st.session_state['mask_locs_1'])
|
230 |
+
#show_annotated_sentence(st.session_state['decoded_sent_2'],
|
231 |
+
# option_locs=st.session_state['option_locs_2'],
|
232 |
+
# mask_locs=st.session_state['mask_locs_2'])
|
233 |
+
|
234 |
+
option_1_locs, option_2_locs = {}, {}
|
235 |
+
pron_locs = {}
|
236 |
+
input_ids_dict = {}
|
237 |
+
masked_ids_option_1 = {}
|
238 |
+
masked_ids_option_2 = {}
|
239 |
+
for sent_id in [1,2]:
|
240 |
+
option_1_locs[f'sent_{sent_id}'], option_2_locs[f'sent_{sent_id}'] = separate_options(st.session_state[f'option_locs_{sent_id}'])
|
241 |
+
pron_locs[f'sent_{sent_id}'] = st.session_state[f'mask_locs_{sent_id}']
|
242 |
+
input_ids_dict[f'sent_{sent_id}'] = tokenizer(st.session_state[f'sent_{sent_id}']).input_ids
|
243 |
+
|
244 |
+
masked_ids_option_1[f'sent_{sent_id}'] = mask_out(input_ids_dict[f'sent_{sent_id}'],
|
245 |
+
pron_locs[f'sent_{sent_id}'],
|
246 |
+
option_1_locs[f'sent_{sent_id}'],mask_id)
|
247 |
+
masked_ids_option_2[f'sent_{sent_id}'] = mask_out(input_ids_dict[f'sent_{sent_id}'],
|
248 |
+
pron_locs[f'sent_{sent_id}'],
|
249 |
+
option_2_locs[f'sent_{sent_id}'],mask_id)
|
250 |
+
|
251 |
+
#st.write(option_1_locs)
|
252 |
+
#st.write(option_2_locs)
|
253 |
+
#st.write(pron_locs)
|
254 |
+
#for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]:
|
255 |
+
# st.write(' '.join([tokenizer.decode([token]) for token in token_ids]))
|
256 |
+
|
257 |
+
option_1_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_1_locs['sent_1'])+1]
|
258 |
+
option_1_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_1_locs['sent_2'])+1]
|
259 |
+
option_2_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_2_locs['sent_1'])+1]
|
260 |
+
option_2_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_2_locs['sent_2'])+1]
|
261 |
+
assert np.all(option_1_tokens_1==option_1_tokens_2) and np.all(option_2_tokens_1==option_2_tokens_2)
|
262 |
+
option_1_tokens = option_1_tokens_1
|
263 |
+
option_2_tokens = option_2_tokens_1
|
264 |
+
|
265 |
+
interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
266 |
+
probs_original = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
267 |
+
df = pd.DataFrame(data=[[probs_original[0,0][0],probs_original[1,0][0]],
|
268 |
+
[probs_original[0,1][0],probs_original[1,1][0]]],
|
269 |
+
columns=[tokenizer.decode(option_1_tokens),tokenizer.decode(option_2_tokens)],
|
270 |
+
index=['Sentence 1','Sentence 2'])
|
271 |
+
cols = st.columns(3)
|
272 |
+
with cols[1]:
|
273 |
+
show_instruction('Probability of predicting each option in each sentence',fontsize=12)
|
274 |
+
st.dataframe(df.style.highlight_max(axis=1),use_container_width=True)
|
275 |
+
|
276 |
+
compare_1 = np.array(masked_ids_option_1['sent_1'])!=np.array(masked_ids_option_1['sent_2'])
|
277 |
+
compare_2 = np.array(masked_ids_option_2['sent_1'])!=np.array(masked_ids_option_2['sent_2'])
|
278 |
+
assert np.all(compare_1.astype(int)==compare_2.astype(int))
|
279 |
+
context_locs = list(np.arange(len(masked_ids_option_1['sent_1']))[compare_1]-1) # match the indexing for annotation
|
280 |
+
|
281 |
+
multihead = True
|
282 |
+
assert np.all(np.array(pron_locs['sent_1'])==np.array(pron_locs['sent_2']))
|
283 |
+
assert np.all(np.array(option_1_locs['sent_1'])==np.array(option_1_locs['sent_2']))
|
284 |
+
assert np.all(np.array(option_2_locs['sent_1'])==np.array(option_2_locs['sent_2']))
|
285 |
+
token_id_list = pron_locs['sent_1'] + option_1_locs['sent_1'] + option_2_locs['sent_1'] + context_locs
|
286 |
+
#st.write(token_id_list)
|
287 |
+
|
288 |
+
effect_array = []
|
289 |
+
for token_id in token_id_list:
|
290 |
+
token_id += 1
|
291 |
+
effect_list = []
|
292 |
+
for layer_id in range(num_layers):
|
293 |
+
interventions = [create_interventions(token_id,['lay','qry','key','val'],num_heads,multihead) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
294 |
+
if multihead:
|
295 |
+
probs = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
296 |
+
else:
|
297 |
+
probs = run_intervention(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
298 |
+
effect = ((probs_original-probs)[0,0] + (probs_original-probs)[1,1] + (probs-probs_original)[0,1] + (probs-probs_original)[1,0])/4
|
299 |
+
effect_list.append(effect)
|
300 |
+
effect_array.append(effect_list)
|
301 |
+
effect_array = np.transpose(np.array(effect_array),(1,0,2))
|
302 |
+
|
303 |
+
cols = st.columns(len(masked_ids_option_1['sent_1'])-2)
|
304 |
+
token_id = 0
|
305 |
+
for col_id,col in enumerate(cols):
|
306 |
+
with col:
|
307 |
+
st.write(tokenizer.decode([masked_ids_option_1['sent_1'][col_id+1]]))
|
308 |
+
if col_id in token_id_list:
|
309 |
+
interv_id = token_id_list.index(col_id)
|
310 |
+
fig,ax = plt.subplots()
|
311 |
+
ax.set_box_aspect(num_layers)
|
312 |
+
ax.imshow(effect_array[:,interv_id:interv_id+1,0],cmap=sns.color_palette("light:r", as_cmap=True),
|
313 |
+
vmin=effect_array[:,:,0].min(),vmax=effect_array[:,:,0].max())
|
314 |
+
ax.set_xticks([])
|
315 |
+
ax.set_xticklabels([])
|
316 |
+
ax.set_yticks([])
|
317 |
+
ax.set_yticklabels([])
|
318 |
+
ax.spines['top'].set_visible(False)
|
319 |
+
ax.spines['bottom'].set_visible(False)
|
320 |
+
ax.spines['right'].set_visible(False)
|
321 |
+
ax.spines['left'].set_visible(False)
|
322 |
+
st.pyplot(fig)
|
|