bourdoiscatie commited on
Commit
6c8b5cc
·
verified ·
1 Parent(s): d0f7321

Update custom_heads_flash_t5.py

Browse files
Files changed (1) hide show
  1. custom_heads_flash_t5.py +1 -92
custom_heads_flash_t5.py CHANGED
@@ -257,6 +257,7 @@ class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel):
257
  Labels for position (index) of the end of the labelled span for computing the token classification loss.
258
  Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
259
  are not taken into account for computing the loss.
 
260
  Returns:
261
  """
262
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -306,95 +307,3 @@ class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel):
306
  hidden_states=encoder_outputs.hidden_states,
307
  attentions=encoder_outputs.attentions,
308
  )
309
-
310
-
311
-
312
- class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel):
313
- _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
314
-
315
- def __init__(self, config: FlashT5Config):
316
- super().__init__(config)
317
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
318
-
319
- encoder_config = copy.deepcopy(config)
320
- encoder_config.is_decoder = False
321
- encoder_config.is_encoder_decoder = False
322
- self.encoder = FlashT5Stack(encoder_config, self.shared)
323
- self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
324
-
325
- # Initialize weights and apply final processing
326
- self.post_init()
327
-
328
- self.qa_outputs.weight.data.normal_(mean=0.0, std=config.initializer_factor * 1.0)
329
- self.qa_outputs.bias.data.zero_()
330
-
331
- self.model_parallel = False
332
-
333
- def forward(
334
- self,
335
- input_ids: Optional[torch.LongTensor] = None,
336
- attention_mask: Optional[torch.FloatTensor] = None,
337
- head_mask: Optional[torch.FloatTensor] = None,
338
- inputs_embeds: Optional[torch.FloatTensor] = None,
339
- start_positions: Optional[torch.LongTensor] = None,
340
- end_positions: Optional[torch.LongTensor] = None,
341
- output_attentions: Optional[bool] = None,
342
- output_hidden_states: Optional[bool] = None,
343
- return_dict: Optional[bool] = None,
344
- ) -> Union[Tuple, QuestionAnsweringModelOutput]:
345
- r"""
346
- Returns:
347
- Example:
348
- ```python
349
- >>> from transformers import AutoTokenizer, MTxEncoderForQuestionAnswering
350
- >>> tokenizer = AutoTokenizer.from_pretrained("MTx-small")
351
- >>> model = MTxEncoderForQuestionAnswering.from_pretrained("MTx-small")
352
- >>> input_ids = tokenizer(
353
- ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
354
- ... ).input_ids # Batch size 1
355
- >>> outputs = model(input_ids=input_ids)
356
- >>> start_logits = outputs.start_logits
357
- >>> end_logits = outputs.end_logits
358
- ```"""
359
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
360
-
361
- outputs = self.encoder(
362
- input_ids,
363
- attention_mask=attention_mask,
364
- inputs_embeds=inputs_embeds,
365
- )
366
- sequence_output = outputs[0]
367
-
368
- logits = self.qa_outputs(sequence_output)
369
- start_logits, end_logits = logits.split(1, dim=-1)
370
- start_logits = start_logits.squeeze(-1).contiguous()
371
- end_logits = end_logits.squeeze(-1).contiguous()
372
-
373
- total_loss = None
374
- if start_positions is not None and end_positions is not None:
375
- # If we are on multi-GPU, split add a dimension
376
- if len(start_positions.size()) > 1:
377
- start_positions = start_positions.squeeze(-1).to(start_logits.device)
378
- if len(end_positions.size()) > 1:
379
- end_positions = end_positions.squeeze(-1).to(end_logits.device)
380
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
381
- ignored_index = start_logits.size(1)
382
- start_positions = start_positions.clamp(0, ignored_index)
383
- end_positions = end_positions.clamp(0, ignored_index)
384
-
385
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
386
- start_loss = loss_fct(start_logits, start_positions)
387
- end_loss = loss_fct(end_logits, end_positions)
388
- total_loss = (start_loss + end_loss) / 2
389
-
390
- if not return_dict:
391
- output = (start_logits, end_logits) + outputs[1:]
392
- return ((total_loss,) + output) if total_loss is not None else output
393
-
394
- return QuestionAnsweringModelOutput(
395
- loss=total_loss,
396
- start_logits=start_logits,
397
- end_logits=end_logits,
398
- hidden_states=outputs.hidden_states,
399
- attentions=outputs.attentions,
400
- )
 
257
  Labels for position (index) of the end of the labelled span for computing the token classification loss.
258
  Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
259
  are not taken into account for computing the loss.
260
+
261
  Returns:
262
  """
263
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
307
  hidden_states=encoder_outputs.hidden_states,
308
  attentions=encoder_outputs.attentions,
309
  )