ccdv commited on
Commit
3406624
·
1 Parent(s): 9bd71c3

replace 1e4 mask

Browse files
Files changed (2) hide show
  1. README.md +1 -0
  2. modeling_lsg_barthez.py +11 -12
README.md CHANGED
@@ -46,6 +46,7 @@ You can change various parameters like :
46
  * local block size (block_size=128)
47
  * sparse block size (sparse_block_size=128)
48
  * sparsity factor (sparsity_factor=2)
 
49
  * see config.json file
50
 
51
  Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
 
46
  * local block size (block_size=128)
47
  * sparse block size (sparse_block_size=128)
48
  * sparsity factor (sparsity_factor=2)
49
+ * mask_first_token (mask first token since it is redundant with the first global token)
50
  * see config.json file
51
 
52
  Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
modeling_lsg_barthez.py CHANGED
@@ -3,7 +3,6 @@ import torch
3
  from transformers.models.bart.modeling_bart import *
4
  from transformers.models.bart.modeling_bart import _expand_mask
5
  import torch.nn as nn
6
- from torch.nn import BCEWithLogitsLoss
7
  import sys
8
 
9
  AUTO_MAP = {
@@ -16,7 +15,7 @@ AUTO_MAP = {
16
 
17
  class LSGMBartConfig(BartConfig):
18
  """
19
- This class overrides :class:`~transformers.RobertaConfig`. Please check the superclass for the appropriate
20
  documentation alongside usage examples.
21
  """
22
 
@@ -267,7 +266,7 @@ class LSGAttentionProduct(nn.Module):
267
 
268
  # Pad before block reshaping
269
  if is_attn_mask:
270
- pad_value = -10000
271
  hidden_states = hidden_states.transpose(-1, -2)
272
  else:
273
  pad_value = 0
@@ -296,7 +295,7 @@ class LSGAttentionProduct(nn.Module):
296
 
297
  # Pad before block reshaping
298
  if is_attn_mask:
299
- pad_value = -10000
300
  hidden_states = hidden_states.transpose(-1, -2)
301
  else:
302
  pad_value = 0
@@ -425,7 +424,7 @@ class LSGMBartEncoderAttention(BaseSelfAttention):
425
  keys = keys.sum(dim=-2) / (mask + 1e-6)
426
  values = values.sum(dim=-2) / (mask + 1e-6)
427
 
428
- mask = - (1. - mask.clamp(0, 1)) * 1e4
429
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
430
 
431
  def get_sparse_tokens_with_stride(self, keys, values, mask):
@@ -490,7 +489,7 @@ class LSGMBartEncoderAttention(BaseSelfAttention):
490
  keys /= mask + 1e-8
491
  values /= mask + 1e-8
492
 
493
- mask = -10000 * (1. - mask.clamp(0, 1))
494
 
495
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
496
 
@@ -739,7 +738,7 @@ class LSGMBartEncoder(LSGMBartPretrainedModel, BartEncoder):
739
  n, t = inputs_.size()[:2]
740
 
741
  if attention_mask is None:
742
- attention_mask = torch.ones(n, t, device=inputs_.device)
743
  if self.mask_first_token:
744
  attention_mask[:,0] = 0
745
 
@@ -891,7 +890,7 @@ class LSGMBartEncoder(LSGMBartPretrainedModel, BartEncoder):
891
  )
892
 
893
 
894
- class LSGMBartDecoder(BartDecoder, LSGMBartPretrainedModel):
895
  """
896
  Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGMBartDecoderLayer`
897
  Args:
@@ -1032,7 +1031,7 @@ class LSGMBartModel(LSGMBartPretrainedModel, BartModel):
1032
  )
1033
 
1034
 
1035
- class LSGMBartForConditionalGeneration(BartForConditionalGeneration, LSGMBartPretrainedModel):
1036
 
1037
  base_model_prefix = "model"
1038
  _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
@@ -1048,7 +1047,7 @@ class LSGMBartForConditionalGeneration(BartForConditionalGeneration, LSGMBartPre
1048
  self.post_init()
1049
 
1050
 
1051
- class LSGMBartForSequenceClassification(BartForSequenceClassification, LSGMBartPretrainedModel):
1052
 
1053
  def __init__(self, config: LSGMBartConfig, **kwargs):
1054
 
@@ -1064,7 +1063,7 @@ class LSGMBartForSequenceClassification(BartForSequenceClassification, LSGMBartP
1064
  self.model._init_weights(self.classification_head.out_proj)
1065
 
1066
 
1067
- class LSGMBartForQuestionAnswering(BartForQuestionAnswering, LSGMBartPretrainedModel):
1068
 
1069
  def __init__(self, config: LSGMBartConfig):
1070
 
@@ -1093,7 +1092,7 @@ class LSGMBartDecoderWrapper(LSGMBartPretrainedModel):
1093
  return self.decoder(*args, **kwargs)
1094
 
1095
 
1096
- class LSGMBartForCausalLM(BartForCausalLM, LSGMBartPretrainedModel):
1097
 
1098
  def __init__(self, config: LSGMBartConfig):
1099
 
 
3
  from transformers.models.bart.modeling_bart import *
4
  from transformers.models.bart.modeling_bart import _expand_mask
5
  import torch.nn as nn
 
6
  import sys
7
 
8
  AUTO_MAP = {
 
15
 
16
  class LSGMBartConfig(BartConfig):
17
  """
18
+ This class overrides :class:`~transformers.BartConfig`. Please check the superclass for the appropriate
19
  documentation alongside usage examples.
20
  """
21
 
 
266
 
267
  # Pad before block reshaping
268
  if is_attn_mask:
269
+ pad_value = torch.finfo(hidden_states.dtype).min
270
  hidden_states = hidden_states.transpose(-1, -2)
271
  else:
272
  pad_value = 0
 
295
 
296
  # Pad before block reshaping
297
  if is_attn_mask:
298
+ pad_value = torch.finfo(hidden_states.dtype).min
299
  hidden_states = hidden_states.transpose(-1, -2)
300
  else:
301
  pad_value = 0
 
424
  keys = keys.sum(dim=-2) / (mask + 1e-6)
425
  values = values.sum(dim=-2) / (mask + 1e-6)
426
 
427
+ mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
428
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
429
 
430
  def get_sparse_tokens_with_stride(self, keys, values, mask):
 
489
  keys /= mask + 1e-8
490
  values /= mask + 1e-8
491
 
492
+ mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
493
 
494
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
495
 
 
738
  n, t = inputs_.size()[:2]
739
 
740
  if attention_mask is None:
741
+ attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
742
  if self.mask_first_token:
743
  attention_mask[:,0] = 0
744
 
 
890
  )
891
 
892
 
893
+ class LSGMBartDecoder(LSGMBartPretrainedModel, BartDecoder):
894
  """
895
  Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGMBartDecoderLayer`
896
  Args:
 
1031
  )
1032
 
1033
 
1034
+ class LSGMBartForConditionalGeneration(LSGMBartPretrainedModel, BartForConditionalGeneration):
1035
 
1036
  base_model_prefix = "model"
1037
  _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
 
1047
  self.post_init()
1048
 
1049
 
1050
+ class LSGMBartForSequenceClassification(LSGMBartPretrainedModel, BartForSequenceClassification):
1051
 
1052
  def __init__(self, config: LSGMBartConfig, **kwargs):
1053
 
 
1063
  self.model._init_weights(self.classification_head.out_proj)
1064
 
1065
 
1066
+ class LSGMBartForQuestionAnswering(LSGMBartPretrainedModel, BartForQuestionAnswering):
1067
 
1068
  def __init__(self, config: LSGMBartConfig):
1069
 
 
1092
  return self.decoder(*args, **kwargs)
1093
 
1094
 
1095
+ class LSGMBartForCausalLM(LSGMBartPretrainedModel, BartForCausalLM):
1096
 
1097
  def __init__(self, config: LSGMBartConfig):
1098