zwt123home123 commited on
Commit
e84cf04
·
verified ·
1 Parent(s): 141c8dd

Update modeling_internlm2.py

Browse files
Files changed (1) hide show
  1. modeling_internlm2.py +3 -3
modeling_internlm2.py CHANGED
@@ -462,9 +462,9 @@ class InternLM2Attention(nn.Module):
462
  '''
463
  # compute headmask based on attn weights ratio.
464
  if attn_weights.shape[2]>1:
465
- res_v = torch.sum( attn_weights[0,:,:,41:image_token_num],dim=[1,2])
466
- res_t = torch.sum( attn_weights[0,:,:,image_token_num:],dim=[1,2])
467
- res_s = torch.sum( attn_weights[0,:,:,:41],dim=[1,2])
468
  res = res_v/(res_t+res_s)
469
  torch.save(res, 'headcut_mask_8B/'+str(idx)+'.pth')
470
  if idx ==31:
 
462
  '''
463
  # compute headmask based on attn weights ratio.
464
  if attn_weights.shape[2]>1:
465
+ res_v = torch.sum( attn_weights[0,:,:,self.offset:self.offset+image_token_num],dim=[1,2])
466
+ res_t = torch.sum( attn_weights[0,:,:,self.offset+image_token_num:],dim=[1,2])
467
+ res_s = torch.sum( attn_weights[0,:,:,:self.offset],dim=[1,2])
468
  res = res_v/(res_t+res_s)
469
  torch.save(res, 'headcut_mask_8B/'+str(idx)+'.pth')
470
  if idx ==31: