zwt123home123
commited on
Update modeling_internlm2.py
Browse files- modeling_internlm2.py +1 -1
modeling_internlm2.py
CHANGED
@@ -460,7 +460,7 @@ class InternLM2Attention(nn.Module):
|
|
460 |
|
461 |
if self.headcut:
|
462 |
'''
|
463 |
-
# compute
|
464 |
if attn_weights.shape[2]>1:
|
465 |
res_v = torch.sum( attn_weights[0,:,:,self.offset:self.offset+image_token_num],dim=[1,2])
|
466 |
res_t = torch.sum( attn_weights[0,:,:,self.offset+image_token_num:],dim=[1,2])
|
|
|
460 |
|
461 |
if self.headcut:
|
462 |
'''
|
463 |
+
# compute attn weights ratio for headcut mask generation.
|
464 |
if attn_weights.shape[2]>1:
|
465 |
res_v = torch.sum( attn_weights[0,:,:,self.offset:self.offset+image_token_num],dim=[1,2])
|
466 |
res_t = torch.sum( attn_weights[0,:,:,self.offset+image_token_num:],dim=[1,2])
|