Add @dataclass decorator to CXRBertOutput

#6
Files changed (1) hide show
  1. modelling_cxrbert.py +146 -0
modelling_cxrbert.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # ------------------------------------------------------------------------------------------
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Any, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+ from torch import Tensor as T
13
+ from transformers import BertForMaskedLM
14
+ from transformers.modeling_outputs import ModelOutput
15
+
16
+ from health_multimodal.text.model.configuration_cxrbert import CXRBertConfig
17
+
18
+ BERTTupleOutput = Tuple[T, T, T, T, T]
19
+
20
+
21
+ @dataclass
22
+ class CXRBertOutput(ModelOutput):
23
+ last_hidden_state: torch.FloatTensor
24
+ logits: Optional[torch.FloatTensor] = None
25
+ cls_projected_embedding: Optional[torch.FloatTensor] = None
26
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
27
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
28
+
29
+
30
+ class BertProjectionHead(nn.Module):
31
+ """Projection head to be used with BERT CLS token.
32
+
33
+ This is similar to ``BertPredictionHeadTransform`` in HuggingFace.
34
+
35
+ :param config: Configuration for BERT.
36
+ """
37
+
38
+ def __init__(self, config: CXRBertConfig) -> None:
39
+ super().__init__()
40
+ self.dense_to_hidden = nn.Linear(config.hidden_size, config.projection_size)
41
+ self.transform_act_fn = nn.functional.gelu
42
+ self.LayerNorm = nn.LayerNorm(config.projection_size, eps=1e-12)
43
+ self.dense_to_output = nn.Linear(config.projection_size, config.projection_size)
44
+
45
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
46
+ hidden_states = self.dense_to_hidden(hidden_states)
47
+ hidden_states = self.transform_act_fn(hidden_states)
48
+ hidden_states = self.LayerNorm(hidden_states)
49
+ hidden_states = self.dense_to_output(hidden_states)
50
+
51
+ return hidden_states
52
+
53
+
54
+ class CXRBertModel(BertForMaskedLM):
55
+ """
56
+ Implements the CXR-BERT model outlined in the manuscript:
57
+ Boecking et al. "Making the Most of Text Semantics to Improve Biomedical Vision-Language Processing", 2022
58
+ https://link.springer.com/chapter/10.1007/978-3-031-20059-5_1
59
+
60
+ Extends the HuggingFace BertForMaskedLM model by adding a separate projection head. The projection "[CLS]" token is
61
+ used to align the latent vectors of image and text modalities.
62
+ """
63
+
64
+ config_class = CXRBertConfig # type: ignore
65
+
66
+ def __init__(self, config: CXRBertConfig):
67
+ super().__init__(config)
68
+
69
+ self.cls_projection_head = BertProjectionHead(config)
70
+ self.init_weights()
71
+
72
+ def forward(
73
+ self,
74
+ input_ids: torch.Tensor,
75
+ attention_mask: torch.Tensor,
76
+ token_type_ids: Optional[torch.Tensor] = None,
77
+ position_ids: Optional[torch.Tensor] = None,
78
+ head_mask: Optional[torch.Tensor] = None,
79
+ inputs_embeds: Optional[torch.Tensor] = None,
80
+ output_attentions: Optional[bool] = None,
81
+ output_hidden_states: Optional[bool] = None,
82
+ output_cls_projected_embedding: Optional[bool] = None,
83
+ return_dict: Optional[bool] = None,
84
+ **kwargs: Any
85
+ ) -> Union[BERTTupleOutput, CXRBertOutput]:
86
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
87
+
88
+ bert_for_masked_lm_output = super().forward(
89
+ input_ids=input_ids,
90
+ attention_mask=attention_mask,
91
+ token_type_ids=token_type_ids,
92
+ position_ids=position_ids,
93
+ head_mask=head_mask,
94
+ inputs_embeds=inputs_embeds,
95
+ output_attentions=output_attentions,
96
+ output_hidden_states=True,
97
+ return_dict=True,
98
+ )
99
+
100
+ last_hidden_state = bert_for_masked_lm_output.hidden_states[-1]
101
+ cls_projected_embedding = (
102
+ self.cls_projection_head(last_hidden_state[:, 0, :]) if output_cls_projected_embedding else None
103
+ )
104
+
105
+ if return_dict:
106
+ return CXRBertOutput(
107
+ last_hidden_state=last_hidden_state,
108
+ logits=bert_for_masked_lm_output.logits,
109
+ cls_projected_embedding=cls_projected_embedding,
110
+ hidden_states=bert_for_masked_lm_output.hidden_states if output_hidden_states else None,
111
+ attentions=bert_for_masked_lm_output.attentions,
112
+ )
113
+ else:
114
+ return (
115
+ last_hidden_state,
116
+ bert_for_masked_lm_output.logits,
117
+ cls_projected_embedding,
118
+ bert_for_masked_lm_output.hidden_states,
119
+ bert_for_masked_lm_output.attentions,
120
+ )
121
+
122
+ def get_projected_text_embeddings(
123
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor, normalize_embeddings: bool = True
124
+ ) -> torch.Tensor:
125
+ """
126
+ Returns l2-normalised projected cls token embeddings for the given input token ids and attention mask.
127
+ The joint latent space is trained using a contrastive objective between image and text data modalities.
128
+
129
+ :param input_ids: (batch_size, sequence_length)
130
+ :param attention_mask: (batch_size, sequence_length)
131
+ :param normalize_embeddings: Whether to l2-normalise the embeddings.
132
+ :return: (batch_size, projection_size)
133
+ """
134
+
135
+ outputs = self.forward(
136
+ input_ids=input_ids, attention_mask=attention_mask, output_cls_projected_embedding=True, return_dict=True
137
+ )
138
+ assert isinstance(outputs, CXRBertOutput)
139
+
140
+ cls_projected_embedding = outputs.cls_projected_embedding
141
+ assert cls_projected_embedding is not None
142
+
143
+ if normalize_embeddings:
144
+ return F.normalize(cls_projected_embedding, dim=1)
145
+
146
+ return cls_projected_embedding