Fix @dataclass decorator bug reported at https://github.com/huggingface/transformers/issues/30412
#5
by
pamessina
- opened
- modeling_cxrbert.py +5 -2
modeling_cxrbert.py
CHANGED
@@ -14,11 +14,14 @@ from transformers.modeling_outputs import ModelOutput
|
|
14 |
|
15 |
from .configuration_cxrbert import CXRBertConfig
|
16 |
|
|
|
|
|
17 |
BERTTupleOutput = Tuple[T, T, T, T, T]
|
18 |
|
|
|
19 |
class CXRBertOutput(ModelOutput):
|
20 |
-
last_hidden_state: torch.FloatTensor
|
21 |
-
logits: torch.FloatTensor
|
22 |
cls_projected_embedding: Optional[torch.FloatTensor] = None
|
23 |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
24 |
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
|
|
14 |
|
15 |
from .configuration_cxrbert import CXRBertConfig
|
16 |
|
17 |
+
from dataclasses import dataclass # manually added due to this bug: https://github.com/huggingface/transformers/issues/30412
|
18 |
+
|
19 |
BERTTupleOutput = Tuple[T, T, T, T, T]
|
20 |
|
21 |
+
@dataclass # manually added due to this bug: https://github.com/huggingface/transformers/issues/30412
|
22 |
class CXRBertOutput(ModelOutput):
|
23 |
+
last_hidden_state: torch.FloatTensor = None # None added. Not present in the original code
|
24 |
+
logits: torch.FloatTensor = None # None added. Not present in the original code
|
25 |
cls_projected_embedding: Optional[torch.FloatTensor] = None
|
26 |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
27 |
attentions: Optional[Tuple[torch.FloatTensor]] = None
|