replace 1e4 mask
Browse files- README.md +1 -0
- modeling_lsg_barthez.py +11 -12
README.md
CHANGED
@@ -46,6 +46,7 @@ You can change various parameters like :
|
|
46 |
* local block size (block_size=128)
|
47 |
* sparse block size (sparse_block_size=128)
|
48 |
* sparsity factor (sparsity_factor=2)
|
|
|
49 |
* see config.json file
|
50 |
|
51 |
Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
|
|
|
46 |
* local block size (block_size=128)
|
47 |
* sparse block size (sparse_block_size=128)
|
48 |
* sparsity factor (sparsity_factor=2)
|
49 |
+
* mask_first_token (mask first token since it is redundant with the first global token)
|
50 |
* see config.json file
|
51 |
|
52 |
Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
|
modeling_lsg_barthez.py
CHANGED
@@ -3,7 +3,6 @@ import torch
|
|
3 |
from transformers.models.bart.modeling_bart import *
|
4 |
from transformers.models.bart.modeling_bart import _expand_mask
|
5 |
import torch.nn as nn
|
6 |
-
from torch.nn import BCEWithLogitsLoss
|
7 |
import sys
|
8 |
|
9 |
AUTO_MAP = {
|
@@ -16,7 +15,7 @@ AUTO_MAP = {
|
|
16 |
|
17 |
class LSGMBartConfig(BartConfig):
|
18 |
"""
|
19 |
-
This class overrides :class:`~transformers.
|
20 |
documentation alongside usage examples.
|
21 |
"""
|
22 |
|
@@ -267,7 +266,7 @@ class LSGAttentionProduct(nn.Module):
|
|
267 |
|
268 |
# Pad before block reshaping
|
269 |
if is_attn_mask:
|
270 |
-
pad_value =
|
271 |
hidden_states = hidden_states.transpose(-1, -2)
|
272 |
else:
|
273 |
pad_value = 0
|
@@ -296,7 +295,7 @@ class LSGAttentionProduct(nn.Module):
|
|
296 |
|
297 |
# Pad before block reshaping
|
298 |
if is_attn_mask:
|
299 |
-
pad_value =
|
300 |
hidden_states = hidden_states.transpose(-1, -2)
|
301 |
else:
|
302 |
pad_value = 0
|
@@ -425,7 +424,7 @@ class LSGMBartEncoderAttention(BaseSelfAttention):
|
|
425 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
426 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
427 |
|
428 |
-
mask =
|
429 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
430 |
|
431 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
@@ -490,7 +489,7 @@ class LSGMBartEncoderAttention(BaseSelfAttention):
|
|
490 |
keys /= mask + 1e-8
|
491 |
values /= mask + 1e-8
|
492 |
|
493 |
-
mask =
|
494 |
|
495 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
496 |
|
@@ -739,7 +738,7 @@ class LSGMBartEncoder(LSGMBartPretrainedModel, BartEncoder):
|
|
739 |
n, t = inputs_.size()[:2]
|
740 |
|
741 |
if attention_mask is None:
|
742 |
-
attention_mask = torch.ones(n, t, device=inputs_.device)
|
743 |
if self.mask_first_token:
|
744 |
attention_mask[:,0] = 0
|
745 |
|
@@ -891,7 +890,7 @@ class LSGMBartEncoder(LSGMBartPretrainedModel, BartEncoder):
|
|
891 |
)
|
892 |
|
893 |
|
894 |
-
class LSGMBartDecoder(
|
895 |
"""
|
896 |
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGMBartDecoderLayer`
|
897 |
Args:
|
@@ -1032,7 +1031,7 @@ class LSGMBartModel(LSGMBartPretrainedModel, BartModel):
|
|
1032 |
)
|
1033 |
|
1034 |
|
1035 |
-
class LSGMBartForConditionalGeneration(
|
1036 |
|
1037 |
base_model_prefix = "model"
|
1038 |
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
|
@@ -1048,7 +1047,7 @@ class LSGMBartForConditionalGeneration(BartForConditionalGeneration, LSGMBartPre
|
|
1048 |
self.post_init()
|
1049 |
|
1050 |
|
1051 |
-
class LSGMBartForSequenceClassification(
|
1052 |
|
1053 |
def __init__(self, config: LSGMBartConfig, **kwargs):
|
1054 |
|
@@ -1064,7 +1063,7 @@ class LSGMBartForSequenceClassification(BartForSequenceClassification, LSGMBartP
|
|
1064 |
self.model._init_weights(self.classification_head.out_proj)
|
1065 |
|
1066 |
|
1067 |
-
class LSGMBartForQuestionAnswering(
|
1068 |
|
1069 |
def __init__(self, config: LSGMBartConfig):
|
1070 |
|
@@ -1093,7 +1092,7 @@ class LSGMBartDecoderWrapper(LSGMBartPretrainedModel):
|
|
1093 |
return self.decoder(*args, **kwargs)
|
1094 |
|
1095 |
|
1096 |
-
class LSGMBartForCausalLM(
|
1097 |
|
1098 |
def __init__(self, config: LSGMBartConfig):
|
1099 |
|
|
|
3 |
from transformers.models.bart.modeling_bart import *
|
4 |
from transformers.models.bart.modeling_bart import _expand_mask
|
5 |
import torch.nn as nn
|
|
|
6 |
import sys
|
7 |
|
8 |
AUTO_MAP = {
|
|
|
15 |
|
16 |
class LSGMBartConfig(BartConfig):
|
17 |
"""
|
18 |
+
This class overrides :class:`~transformers.BartConfig`. Please check the superclass for the appropriate
|
19 |
documentation alongside usage examples.
|
20 |
"""
|
21 |
|
|
|
266 |
|
267 |
# Pad before block reshaping
|
268 |
if is_attn_mask:
|
269 |
+
pad_value = torch.finfo(hidden_states.dtype).min
|
270 |
hidden_states = hidden_states.transpose(-1, -2)
|
271 |
else:
|
272 |
pad_value = 0
|
|
|
295 |
|
296 |
# Pad before block reshaping
|
297 |
if is_attn_mask:
|
298 |
+
pad_value = torch.finfo(hidden_states.dtype).min
|
299 |
hidden_states = hidden_states.transpose(-1, -2)
|
300 |
else:
|
301 |
pad_value = 0
|
|
|
424 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
425 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
426 |
|
427 |
+
mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
|
428 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
429 |
|
430 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
|
|
489 |
keys /= mask + 1e-8
|
490 |
values /= mask + 1e-8
|
491 |
|
492 |
+
mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
|
493 |
|
494 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
495 |
|
|
|
738 |
n, t = inputs_.size()[:2]
|
739 |
|
740 |
if attention_mask is None:
|
741 |
+
attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
|
742 |
if self.mask_first_token:
|
743 |
attention_mask[:,0] = 0
|
744 |
|
|
|
890 |
)
|
891 |
|
892 |
|
893 |
+
class LSGMBartDecoder(LSGMBartPretrainedModel, BartDecoder):
|
894 |
"""
|
895 |
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGMBartDecoderLayer`
|
896 |
Args:
|
|
|
1031 |
)
|
1032 |
|
1033 |
|
1034 |
+
class LSGMBartForConditionalGeneration(LSGMBartPretrainedModel, BartForConditionalGeneration):
|
1035 |
|
1036 |
base_model_prefix = "model"
|
1037 |
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
|
|
|
1047 |
self.post_init()
|
1048 |
|
1049 |
|
1050 |
+
class LSGMBartForSequenceClassification(LSGMBartPretrainedModel, BartForSequenceClassification):
|
1051 |
|
1052 |
def __init__(self, config: LSGMBartConfig, **kwargs):
|
1053 |
|
|
|
1063 |
self.model._init_weights(self.classification_head.out_proj)
|
1064 |
|
1065 |
|
1066 |
+
class LSGMBartForQuestionAnswering(LSGMBartPretrainedModel, BartForQuestionAnswering):
|
1067 |
|
1068 |
def __init__(self, config: LSGMBartConfig):
|
1069 |
|
|
|
1092 |
return self.decoder(*args, **kwargs)
|
1093 |
|
1094 |
|
1095 |
+
class LSGMBartForCausalLM(LSGMBartPretrainedModel, BartForCausalLM):
|
1096 |
|
1097 |
def __init__(self, config: LSGMBartConfig):
|
1098 |
|