anicolson commited on
Commit
f063116
1 Parent(s): ea64610

Upload model

Browse files
Files changed (4) hide show
  1. config.json +4 -4
  2. generation_config.json +1 -1
  3. modelling_multi.py +425 -0
  4. pytorch_model.bin +2 -2
config.json CHANGED
@@ -1,10 +1,10 @@
1
  {
2
  "_commit_hash": null,
3
  "architectures": [
4
- "VariableCXREncoderDecoderModel"
5
  ],
6
  "auto_map": {
7
- "AutoModel": "modelling_variable.VariableCXREncoderDecoderModel"
8
  },
9
  "decoder": {
10
  "_name_or_path": "",
@@ -78,7 +78,7 @@
78
  "top_p": 1.0,
79
  "torch_dtype": null,
80
  "torchscript": false,
81
- "transformers_version": "4.28.1",
82
  "type_vocab_size": 2,
83
  "typical_p": 1.0,
84
  "use_bfloat16": false,
@@ -2243,7 +2243,7 @@
2243
  "top_p": 1.0,
2244
  "torch_dtype": "float32",
2245
  "torchscript": false,
2246
- "transformers_version": "4.28.1",
2247
  "typical_p": 1.0,
2248
  "use_bfloat16": false
2249
  },
 
1
  {
2
  "_commit_hash": null,
3
  "architectures": [
4
+ "MultiCXREncoderDecoderModel"
5
  ],
6
  "auto_map": {
7
+ "AutoModel": "modelling_multi.MultiCXREncoderDecoderModel"
8
  },
9
  "decoder": {
10
  "_name_or_path": "",
 
78
  "top_p": 1.0,
79
  "torch_dtype": null,
80
  "torchscript": false,
81
+ "transformers_version": "4.31.0",
82
  "type_vocab_size": 2,
83
  "typical_p": 1.0,
84
  "use_bfloat16": false,
 
2243
  "top_p": 1.0,
2244
  "torch_dtype": "float32",
2245
  "torchscript": false,
2246
+ "transformers_version": "4.31.0",
2247
  "typical_p": 1.0,
2248
  "use_bfloat16": false
2249
  },
generation_config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
  "_from_model_config": true,
3
  "pad_token_id": 0,
4
- "transformers_version": "4.28.1"
5
  }
 
1
  {
2
  "_from_model_config": true,
3
  "pad_token_id": 0,
4
+ "transformers_version": "4.31.0"
5
  }
modelling_multi.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import transformers
6
+ from torch.nn import CrossEntropyLoss
7
+ from transformers import PreTrainedTokenizerFast, VisionEncoderDecoderModel
8
+ from transformers.configuration_utils import PretrainedConfig
9
+ from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
10
+ from transformers.modeling_utils import PreTrainedModel
11
+ from transformers.models.vision_encoder_decoder.configuration_vision_encoder_decoder import \
12
+ VisionEncoderDecoderConfig
13
+ from transformers.utils import logging
14
+
15
+ logger = logging.get_logger(__name__)
16
+
17
+
18
+ class CvtWithProjectionHeadConfig(transformers.CvtConfig):
19
+ def __init__(self, projection_size: int = None, **kwargs: Any) -> None:
20
+ super().__init__(**kwargs)
21
+ self.projection_size = projection_size
22
+
23
+
24
+ class ModelOutputWithProjectionEmbedding(transformers.modeling_outputs.ModelOutput):
25
+ last_hidden_state: torch.FloatTensor
26
+ attention_mask: torch.FloatTensor
27
+
28
+
29
+ class CvtProjectionHead(torch.nn.Module):
30
+
31
+ def __init__(self, config) -> None:
32
+ super().__init__()
33
+
34
+ # https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/models/cvt/modeling_cvt.py#L657
35
+ self.layer_norm = torch.nn.LayerNorm(config.embed_dim[-1], eps=config.layer_norm_eps)
36
+
37
+ # No bias as following layer normalisation with bias:
38
+ self.projection = torch.nn.Linear(config.embed_dim[-1], config.projection_size, bias=False)
39
+
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ x = self.layer_norm(x)
43
+ x = self.projection(x)
44
+ return x
45
+
46
+
47
+ class MultiCvtWithProjectionHead(transformers.CvtPreTrainedModel):
48
+ def __init__(self, config):
49
+ super().__init__(config)
50
+
51
+ self.cvt = transformers.CvtModel(config, add_pooling_layer=False)
52
+ self.projection_head = CvtProjectionHead(config)
53
+
54
+ # Initialize weights and apply final processing:
55
+ self.post_init()
56
+
57
+ def forward(
58
+ self,
59
+ pixel_values: Optional[torch.Tensor] = None,
60
+ output_hidden_states: Optional[bool] = None,
61
+ return_dict: Optional[bool] = None,
62
+ ) -> Union[Tuple, ModelOutputWithProjectionEmbedding]:
63
+
64
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
65
+
66
+ # Flatten the batch and study_id dimensions:
67
+ outputs = self.cvt(
68
+ pixel_values.view(-1, *pixel_values.shape[2:]),
69
+ output_hidden_states=output_hidden_states,
70
+ return_dict=return_dict,
71
+ )
72
+
73
+ # Flatten h x w:
74
+ last_hidden_state = torch.flatten(outputs.last_hidden_state, 2)
75
+
76
+ # Project the features for each spatial position to the decoder's hidden size:
77
+ projection = self.projection_head(torch.permute(last_hidden_state, [0, 2, 1]))
78
+
79
+ # Concatenate the features for each chest X-ray:
80
+ projection = projection.view(pixel_values.shape[0], -1, projection.shape[-1])
81
+
82
+ # Derive the attention mask from the pixel values:
83
+ attention_mask = (pixel_values[:, :, 0, 0, 0] != 0.0).repeat_interleave(last_hidden_state.shape[-1], dim=1)
84
+
85
+ if not return_dict:
86
+ return projection
87
+
88
+ return ModelOutputWithProjectionEmbedding(
89
+ last_hidden_state=projection, attention_mask=attention_mask,
90
+ )
91
+
92
+
93
+ class MultiCXREncoderDecoderModel(VisionEncoderDecoderModel):
94
+
95
+ config_class = VisionEncoderDecoderConfig
96
+ base_model_prefix = "vision_encoder_decoder"
97
+ main_input_name = "pixel_values"
98
+ supports_gradient_checkpointing = True
99
+
100
+ def __init__(
101
+ self,
102
+ config: Optional[PretrainedConfig] = None,
103
+ encoder: Optional[PreTrainedModel] = None,
104
+ decoder: Optional[PreTrainedModel] = None,
105
+ ):
106
+
107
+ if decoder:
108
+ assert decoder.config.add_cross_attention, '"add_cross_attention" must be True for the given decoder'
109
+ assert decoder.config.is_decoder, '"is_decoder" must be True for the given decoder'
110
+
111
+ if config is None and (encoder is None or decoder is None):
112
+ raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
113
+ if config is None:
114
+ config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
115
+ else:
116
+ if not isinstance(config, self.config_class):
117
+ raise ValueError(f"Config: {config} has to be of type {self.config_class}")
118
+
119
+ config.tie_word_embeddings = False
120
+
121
+ # initialize with config
122
+ PreTrainedModel.__init__(self, config)
123
+
124
+ # Encoder:
125
+ if encoder is None:
126
+ encoder = MultiCvtWithProjectionHead(config=config.encoder)
127
+
128
+ # Decoder:
129
+ if decoder is None:
130
+ decoder = transformers.BertLMHeadModel(config=config.decoder)
131
+
132
+ self.encoder = encoder
133
+ self.decoder = decoder
134
+
135
+ if self.encoder.config.to_dict() != self.config.encoder.to_dict():
136
+ logger.warning(
137
+ f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
138
+ f" {self.config.encoder}"
139
+ )
140
+ if self.decoder.config.to_dict() != self.config.decoder.to_dict():
141
+ logger.warning(
142
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
143
+ f" {self.config.decoder}"
144
+ )
145
+
146
+ self.encoder.config = self.config.encoder
147
+ self.decoder.config = self.config.decoder
148
+
149
+ # config.add_cross_attention = True
150
+ # config.is_decoder = True
151
+
152
+ def forward(
153
+ self,
154
+ pixel_values: Optional[torch.FloatTensor] = None,
155
+ decoder_input_ids: Optional[torch.LongTensor] = None,
156
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
157
+ encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
158
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
159
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
160
+ labels: Optional[torch.LongTensor] = None,
161
+ use_cache: Optional[bool] = None,
162
+ output_attentions: Optional[bool] = None,
163
+ output_hidden_states: Optional[bool] = None,
164
+ return_dict: Optional[bool] = None,
165
+ **kwargs,
166
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
167
+
168
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
169
+
170
+ kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
171
+
172
+ kwargs_decoder = {
173
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
174
+ }
175
+
176
+ if encoder_outputs is None:
177
+ if pixel_values is None:
178
+ raise ValueError("You have to specify pixel_values")
179
+
180
+ encoder_outputs = self.encoder(
181
+ pixel_values,
182
+ output_hidden_states=output_hidden_states,
183
+ return_dict=return_dict,
184
+ **kwargs_encoder,
185
+ ) # CvT does not support output_attentions.
186
+
187
+ elif isinstance(encoder_outputs, tuple):
188
+ encoder_outputs = BaseModelOutput(*encoder_outputs)
189
+
190
+ encoder_hidden_states = encoder_outputs[0]
191
+
192
+ decoder_outputs = self.decoder(
193
+ input_ids=decoder_input_ids,
194
+ attention_mask=decoder_attention_mask,
195
+ encoder_hidden_states=encoder_hidden_states,
196
+ encoder_attention_mask=encoder_outputs.attention_mask,
197
+ inputs_embeds=decoder_inputs_embeds,
198
+ output_attentions=output_attentions,
199
+ output_hidden_states=output_hidden_states,
200
+ use_cache=use_cache,
201
+ past_key_values=past_key_values,
202
+ return_dict=return_dict,
203
+ **kwargs_decoder,
204
+ )
205
+
206
+ # Loss:
207
+ loss = None
208
+ if labels is not None:
209
+ logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
210
+ loss_fct = CrossEntropyLoss()
211
+ loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
212
+
213
+ if not return_dict:
214
+ if loss is not None:
215
+ return (loss,) + decoder_outputs + encoder_outputs
216
+ else:
217
+ return decoder_outputs + encoder_outputs
218
+
219
+ return Seq2SeqLMOutput(
220
+ loss=loss,
221
+ logits=decoder_outputs.logits,
222
+ past_key_values=decoder_outputs.past_key_values,
223
+ decoder_hidden_states=decoder_outputs.hidden_states,
224
+ decoder_attentions=decoder_outputs.attentions,
225
+ cross_attentions=decoder_outputs.cross_attentions,
226
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
227
+ # encoder_hidden_states=encoder_outputs.hidden_states,
228
+ # encoder_attentions=encoder_outputs.attentions,
229
+ )
230
+
231
+ def prepare_inputs_for_generation(
232
+ self,
233
+ input_ids,
234
+ special_token_ids,
235
+ past_key_values=None,
236
+ attention_mask=None,
237
+ use_cache=None,
238
+ encoder_outputs=None,
239
+ **kwargs,
240
+ ):
241
+ """
242
+ Modification of:
243
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py#L660
244
+ """
245
+
246
+ decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
247
+ decoder_attention_mask = decoder_inputs['attention_mask'] if 'attention_mask' in decoder_inputs else None
248
+
249
+ if not past_key_values:
250
+ token_type_ids = self.token_ids_to_token_type_ids(input_ids, special_token_ids)
251
+ else:
252
+ token_type_ids = self.token_ids_to_token_type_ids_past(input_ids, special_token_ids)
253
+
254
+ input_dict = {
255
+ 'attention_mask': attention_mask,
256
+ 'decoder_attention_mask': decoder_attention_mask,
257
+ 'decoder_input_ids': decoder_inputs['input_ids'],
258
+ 'decoder_token_type_ids': token_type_ids,
259
+ 'encoder_outputs': encoder_outputs,
260
+ 'past_key_values': decoder_inputs['past_key_values'],
261
+ 'use_cache': use_cache,
262
+ }
263
+ return input_dict
264
+
265
+ def token_ids_to_token_type_ids(self, token_ids, special_token_ids, token_type_id_sections=None):
266
+ """
267
+ Extract token type identifiers from the token identifiers.
268
+
269
+ Argument/s:
270
+ token_ids - token identifiers.
271
+ special_token_ids - special token identifiers that indicate the separation between sections.
272
+ token_type_id_section - token type identifier for each section.
273
+
274
+ Returns:
275
+ token_type_ids - token type identifiers.
276
+ """
277
+
278
+ token_type_id_sections = token_type_id_sections if token_type_id_sections is not None else list(range(len(special_token_ids) + 1))
279
+
280
+ mbatch_size, seq_len = token_ids.shape
281
+ token_type_ids = torch.full_like(token_ids, token_type_id_sections[0], dtype=torch.long, device=token_ids.device)
282
+
283
+ for i, j in enumerate(special_token_ids):
284
+ # Find first occurrence of special tokens that indicate the boundary between sections:
285
+ cols = (token_ids == j).int().argmax(dim=1)
286
+ rows = torch.arange(mbatch_size, device=token_ids.device)
287
+
288
+ # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
289
+ cols += 1
290
+
291
+ # Ensure that the column index is not out of bounds. If 0, then token_id not present.
292
+ # This is safe as index 0 is always a special token (now equal to 1 due to +1):
293
+ rows = rows[torch.logical_and(cols != 1, cols < seq_len)]
294
+ cols = cols[torch.logical_and(cols != 1, cols < seq_len)]
295
+
296
+ # Indices to that correspond to the second sequence:
297
+ if rows.nelement() != 0:
298
+ ids = torch.stack([
299
+ torch.stack([x, z]) for (x, y) in zip(rows, cols) for z in torch.arange(
300
+ y, seq_len, device=token_ids.device,
301
+ )
302
+ ])
303
+
304
+ token_type_ids[ids[:, 0], ids[:, 1]] = token_type_id_sections[i + 1]
305
+
306
+ return token_type_ids
307
+
308
+ def token_ids_to_token_type_ids_past(self, token_ids, special_token_ids, token_type_id_sections=None):
309
+ """
310
+ Extract token type identifiers from the token identifiers if past != None.
311
+
312
+ Argument/s:
313
+ token_ids - token identifiers.
314
+ special_token_ids - special token identifiers that indicate the separation between sections.
315
+
316
+ Returns:
317
+ token_type_ids - token type identifiers.
318
+ """
319
+
320
+ token_type_id_sections = token_type_id_sections if token_type_id_sections is not None else list(range(len(special_token_ids) + 1))
321
+ token_type_ids = torch.full([token_ids.shape[0], 1], token_type_id_sections[0], dtype=torch.long, device=token_ids.device)
322
+
323
+ # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
324
+ token_ids = token_ids[:, :-1]
325
+
326
+ for i, j in enumerate(special_token_ids):
327
+
328
+ # Find first occurrence of special token, which indicates the boundary between sections:
329
+ exists = torch.any(token_ids == j, dim=1, keepdim=True)
330
+ token_type_ids[exists] = token_type_id_sections[i + 1]
331
+
332
+ return token_type_ids
333
+
334
+ def tokenize_report_teacher_forcing(self, findings: str, impression: str, tokenizer: PreTrainedTokenizerFast, max_len: int):
335
+ """
336
+ Tokenize the reports and creates the inputs and targets for teacher forcing.
337
+
338
+ Argument/s:
339
+ findings - findings section.
340
+ impression - impression section.
341
+ return_token_type_ids - return the token type identifiers.
342
+ tokenizer - Hugging Face tokenizer.
343
+ max_len - maximum number of tokens.
344
+
345
+ Returns:
346
+ decoder_input_ids - the token identifiers for the input of the decoder.
347
+ decoder_attention_mask - the attention mask for the decoder_input_ids.
348
+ label_ids - the label token identifiers for the decoder.
349
+ """
350
+
351
+ # Prepare the sections for the tokenizer by placing special tokens between each section:
352
+ report = [f'{tokenizer.bos_token}{i}{tokenizer.sep_token}{j}{tokenizer.eos_token}' for i, j in
353
+ zip(findings, impression)]
354
+
355
+ # Tokenize the report:
356
+ tokenized = tokenizer(
357
+ report,
358
+ padding='longest',
359
+ truncation=True,
360
+ max_length=max_len + 1, # +1 to account for the bias between input and target.
361
+ return_tensors='pt',
362
+ return_token_type_ids=False,
363
+ add_special_tokens=False,
364
+ ).to(self.device)
365
+
366
+ # Modify for language modelling:
367
+ batch_dict = {
368
+
369
+ # Labels for the decoder (shifted right by one for autoregression):
370
+ 'label_ids': tokenized['input_ids'][:, 1:].detach().clone(),
371
+
372
+ # Remove last token identifier to match the sequence length of the labels:
373
+ 'decoder_input_ids': tokenized['input_ids'][:, :-1],
374
+
375
+ # Attention mask for the decoder_input_ids (remove first token so that the eos_token_id is not considered):
376
+ 'decoder_attention_mask': tokenized['attention_mask'][:, 1:],
377
+ }
378
+
379
+ return batch_dict
380
+
381
+ def split_and_decode_sections(self, token_ids, special_token_ids, tokenizer: PreTrainedTokenizerFast):
382
+ """
383
+ Split the token identifiers into sections, then convert the token identifiers into strings.
384
+
385
+ Argument/s:
386
+ token_ids - token identifiers.
387
+ special_token_ids - special token identifiers that indicate the end of each section.
388
+ tokenizer - Hugging Face tokenizer.
389
+
390
+ Returns:
391
+ token_type_ids - token type identifiers.
392
+ """
393
+
394
+ _, seq_len = token_ids.shape
395
+
396
+ # The number of sections is the same as the number of special_token_ids:
397
+ num_sections = len(special_token_ids)
398
+
399
+ sections = {k: [] for k in range(num_sections)}
400
+
401
+ for i in token_ids:
402
+ prev_col = 0
403
+ for j, k in enumerate(special_token_ids):
404
+
405
+ # The maximum sequence length was exceeded, thus no more tokens:
406
+ if prev_col >= seq_len:
407
+ sections[j].append('')
408
+ continue
409
+
410
+ # Find first occurrence of special tokens that indicate the boundary between sections:
411
+ col = (i == k).int().argmax().item()
412
+
413
+ # If equal to 0, token was not found, set the column to the sequence length (as the decoder exceeded
414
+ # the maximum sequence length):
415
+ if col == 0:
416
+ col = seq_len
417
+
418
+ # Extract section token identifiers:
419
+ section_token_ids = i[prev_col:col]
420
+ prev_col = col
421
+ section_string = tokenizer.decode(section_token_ids, skip_special_tokens=True)
422
+
423
+ sections[j].append(section_string)
424
+
425
+ return tuple(sections.values())
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fd3d75dceeec1cb112f40a5b8acb031384617714e2e441947f8a1ef3bc5df878
3
- size 449713809
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:128c4fe34643bd3d0ee2648627dcb76a5c7cba2d602285b25fe3de06885d4867
3
+ size 449709389