|
|
|
|
|
|
|
|
|
|
|
from typing import ( |
|
Optional, |
|
Tuple, |
|
Union, |
|
List, |
|
) |
|
|
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import CrossEntropyLoss |
|
from transformers import ( |
|
BartConfig, |
|
BartPretrainedModel, |
|
) |
|
from transformers.modeling_outputs import Seq2SeqLMOutput |
|
from transformers.models.bart.modeling_bart import shift_tokens_right |
|
|
|
from transformers.utils import ( |
|
add_end_docstrings, |
|
add_start_docstrings, |
|
add_start_docstrings_to_model_forward, |
|
logging, |
|
replace_return_docstrings, |
|
) |
|
|
|
from .bart_model import BartCustomModel |
|
from .config import BartCustomConfig |
|
from .custom_constants import BartConstants |
|
from .bart_generation_mixin import GenerationMixin |
|
from .custom_outputs import CustomSeq2SeqLMOutput |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
@add_start_docstrings( |
|
"The BART Model with a language modeling head. Can be used for summarization.", BartConstants.BART_START_DOCSTRING |
|
) |
|
class BartCustomForConditionalGeneration(BartPretrainedModel, GenerationMixin): |
|
base_model_prefix = "model" |
|
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"] |
|
|
|
def __init__(self, config: BartCustomConfig): |
|
super().__init__(config) |
|
self.model = BartCustomModel(config) |
|
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) |
|
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_encoder(self): |
|
return self.model.get_encoder() |
|
|
|
def get_decoder(self): |
|
return self.model.get_decoder() |
|
|
|
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: |
|
new_embeddings = super().resize_token_embeddings(new_num_tokens) |
|
self._resize_final_logits_bias(new_num_tokens) |
|
return new_embeddings |
|
|
|
def _resize_final_logits_bias(self, new_num_tokens: int) -> None: |
|
old_num_tokens = self.final_logits_bias.shape[-1] |
|
if new_num_tokens <= old_num_tokens: |
|
new_bias = self.final_logits_bias[:, :new_num_tokens] |
|
else: |
|
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) |
|
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) |
|
self.register_buffer("final_logits_bias", new_bias) |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.lm_head = new_embeddings |
|
|
|
@add_start_docstrings_to_model_forward(BartConstants.BART_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=BartConstants.CONFIG_FOR_DOC) |
|
@add_end_docstrings(BartConstants.BART_GENERATION_EXAMPLE) |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
decoder_input_ids: Optional[torch.LongTensor] = None, |
|
decoder_attention_mask: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
decoder_head_mask: Optional[torch.Tensor] = None, |
|
cross_attn_head_mask: Optional[torch.Tensor] = None, |
|
encoder_outputs: Optional[List[torch.FloatTensor]] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
input_commonsense_relations: Optional[torch.Tensor] = None, |
|
reduce_ce=True, |
|
) -> Union[Tuple, CustomSeq2SeqLMOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
|
|
Returns: |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if labels is not None: |
|
if use_cache: |
|
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") |
|
use_cache = False |
|
if decoder_input_ids is None and decoder_inputs_embeds is None: |
|
decoder_input_ids = shift_tokens_right( |
|
labels, self.config.pad_token_id, self.config.decoder_start_token_id |
|
) |
|
outputs = self.model( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
decoder_input_ids=decoder_input_ids, |
|
encoder_outputs=encoder_outputs, |
|
decoder_attention_mask=decoder_attention_mask, |
|
head_mask=head_mask, |
|
decoder_head_mask=decoder_head_mask, |
|
cross_attn_head_mask=cross_attn_head_mask, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
decoder_inputs_embeds=decoder_inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
relation_inputs=input_commonsense_relations |
|
) |
|
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias |
|
|
|
masked_lm_loss = None |
|
if labels is not None: |
|
loss_fct = CrossEntropyLoss(reduce=reduce_ce, ignore_index=self.config.pad_token_id) |
|
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) |
|
|
|
if not return_dict: |
|
output = (lm_logits,) + outputs[1:] |
|
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output |
|
|
|
return CustomSeq2SeqLMOutput( |
|
loss=masked_lm_loss, |
|
logits=lm_logits, |
|
past_key_values=outputs.past_key_values, |
|
decoder_hidden_states=outputs.decoder_hidden_states, |
|
decoder_attentions=outputs.decoder_attentions, |
|
cross_attentions=outputs.cross_attentions, |
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state, |
|
encoder_hidden_states=outputs.encoder_hidden_states, |
|
encoder_attentions=outputs.encoder_attentions, |
|
head_mask=outputs.encoder_head_mask |
|
) |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
decoder_input_ids, |
|
past=None, |
|
attention_mask=None, |
|
head_mask=None, |
|
decoder_head_mask=None, |
|
cross_attn_head_mask=None, |
|
use_cache=None, |
|
encoder_outputs=None, |
|
**kwargs |
|
): |
|
|
|
if past is not None: |
|
decoder_input_ids = decoder_input_ids[:, -1:] |
|
|
|
return { |
|
"input_ids": None, |
|
"encoder_outputs": encoder_outputs, |
|
"past_key_values": past, |
|
"decoder_input_ids": decoder_input_ids, |
|
"attention_mask": attention_mask, |
|
"head_mask": head_mask, |
|
"decoder_head_mask": decoder_head_mask, |
|
"cross_attn_head_mask": cross_attn_head_mask, |
|
"use_cache": use_cache, |
|
} |
|
|
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): |
|
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) |
|
|
|
@staticmethod |
|
def _reorder_cache(past, beam_idx): |
|
reordered_past = () |
|
for layer_past in past: |
|
|
|
reordered_past += ( |
|
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], |
|
) |
|
return reordered_past |
|
|