Jackmin801 commited on
Commit
5dd64d0
·
1 Parent(s): 43f3955

disable sdpa and dynamically allocate alibi bias

Browse files
Files changed (2) hide show
  1. configuration_bert.py +1 -1
  2. modeling_bert.py +1 -22
configuration_bert.py CHANGED
@@ -128,7 +128,7 @@ class JinaBertConfig(PretrainedConfig):
128
  classifier_dropout=None,
129
  feed_forward_type="original",
130
  emb_pooler=None,
131
- attn_implementation='torch',
132
  **kwargs,
133
  ):
134
  super().__init__(pad_token_id=pad_token_id, **kwargs)
 
128
  classifier_dropout=None,
129
  feed_forward_type="original",
130
  emb_pooler=None,
131
+ attn_implementation=None,
132
  **kwargs,
133
  ):
134
  super().__init__(pad_token_id=pad_token_id, **kwargs)
modeling_bert.py CHANGED
@@ -697,11 +697,6 @@ class JinaBertEncoder(nn.Module):
697
  )
698
  self.gradient_checkpointing = False
699
  self.num_attention_heads = config.num_attention_heads
700
- self.register_buffer(
701
- "alibi",
702
- self.rebuild_alibi_tensor(size=config.max_position_embeddings),
703
- persistent=False,
704
- )
705
 
706
  def rebuild_alibi_tensor(
707
  self, size: int, device: Optional[Union[torch.device, str]] = None
@@ -769,23 +764,7 @@ class JinaBertEncoder(nn.Module):
769
 
770
  # Add alibi matrix to extended_attention_mask
771
  _, seqlen, _ = hidden_states.size()
772
- if self._current_alibi_size < seqlen:
773
- # Rebuild the alibi tensor when needed
774
- warnings.warn(
775
- f'Increasing alibi size from {self._current_alibi_size} to {seqlen}.'
776
- )
777
- self.register_buffer(
778
- "alibi",
779
- self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device).to(
780
- hidden_states.dtype
781
- ),
782
- persistent=False,
783
- )
784
- elif self.alibi.device != hidden_states.device:
785
- # Device catch-up
786
- self.alibi = self.alibi.to(hidden_states.device)
787
-
788
- alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
789
  if self.gradient_checkpointing and self.training:
790
  if use_cache:
791
  logger.warning_once(
 
697
  )
698
  self.gradient_checkpointing = False
699
  self.num_attention_heads = config.num_attention_heads
 
 
 
 
 
700
 
701
  def rebuild_alibi_tensor(
702
  self, size: int, device: Optional[Union[torch.device, str]] = None
 
764
 
765
  # Add alibi matrix to extended_attention_mask
766
  _, seqlen, _ = hidden_states.size()
767
+ alibi_bias = self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device).to(hidden_states.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
768
  if self.gradient_checkpointing and self.training:
769
  if use_cache:
770
  logger.warning_once(