zwt123home123
commited on
Update modeling_internlm2.py
Browse files- modeling_internlm2.py +3 -3
modeling_internlm2.py
CHANGED
@@ -462,9 +462,9 @@ class InternLM2Attention(nn.Module):
|
|
462 |
'''
|
463 |
# compute headmask based on attn weights ratio.
|
464 |
if attn_weights.shape[2]>1:
|
465 |
-
res_v = torch.sum( attn_weights[0,:,:,
|
466 |
-
res_t = torch.sum( attn_weights[0,:,:,image_token_num:],dim=[1,2])
|
467 |
-
res_s = torch.sum( attn_weights[0,:,:,:
|
468 |
res = res_v/(res_t+res_s)
|
469 |
torch.save(res, 'headcut_mask_8B/'+str(idx)+'.pth')
|
470 |
if idx ==31:
|
|
|
462 |
'''
|
463 |
# compute headmask based on attn weights ratio.
|
464 |
if attn_weights.shape[2]>1:
|
465 |
+
res_v = torch.sum( attn_weights[0,:,:,self.offset:self.offset+image_token_num],dim=[1,2])
|
466 |
+
res_t = torch.sum( attn_weights[0,:,:,self.offset+image_token_num:],dim=[1,2])
|
467 |
+
res_s = torch.sum( attn_weights[0,:,:,:self.offset],dim=[1,2])
|
468 |
res = res_v/(res_t+res_s)
|
469 |
torch.save(res, 'headcut_mask_8B/'+str(idx)+'.pth')
|
470 |
if idx ==31:
|