Fix @dataclass decorator bug reported at https://github.com/huggingface/transformers/issues/30412

#5
Files changed (1) hide show
  1. modeling_cxrbert.py +5 -2
modeling_cxrbert.py CHANGED
@@ -14,11 +14,14 @@ from transformers.modeling_outputs import ModelOutput
14
 
15
  from .configuration_cxrbert import CXRBertConfig
16
 
 
 
17
  BERTTupleOutput = Tuple[T, T, T, T, T]
18
 
 
19
  class CXRBertOutput(ModelOutput):
20
- last_hidden_state: torch.FloatTensor
21
- logits: torch.FloatTensor
22
  cls_projected_embedding: Optional[torch.FloatTensor] = None
23
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
24
  attentions: Optional[Tuple[torch.FloatTensor]] = None
 
14
 
15
  from .configuration_cxrbert import CXRBertConfig
16
 
17
+ from dataclasses import dataclass # manually added due to this bug: https://github.com/huggingface/transformers/issues/30412
18
+
19
  BERTTupleOutput = Tuple[T, T, T, T, T]
20
 
21
+ @dataclass # manually added due to this bug: https://github.com/huggingface/transformers/issues/30412
22
  class CXRBertOutput(ModelOutput):
23
+ last_hidden_state: torch.FloatTensor = None # None added. Not present in the original code
24
+ logits: torch.FloatTensor = None # None added. Not present in the original code
25
  cls_projected_embedding: Optional[torch.FloatTensor] = None
26
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
27
  attentions: Optional[Tuple[torch.FloatTensor]] = None