prajdabre VarunGumma commited on
Commit
871b62f
·
verified ·
1 Parent(s): 0b0e440

Update modeling_rotary_indictrans.py (#5)

Browse files

- Update modeling_rotary_indictrans.py (2572b0000ff5bd5ee41fa49a22c8c4e6fbebd6a8)


Co-authored-by: Varun Gumma <[email protected]>

Files changed (1) hide show
  1. modeling_rotary_indictrans.py +26 -25
modeling_rotary_indictrans.py CHANGED
@@ -31,16 +31,22 @@ from transformers.generation import GenerationMixin
31
  from transformers.modeling_utils import PreTrainedModel
32
  from .configuration_rotary_indictrans import RotaryIndicTransConfig
33
 
34
- from flash_attn import flash_attn_func, flash_attn_varlen_func
35
- from flash_attn.bert_padding import (
36
- index_first_axis,
37
- pad_input,
38
- unpad_input,
39
- )
40
-
41
  logger = logging.get_logger(__name__)
42
  device = "cuda" if torch.cuda.is_available() else "cpu"
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
46
  def _get_unpad_data(attention_mask):
@@ -1401,8 +1407,6 @@ class RotaryIndicTransDecoder(RotaryIndicTransPreTrainedModel):
1401
 
1402
  # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100Model->RotaryIndicTrans
1403
  class RotaryIndicTransModel(RotaryIndicTransPreTrainedModel):
1404
- _tied_weights_keys = None
1405
-
1406
  def __init__(self, config: RotaryIndicTransConfig):
1407
  super().__init__(config)
1408
 
@@ -1497,10 +1501,11 @@ class RotaryIndicTransModel(RotaryIndicTransPreTrainedModel):
1497
 
1498
 
1499
  # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->RotaryIndicTrans
1500
- class RotaryIndicTransForConditionalGeneration(RotaryIndicTransPreTrainedModel, GenerationMixin):
 
 
1501
  base_model_prefix = "model"
1502
- _tied_weights_keys = None
1503
- _label_smoothing = 0.0
1504
 
1505
  def __init__(self, config: RotaryIndicTransConfig):
1506
  super().__init__(config)
@@ -1509,19 +1514,16 @@ class RotaryIndicTransForConditionalGeneration(RotaryIndicTransPreTrainedModel,
1509
  config.decoder_embed_dim, config.decoder_vocab_size, bias=False
1510
  )
1511
 
1512
- if config.share_decoder_input_output_embed:
1513
- self.lm_head.weight = self.model.decoder.embed_tokens.weight
1514
-
1515
  self.post_init()
1516
 
1517
- def tie_weights(self):
1518
- pass
1519
-
1520
  def get_encoder(self):
1521
- return self.model.get_encoder()
1522
 
1523
  def get_decoder(self):
1524
- return self.model.get_decoder()
 
 
 
1525
 
1526
  def get_output_embeddings(self):
1527
  return self.lm_head
@@ -1529,8 +1531,9 @@ class RotaryIndicTransForConditionalGeneration(RotaryIndicTransPreTrainedModel,
1529
  def set_output_embeddings(self, new_embeddings):
1530
  self.lm_head = new_embeddings
1531
 
1532
- def set_label_smoothing(self, label_smoothing):
1533
- self._label_smoothing = label_smoothing
 
1534
 
1535
  def forward(
1536
  self,
@@ -1594,8 +1597,6 @@ class RotaryIndicTransForConditionalGeneration(RotaryIndicTransPreTrainedModel,
1594
  masked_lm_loss = F.cross_entropy(
1595
  input=lm_logits.view(-1, self.config.decoder_vocab_size),
1596
  target=labels.view(-1),
1597
- ignore_index=-100,
1598
- label_smoothing=self._label_smoothing,
1599
  )
1600
 
1601
  if not return_dict:
@@ -1652,4 +1653,4 @@ class RotaryIndicTransForConditionalGeneration(RotaryIndicTransPreTrainedModel,
1652
  past_state.index_select(0, beam_idx) for past_state in layer_past
1653
  ),
1654
  )
1655
- return reordered_past
 
31
  from transformers.modeling_utils import PreTrainedModel
32
  from .configuration_rotary_indictrans import RotaryIndicTransConfig
33
 
 
 
 
 
 
 
 
34
  logger = logging.get_logger(__name__)
35
  device = "cuda" if torch.cuda.is_available() else "cpu"
36
 
37
+ try:
38
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
39
+ from flash_attn.bert_padding import (
40
+ index_first_axis,
41
+ pad_input,
42
+ unpad_input,
43
+ )
44
+ except ImportError:
45
+ logger.warning(
46
+ "It is highly recommended to use `flash_attention_2` for better performance with RotaryIndicTrans."
47
+ "Falling back to the default `eager` implementation."
48
+ )
49
+
50
 
51
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
52
  def _get_unpad_data(attention_mask):
 
1407
 
1408
  # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100Model->RotaryIndicTrans
1409
  class RotaryIndicTransModel(RotaryIndicTransPreTrainedModel):
 
 
1410
  def __init__(self, config: RotaryIndicTransConfig):
1411
  super().__init__(config)
1412
 
 
1501
 
1502
 
1503
  # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->RotaryIndicTrans
1504
+ class RotaryIndicTransForConditionalGeneration(
1505
+ RotaryIndicTransPreTrainedModel, GenerationMixin
1506
+ ):
1507
  base_model_prefix = "model"
1508
+ _tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"]
 
1509
 
1510
  def __init__(self, config: RotaryIndicTransConfig):
1511
  super().__init__(config)
 
1514
  config.decoder_embed_dim, config.decoder_vocab_size, bias=False
1515
  )
1516
 
 
 
 
1517
  self.post_init()
1518
 
 
 
 
1519
  def get_encoder(self):
1520
+ return self.model.encoder
1521
 
1522
  def get_decoder(self):
1523
+ return self.model.decoder
1524
+
1525
+ def get_input_embeddings(self):
1526
+ return self.model.encoder.embed_tokens
1527
 
1528
  def get_output_embeddings(self):
1529
  return self.lm_head
 
1531
  def set_output_embeddings(self, new_embeddings):
1532
  self.lm_head = new_embeddings
1533
 
1534
+ def tie_weights(self):
1535
+ if self.config.share_decoder_input_output_embed:
1536
+ self._tie_or_clone_weights(self.model.decoder.embed_tokens, self.lm_head)
1537
 
1538
  def forward(
1539
  self,
 
1597
  masked_lm_loss = F.cross_entropy(
1598
  input=lm_logits.view(-1, self.config.decoder_vocab_size),
1599
  target=labels.view(-1),
 
 
1600
  )
1601
 
1602
  if not return_dict:
 
1653
  past_state.index_select(0, beam_idx) for past_state in layer_past
1654
  ),
1655
  )
1656
+ return reordered_past