Chan-Y commited on
Commit
3170836
1 Parent(s): 67e68c3

Upload processing_florence2.py

Browse files
Files changed (1) hide show
  1. processing_florence2.py +1088 -0
processing_florence2.py ADDED
@@ -0,0 +1,1088 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Microsoft and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for Florence-2.
17
+ """
18
+
19
+ import re
20
+ import logging
21
+ from typing import List, Optional, Union
22
+ import numpy as np
23
+
24
+ import torch
25
+
26
+ from transformers.feature_extraction_utils import BatchFeature
27
+ from transformers.image_utils import ImageInput, is_valid_image
28
+ from transformers.processing_utils import ProcessorMixin
29
+ from transformers.tokenization_utils_base import (
30
+ PaddingStrategy,
31
+ PreTokenizedInput,
32
+ TextInput,
33
+ TruncationStrategy,
34
+ )
35
+ from transformers.utils import TensorType
36
+
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ # Copied from transformers.models.idefics2.processing_idefics2.is_url
41
+ def is_url(val) -> bool:
42
+ return isinstance(val, str) and val.startswith("http")
43
+
44
+ # Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url
45
+ def is_image_or_image_url(elem):
46
+ return is_url(elem) or is_valid_image(elem)
47
+
48
+
49
+ def _is_str_or_image(elem):
50
+ return isinstance(elem, (str)) or is_image_or_image_url(elem)
51
+
52
+
53
+ class Florence2Processor(ProcessorMixin):
54
+ r"""
55
+ Constructs a Florence2 processor which wraps a Florence2 image processor and a Florence2 tokenizer into a single processor.
56
+
57
+ [`Florence2Processor`] offers all the functionalities of [`CLIPImageProcessor`] and [`BartTokenizerFast`]. See the
58
+ [`~Florence2Processor.__call__`] and [`~Florence2Processor.decode`] for more information.
59
+
60
+ Args:
61
+ image_processor ([`CLIPImageProcessor`], *optional*):
62
+ The image processor is a required input.
63
+ tokenizer ([`BartTokenizerFast`], *optional*):
64
+ The tokenizer is a required input.
65
+ """
66
+
67
+ attributes = ["image_processor", "tokenizer"]
68
+ image_processor_class = "CLIPImageProcessor"
69
+ tokenizer_class = ("BartTokenizer", "BartTokenizerFast")
70
+
71
+ def __init__(
72
+ self,
73
+ image_processor=None,
74
+ tokenizer=None,
75
+ ):
76
+ if image_processor is None:
77
+ raise ValueError("You need to specify an `image_processor`.")
78
+ if tokenizer is None:
79
+ raise ValueError("You need to specify a `tokenizer`.")
80
+ if not hasattr(image_processor, "image_seq_length"):
81
+ raise ValueError("Image processor is missing an `image_seq_length` attribute.")
82
+
83
+ self.image_seq_length = image_processor.image_seq_length
84
+
85
+ tokens_to_add = {
86
+ 'additional_special_tokens': \
87
+ tokenizer.additional_special_tokens + \
88
+ ['<od>', '</od>', '<ocr>', '</ocr>'] + \
89
+ [f'<loc_{x}>' for x in range(1000)] + \
90
+ ['<cap>', '</cap>', '<ncap>', '</ncap>','<dcap>', '</dcap>', '<grounding>', '</grounding>', '<seg>', '</seg>', '<sep>', '<region_cap>', '</region_cap>', '<region_to_desciption>', '</region_to_desciption>', '<proposal>', '</proposal>', '<poly>', '</poly>', '<and>']
91
+ }
92
+ tokenizer.add_special_tokens(tokens_to_add)
93
+
94
+ self.tasks_answer_post_processing_type = {
95
+ '<OCR>': 'pure_text',
96
+ '<OCR_WITH_REGION>': 'ocr',
97
+ '<CAPTION>': 'pure_text',
98
+ '<DETAILED_CAPTION>': 'pure_text',
99
+ '<MORE_DETAILED_CAPTION>': 'pure_text',
100
+ '<OD>': 'description_with_bboxes',
101
+ '<DENSE_REGION_CAPTION>': 'description_with_bboxes',
102
+ '<CAPTION_TO_PHRASE_GROUNDING>': "phrase_grounding",
103
+ '<REFERRING_EXPRESSION_SEGMENTATION>': 'polygons',
104
+ '<REGION_TO_SEGMENTATION>': 'polygons',
105
+ '<OPEN_VOCABULARY_DETECTION>': 'description_with_bboxes_or_polygons',
106
+ '<REGION_TO_CATEGORY>': 'pure_text',
107
+ '<REGION_TO_DESCRIPTION>': 'pure_text',
108
+ '<REGION_TO_OCR>': 'pure_text',
109
+ '<REGION_PROPOSAL>': 'bboxes'
110
+ }
111
+
112
+ self.task_prompts_without_inputs = {
113
+ '<OCR>': 'What is the text in the image?',
114
+ '<OCR_WITH_REGION>': 'What is the text in the image, with regions?',
115
+ '<CAPTION>': 'What does the image describe?',
116
+ '<DETAILED_CAPTION>': 'Describe in detail what is shown in the image.',
117
+ '<MORE_DETAILED_CAPTION>': 'Describe with a paragraph what is shown in the image.',
118
+ '<OD>': 'Locate the objects with category name in the image.',
119
+ '<DENSE_REGION_CAPTION>': 'Locate the objects in the image, with their descriptions.',
120
+ '<REGION_PROPOSAL>': 'Locate the region proposals in the image.'
121
+ }
122
+
123
+ self.task_prompts_with_input = {
124
+ '<CAPTION_TO_PHRASE_GROUNDING>': "Locate the phrases in the caption: {input}",
125
+ '<REFERRING_EXPRESSION_SEGMENTATION>': 'Locate {input} in the image with mask',
126
+ '<REGION_TO_SEGMENTATION>': 'What is the polygon mask of region {input}',
127
+ '<OPEN_VOCABULARY_DETECTION>': 'Locate {input} in the image.',
128
+ '<REGION_TO_CATEGORY>': 'What is the region {input}?',
129
+ '<REGION_TO_DESCRIPTION>': 'What does the region {input} describe?',
130
+ '<REGION_TO_OCR>': 'What text is in the region {input}?',
131
+ }
132
+
133
+ self.post_processor = Florence2PostProcesser(tokenizer=tokenizer)
134
+
135
+
136
+ super().__init__(image_processor, tokenizer)
137
+
138
+ def _construct_prompts(self, text):
139
+ # replace the task tokens with the task prompts if task token is in the text
140
+ prompts = []
141
+ for _text in text:
142
+ # 1. fixed task prompts without additional inputs
143
+ for task_token, task_prompt in self.task_prompts_without_inputs.items():
144
+ if task_token in _text:
145
+ assert _text == task_token, f"Task token {task_token} should be the only token in the text."
146
+ _text = task_prompt
147
+ break
148
+ # 2. task prompts with additional inputs
149
+ for task_token, task_prompt in self.task_prompts_with_input.items():
150
+ if task_token in _text:
151
+ _text = task_prompt.format(input=_text.replace(task_token, ''))
152
+ break
153
+ prompts.append(_text)
154
+ return prompts
155
+
156
+ def __call__(
157
+ self,
158
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
159
+ images: ImageInput = None,
160
+ tokenize_newline_separately: bool = True,
161
+ padding: Union[bool, str, PaddingStrategy] = False,
162
+ truncation: Union[bool, str, TruncationStrategy] = None,
163
+ max_length=None,
164
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
165
+ do_resize: bool = None,
166
+ do_normalize: bool = None,
167
+ image_mean: Optional[Union[float, List[float]]] = None,
168
+ image_std: Optional[Union[float, List[float]]] = None,
169
+ data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821
170
+ input_data_format: Optional[
171
+ Union[str, "ChannelDimension"] # noqa: F821
172
+ ] = None,
173
+ resample: "PILImageResampling" = None, # noqa: F821
174
+ do_convert_rgb: bool = None,
175
+ do_thumbnail: bool = None,
176
+ do_align_long_axis: bool = None,
177
+ do_rescale: bool = None,
178
+ ) -> BatchFeature:
179
+ """
180
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
181
+ and `kwargs` arguments to BartTokenizerFast's [`~BartTokenizerFast.__call__`] if `text` is not `None` to encode
182
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
183
+ CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
184
+ of the above two methods for more information.
185
+
186
+ Args:
187
+ text (`str`, `List[str]`, `List[List[str]]`):
188
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
189
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
190
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
191
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
192
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
193
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
194
+ number of channels, H and W are image height and width.
195
+ tokenize_newline_separately (`bool`, defaults to `True`):
196
+ Adds a separately tokenized '\n' at the end of the prompt.
197
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
198
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
199
+ index) among:
200
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
201
+ sequence if provided).
202
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
203
+ acceptable input length for the model if that argument is not provided.
204
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
205
+ lengths).
206
+ max_length (`int`, *optional*):
207
+ Maximum length of the returned list and optionally padding length (see above).
208
+ truncation (`bool`, *optional*):
209
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
210
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
211
+ If set, will return tensors of a particular framework. Acceptable values are:
212
+
213
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
214
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
215
+ - `'np'`: Return NumPy `np.ndarray` objects.
216
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
217
+
218
+ Returns:
219
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
220
+
221
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. If `suffix`
222
+ is provided, the `input_ids` will also contain the suffix input ids.
223
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
224
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
225
+ `None`).
226
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
227
+ - **labels** -- Labels compatible with training if `suffix` is not None
228
+ """
229
+
230
+ return_token_type_ids = False
231
+
232
+ if images is None:
233
+ raise ValueError("`images` are expected as arguments to a `Florence2Processor` instance.")
234
+ if text is None:
235
+ logger.warning_once(
236
+ "You are using Florence-2 without a text prompt."
237
+ )
238
+ text = ""
239
+
240
+ if isinstance(text, List) and isinstance(images, List):
241
+ if len(images) < len(text):
242
+ raise ValueError(
243
+ f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image."
244
+ )
245
+ if _is_str_or_image(text):
246
+ text = [text]
247
+ elif isinstance(text, list) and _is_str_or_image(text[0]):
248
+ pass
249
+
250
+ pixel_values = self.image_processor(
251
+ images,
252
+ do_resize=do_resize,
253
+ do_normalize=do_normalize,
254
+ return_tensors=return_tensors,
255
+ image_mean=image_mean,
256
+ image_std=image_std,
257
+ input_data_format=input_data_format,
258
+ data_format=data_format,
259
+ resample=resample,
260
+ do_convert_rgb=do_convert_rgb,
261
+ )["pixel_values"]
262
+
263
+ if max_length is not None:
264
+ max_length -= self.image_seq_length # max_length has to account for the image tokens
265
+
266
+ text = self._construct_prompts(text)
267
+
268
+ inputs = self.tokenizer(
269
+ text,
270
+ return_tensors=return_tensors,
271
+ padding=padding,
272
+ max_length=max_length,
273
+ truncation=truncation,
274
+ return_token_type_ids=return_token_type_ids,
275
+ )
276
+
277
+ return_data = {**inputs, "pixel_values": pixel_values}
278
+
279
+ if return_token_type_ids:
280
+ labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100)
281
+ return_data.update({"labels": labels})
282
+ return BatchFeature(data=return_data)
283
+
284
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Florence2
285
+ def batch_decode(self, *args, **kwargs):
286
+ """
287
+ This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
288
+ refer to the docstring of this method for more information.
289
+ """
290
+ return self.tokenizer.batch_decode(*args, **kwargs)
291
+
292
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Florence2
293
+ def decode(self, *args, **kwargs):
294
+ """
295
+ This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
296
+ the docstring of this method for more information.
297
+ """
298
+ return self.tokenizer.decode(*args, **kwargs)
299
+
300
+ @property
301
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Florence2
302
+ def model_input_names(self):
303
+ tokenizer_input_names = self.tokenizer.model_input_names
304
+ image_processor_input_names = self.image_processor.model_input_names
305
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
306
+
307
+ def post_process_generation(self, text, task, image_size):
308
+ """
309
+ Post-process the output of the model to each of the task outputs.
310
+
311
+ Args:
312
+ text (`str`): The text to post-process.
313
+ task (`str`): The task to post-process the text for.
314
+ image_size (`Tuple[int, int]`): The size of the image. height x width.
315
+ """
316
+
317
+ task_answer_post_processing_type = self.tasks_answer_post_processing_type.get(task, 'pure_text')
318
+ task_answer = self.post_processor(
319
+ text=text,
320
+ image_size=image_size,
321
+ parse_tasks=task_answer_post_processing_type,
322
+ )[task_answer_post_processing_type]
323
+
324
+ if task_answer_post_processing_type == 'pure_text':
325
+ final_answer = task_answer
326
+ # remove the special tokens
327
+ final_answer = final_answer.replace('<s>', '').replace('</s>', '')
328
+ elif task_answer_post_processing_type in ['od', 'description_with_bboxes', 'bboxes']:
329
+ od_instances = task_answer
330
+ bboxes_od = [_od_instance['bbox'] for _od_instance in od_instances]
331
+ labels_od = [str(_od_instance['cat_name']) for _od_instance in od_instances]
332
+ final_answer = {'bboxes': bboxes_od, 'labels': labels_od}
333
+ elif task_answer_post_processing_type in ['ocr']:
334
+ bboxes = [_od_instance['quad_box'] for _od_instance in task_answer]
335
+ labels = [str(_od_instance['text']) for _od_instance in task_answer]
336
+ final_answer = {'quad_boxes': bboxes, 'labels': labels}
337
+ elif task_answer_post_processing_type in ['phrase_grounding']:
338
+ bboxes = []
339
+ labels = []
340
+ for _grounded_phrase in task_answer:
341
+ for _bbox in _grounded_phrase['bbox']:
342
+ bboxes.append(_bbox)
343
+ labels.append(_grounded_phrase['cat_name'])
344
+ final_answer = {'bboxes': bboxes, 'labels': labels}
345
+ elif task_answer_post_processing_type in ['description_with_polygons', 'polygons']:
346
+ labels = []
347
+ polygons = []
348
+ for result in task_answer:
349
+ label = result['cat_name']
350
+ _polygons = result['polygons']
351
+ labels.append(label)
352
+ polygons.append(_polygons)
353
+ final_answer = {'polygons': polygons, 'labels': labels}
354
+ elif task_answer_post_processing_type in ['description_with_bboxes_or_polygons']:
355
+ bboxes = []
356
+ bboxes_labels = []
357
+ polygons = []
358
+ polygons_labels = []
359
+ for result in task_answer:
360
+ label = result['cat_name']
361
+ if 'polygons' in result:
362
+ _polygons = result['polygons']
363
+ polygons.append(_polygons)
364
+ polygons_labels.append(label)
365
+ else:
366
+ _bbox = result['bbox']
367
+ bboxes.append(_bbox)
368
+ bboxes_labels.append(label)
369
+ final_answer = {'bboxes': bboxes, 'bboxes_labels': bboxes_labels, 'polygons': polygons, 'polygons_labels': polygons_labels}
370
+ else:
371
+ raise ValueError('Unknown task answer post processing type: {}'.format(task_answer_post_processing_type))
372
+
373
+ final_answer = {
374
+ task: final_answer}
375
+ return final_answer
376
+
377
+ class BoxQuantizer(object):
378
+ def __init__(self, mode, bins):
379
+ self.mode = mode
380
+ self.bins = bins
381
+
382
+ def quantize(self, boxes: torch.Tensor, size):
383
+ bins_w, bins_h = self.bins # Quantization bins.
384
+ size_w, size_h = size # Original image size.
385
+ size_per_bin_w = size_w / bins_w
386
+ size_per_bin_h = size_h / bins_h
387
+ xmin, ymin, xmax, ymax = boxes.split(1, dim=-1) # Shape: 4 * [N, 1].
388
+
389
+ if self.mode == 'floor':
390
+ quantized_xmin = (
391
+ xmin / size_per_bin_w).floor().clamp(0, bins_w - 1)
392
+ quantized_ymin = (
393
+ ymin / size_per_bin_h).floor().clamp(0, bins_h - 1)
394
+ quantized_xmax = (
395
+ xmax / size_per_bin_w).floor().clamp(0, bins_w - 1)
396
+ quantized_ymax = (
397
+ ymax / size_per_bin_h).floor().clamp(0, bins_h - 1)
398
+
399
+ elif self.mode == 'round':
400
+ raise NotImplementedError()
401
+
402
+ else:
403
+ raise ValueError('Incorrect quantization type.')
404
+
405
+ quantized_boxes = torch.cat(
406
+ (quantized_xmin, quantized_ymin, quantized_xmax, quantized_ymax), dim=-1
407
+ ).int()
408
+
409
+ return quantized_boxes
410
+
411
+ def dequantize(self, boxes: torch.Tensor, size):
412
+ bins_w, bins_h = self.bins # Quantization bins.
413
+ size_w, size_h = size # Original image size.
414
+ size_per_bin_w = size_w / bins_w
415
+ size_per_bin_h = size_h / bins_h
416
+ xmin, ymin, xmax, ymax = boxes.split(1, dim=-1) # Shape: 4 * [N, 1].
417
+
418
+ if self.mode == 'floor':
419
+ # Add 0.5 to use the center position of the bin as the coordinate.
420
+ dequantized_xmin = (xmin + 0.5) * size_per_bin_w
421
+ dequantized_ymin = (ymin + 0.5) * size_per_bin_h
422
+ dequantized_xmax = (xmax + 0.5) * size_per_bin_w
423
+ dequantized_ymax = (ymax + 0.5) * size_per_bin_h
424
+
425
+ elif self.mode == 'round':
426
+ raise NotImplementedError()
427
+
428
+ else:
429
+ raise ValueError('Incorrect quantization type.')
430
+
431
+ dequantized_boxes = torch.cat(
432
+ (dequantized_xmin, dequantized_ymin,
433
+ dequantized_xmax, dequantized_ymax), dim=-1
434
+ )
435
+
436
+ return dequantized_boxes
437
+
438
+
439
+ class CoordinatesQuantizer(object):
440
+ """
441
+ Quantize coornidates (Nx2)
442
+ """
443
+
444
+ def __init__(self, mode, bins):
445
+ self.mode = mode
446
+ self.bins = bins
447
+
448
+ def quantize(self, coordinates: torch.Tensor, size):
449
+ bins_w, bins_h = self.bins # Quantization bins.
450
+ size_w, size_h = size # Original image size.
451
+ size_per_bin_w = size_w / bins_w
452
+ size_per_bin_h = size_h / bins_h
453
+ assert coordinates.shape[-1] == 2, 'coordinates should be shape (N, 2)'
454
+ x, y = coordinates.split(1, dim=-1) # Shape: 4 * [N, 1].
455
+
456
+ if self.mode == 'floor':
457
+ quantized_x = (x / size_per_bin_w).floor().clamp(0, bins_w - 1)
458
+ quantized_y = (y / size_per_bin_h).floor().clamp(0, bins_h - 1)
459
+
460
+ elif self.mode == 'round':
461
+ raise NotImplementedError()
462
+
463
+ else:
464
+ raise ValueError('Incorrect quantization type.')
465
+
466
+ quantized_coordinates = torch.cat(
467
+ (quantized_x, quantized_y), dim=-1
468
+ ).int()
469
+
470
+ return quantized_coordinates
471
+
472
+ def dequantize(self, coordinates: torch.Tensor, size):
473
+ bins_w, bins_h = self.bins # Quantization bins.
474
+ size_w, size_h = size # Original image size.
475
+ size_per_bin_w = size_w / bins_w
476
+ size_per_bin_h = size_h / bins_h
477
+ assert coordinates.shape[-1] == 2, 'coordinates should be shape (N, 2)'
478
+ x, y = coordinates.split(1, dim=-1) # Shape: 4 * [N, 1].
479
+
480
+ if self.mode == 'floor':
481
+ # Add 0.5 to use the center position of the bin as the coordinate.
482
+ dequantized_x = (x + 0.5) * size_per_bin_w
483
+ dequantized_y = (y + 0.5) * size_per_bin_h
484
+
485
+ elif self.mode == 'round':
486
+ raise NotImplementedError()
487
+
488
+ else:
489
+ raise ValueError('Incorrect quantization type.')
490
+
491
+ dequantized_coordinates = torch.cat(
492
+ (dequantized_x, dequantized_y), dim=-1
493
+ )
494
+
495
+ return dequantized_coordinates
496
+
497
+
498
+ class Florence2PostProcesser(object):
499
+ """
500
+ Florence-2 post process for converting text prediction to various tasks results.
501
+
502
+ Args:
503
+ config: A dict of configs.
504
+ tokenizer: A tokenizer for decoding text to spans.
505
+ sample config:
506
+ UNIFIED_POST_PROCESS:
507
+ # commom configs
508
+ NUM_BBOX_HEIGHT_BINS: 1000
509
+ NUM_BBOX_WIDTH_BINS: 1000
510
+ COORDINATES_HEIGHT_BINS: 1000
511
+ COORDINATES_WIDTH_BINS: 1000
512
+ # task specific configs, override the common configs
513
+ PRASE_TASKS:
514
+ - TASK_NAME: 'video_dense_caption'
515
+ PATTERN: 'r<time_(\d+)><time_(\d+)>([a-zA-Z0-9 ]+)'
516
+ SCORE_MODE: 'avg_cat_name_scores'
517
+ NUM_BINS: 100
518
+ - TASK_NAME: 'od'
519
+ PATTERN: 'r<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>([a-zA-Z0-9 ]+)'
520
+ SCORE_MODE: 'avg_cat_name_scores'
521
+
522
+ Returns:
523
+ parsed_dict (dict): A dict of parsed results.
524
+ """
525
+ def __init__(
526
+ self,
527
+ tokenizer=None
528
+ ):
529
+ parse_tasks = []
530
+ parse_task_configs = {}
531
+ config = self._create_default_config()
532
+ for task in config['PARSE_TASKS']:
533
+ parse_tasks.append(task['TASK_NAME'])
534
+ parse_task_configs[task['TASK_NAME']] = task
535
+
536
+ self.config = config
537
+ self.parse_tasks = parse_tasks
538
+ self.parse_tasks_configs = parse_task_configs
539
+
540
+ self.tokenizer = tokenizer
541
+ if self.tokenizer is not None:
542
+ self.all_special_tokens = set(self.tokenizer.all_special_tokens)
543
+
544
+ self.init_quantizers()
545
+ self.black_list_of_phrase_grounding = self._create_black_list_of_phrase_grounding()
546
+
547
+ def _create_black_list_of_phrase_grounding(self):
548
+ black_list = {}
549
+
550
+ if 'phrase_grounding' in self.parse_tasks and self.parse_tasks_configs['phrase_grounding']['FILTER_BY_BLACK_LIST']:
551
+ black_list = set(
552
+ ['it', 'I', 'me', 'mine',
553
+ 'you', 'your', 'yours',
554
+ 'he', 'him', 'his',
555
+ 'she', 'her', 'hers',
556
+ 'they', 'them', 'their', 'theirs',
557
+ 'one', 'oneself',
558
+ 'we', 'us', 'our', 'ours',
559
+ 'you', 'your', 'yours',
560
+ 'they', 'them', 'their', 'theirs',
561
+ 'mine', 'yours', 'his', 'hers', 'its',
562
+ 'ours', 'yours', 'theirs',
563
+ 'myself', 'yourself', 'himself', 'herself', 'itself',
564
+ 'ourselves', 'yourselves', 'themselves',
565
+ 'this', 'that',
566
+ 'these', 'those',
567
+ 'who', 'whom', 'whose', 'which', 'what',
568
+ 'who', 'whom', 'whose', 'which', 'that',
569
+ 'all', 'another', 'any', 'anybody', 'anyone', 'anything',
570
+ 'each', 'everybody', 'everyone', 'everything',
571
+ 'few', 'many', 'nobody', 'none', 'one', 'several',
572
+ 'some', 'somebody', 'someone', 'something',
573
+ 'each other', 'one another',
574
+ 'myself', 'yourself', 'himself', 'herself', 'itself',
575
+ 'ourselves', 'yourselves', 'themselves',
576
+ 'the image', 'image', 'images', 'the', 'a', 'an', 'a group',
577
+ 'other objects', 'lots', 'a set',
578
+ ]
579
+ )
580
+
581
+ return black_list
582
+
583
+ def _create_default_config(self):
584
+ config = {
585
+ 'NUM_BBOX_HEIGHT_BINS': 1000,
586
+ 'NUM_BBOX_WIDTH_BINS': 1000,
587
+ 'BOX_QUANTIZATION_MODE': 'floor',
588
+ 'COORDINATES_HEIGHT_BINS': 1000,
589
+ 'COORDINATES_WIDTH_BINS': 1000,
590
+ 'COORDINATES_QUANTIZATION_MODE': 'floor',
591
+ 'PARSE_TASKS': [
592
+ {
593
+ 'TASK_NAME': 'od',
594
+ 'PATTERN': r'([a-zA-Z0-9 ]+)<loc_(\\d+)><loc_(\\d+)><loc_(\\d+)><loc_(\\d+)>'
595
+ },
596
+ {
597
+ 'TASK_NAME': 'ocr',
598
+ 'PATTERN': r'(.+?)<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>',
599
+ 'AREA_THRESHOLD': 0.01
600
+ },
601
+ {
602
+ 'TASK_NAME': 'phrase_grounding',
603
+ 'FILTER_BY_BLACK_LIST': True
604
+ },
605
+ {
606
+ 'TASK_NAME': 'pure_text',
607
+ },
608
+ {
609
+ 'TASK_NAME': 'description_with_bboxes',
610
+ },
611
+ {
612
+ 'TASK_NAME': 'description_with_polygons',
613
+ },
614
+ {
615
+ 'TASK_NAME': 'polygons',
616
+ },
617
+ {
618
+ 'TASK_NAME': 'bboxes',
619
+ },
620
+ {
621
+ 'TASK_NAME': 'description_with_bboxes_or_polygons',
622
+ }
623
+ ]
624
+ }
625
+
626
+ return config
627
+
628
+ def init_quantizers(self):
629
+ # we have box_quantizer (od, grounding) and coordinates_quantizer (ocr, referring_segmentation)
630
+ num_bbox_height_bins = self.config.get('NUM_BBOX_HEIGHT_BINS', 1000)
631
+ num_bbox_width_bins = self.config.get('NUM_BBOX_WIDTH_BINS', 1000)
632
+ box_quantization_mode = self.config.get('BOX_QUANTIZATION_MODE', 'floor')
633
+ self.box_quantizer = BoxQuantizer(
634
+ box_quantization_mode,
635
+ (num_bbox_width_bins, num_bbox_height_bins),
636
+ )
637
+
638
+ num_bbox_height_bins = self.config['COORDINATES_HEIGHT_BINS'] if 'COORDINATES_HEIGHT_BINS' in self.config else self.config.get('NUM_BBOX_HEIGHT_BINS', 1000)
639
+ num_bbox_width_bins = self.config['COORDINATES_WIDTH_BINS'] if 'COORDINATES_WIDTH_BINS' in self.config else self.config.get('NUM_BBOX_WIDTH_BINS', 1000)
640
+ box_quantization_mode = self.config.get('COORDINATES_QUANTIZATION_MODE') if 'COORDINATES_QUANTIZATION_MODE' in self.config else self.config.get('BOX_QUANTIZATION_MODE', 'floor')
641
+ self.coordinates_quantizer = CoordinatesQuantizer(
642
+ box_quantization_mode,
643
+ (num_bbox_width_bins, num_bbox_height_bins),
644
+ )
645
+
646
+ def decode_with_spans(self, tokenizer, token_ids):
647
+ filtered_tokens = tokenizer.convert_ids_to_tokens(
648
+ token_ids, skip_special_tokens=False)
649
+ assert len(filtered_tokens) == len(token_ids)
650
+
651
+ # To avoid mixing byte-level and unicode for byte-level BPT
652
+ # we need to build string separately for added tokens and byte-level tokens
653
+ # cf. https://github.com/huggingface/transformers/issues/1133
654
+ sub_texts = []
655
+ for token in filtered_tokens:
656
+ if token in self.all_special_tokens:
657
+ sub_texts.append(token)
658
+ else:
659
+ if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)):
660
+ sub_text = tokenizer.convert_tokens_to_string([token])
661
+ elif isinstance(tokenizer, (T5Tokenizer, T5TokenizerFast)):
662
+ # Ref: https://github.com/google/sentencepiece#whitespace-is-treated-as-a-basic-symbol
663
+ # Note: Do not strip sub_text as it may have functional whitespace
664
+ sub_text = token.replace('▁', ' ')
665
+ else:
666
+ raise ValueError(f'type {type(tokenizer)} not supported')
667
+ sub_texts.append(sub_text)
668
+
669
+ text = ''
670
+ spans = []
671
+ for sub_text in sub_texts:
672
+ span = (len(text), len(text) + len(sub_text)) # [start index, end index).
673
+ text += sub_text
674
+ spans.append(span)
675
+
676
+ # Text format:
677
+ # 1. T5Tokenizer/T5TokenizerFast:
678
+ # "<loc_1><loc_2><loc_3><loc_4> transplanting dog<loc_1><loc_2><loc_3><loc_4> cat</s>"
679
+ # Equivalent to t5_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
680
+ # 2. BartTokenizer (need to double check):
681
+ # "<s><loc_1><loc_2><loc_3><loc_4>transplanting dog<loc_1><loc_2><loc_3><loc_4>cat</s>"
682
+ # Equivalent to bart_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
683
+ return text, spans
684
+
685
+ def parse_od_from_text_and_spans(
686
+ self,
687
+ text,
688
+ pattern,
689
+ image_size,
690
+ phrase_centric=False
691
+ ):
692
+ parsed = list(re.finditer(pattern, text))
693
+
694
+ instances = []
695
+ for i in range(len(parsed)):
696
+ # Prepare instance.
697
+ instance = {}
698
+
699
+ if phrase_centric:
700
+ bbox_bins = [int(parsed[i].group(j)) for j in range(2, 6)]
701
+ else:
702
+ bbox_bins = [int(parsed[i].group(j)) for j in range(1, 5)]
703
+ instance['bbox'] = self.box_quantizer.dequantize(
704
+ boxes=torch.tensor(bbox_bins),
705
+ size=image_size
706
+ ).tolist()
707
+
708
+ if phrase_centric:
709
+ instance['cat_name'] = parsed[i].group(1).lower().strip()
710
+ else:
711
+ instance['cat_name'] = parsed[i].group(5).lower().strip()
712
+ instances.append(instance)
713
+
714
+ return instances
715
+
716
+ def parse_ocr_from_text_and_spans(self,
717
+ text,
718
+ pattern,
719
+ image_size,
720
+ area_threshold=-1.0,
721
+ ):
722
+ bboxes = []
723
+ labels = []
724
+ text = text.replace('<s>', '')
725
+ # ocr with regions
726
+ parsed = re.findall(pattern, text)
727
+ instances = []
728
+ image_width, image_height = image_size
729
+
730
+ for ocr_line in parsed:
731
+ ocr_content = ocr_line[0]
732
+ quad_box = ocr_line[1:]
733
+ quad_box = [int(i) for i in quad_box]
734
+ quad_box = self.coordinates_quantizer.dequantize(
735
+ torch.tensor(np.array(quad_box).reshape(-1, 2)),
736
+ size=image_size
737
+ ).reshape(-1).tolist()
738
+
739
+ if area_threshold > 0:
740
+ x_coords = [i for i in quad_box[0::2]]
741
+ y_coords = [i for i in quad_box[1::2]]
742
+
743
+ # apply the Shoelace formula
744
+ area = 0.5 * abs(sum(x_coords[i] * y_coords[i + 1] - x_coords[i + 1] * y_coords[i] for i in range(4 - 1)))
745
+
746
+ if area < (image_width * image_height) * area_threshold:
747
+ continue
748
+
749
+ bboxes.append(quad_box)
750
+ labels.append(ocr_content)
751
+ instances.append({
752
+ 'quad_box': quad_box,
753
+ 'text': ocr_content,
754
+ })
755
+ return instances
756
+
757
+ def parse_phrase_grounding_from_text_and_spans(self, text, pattern, image_size):
758
+ # ignore <s> </s> and <pad>
759
+ cur_span = 0
760
+ if text.startswith('<s>'):
761
+ cur_span += 3
762
+
763
+ text = text.replace('<s>', '')
764
+ text = text.replace('</s>', '')
765
+ text = text.replace('<pad>', '')
766
+
767
+ pattern = r"([^<]+(?:<loc_\d+>){4,})"
768
+ phrases = re.findall(pattern, text)
769
+
770
+ # pattern should be text pattern and od pattern
771
+ pattern = r'^\s*(.*?)(?=<od>|</od>|<box>|</box>|<bbox>|</bbox>|<loc_)'
772
+ box_pattern = r'<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>'
773
+
774
+ instances = []
775
+ for pharse_text in phrases:
776
+ phrase_text_strip = pharse_text.replace('<ground>', '', 1)
777
+ phrase_text_strip = pharse_text.replace('<obj>', '', 1)
778
+
779
+ if phrase_text_strip == '':
780
+ cur_span += len(pharse_text)
781
+ continue
782
+
783
+ # Prepare instance.
784
+ instance = {}
785
+
786
+ # parse phrase, get string
787
+ phrase = re.search(pattern, phrase_text_strip)
788
+ if phrase is None:
789
+ cur_span += len(pharse_text)
790
+ continue
791
+
792
+ # parse bboxes by box_pattern
793
+ bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
794
+ if len(bboxes_parsed) == 0:
795
+ cur_span += len(pharse_text)
796
+ continue
797
+
798
+ phrase = phrase.group()
799
+ # remove leading and trailing spaces
800
+ phrase = phrase.strip()
801
+
802
+ if phrase in self.black_list_of_phrase_grounding:
803
+ cur_span += len(pharse_text)
804
+ continue
805
+
806
+ # a list of list
807
+ bbox_bins = [[int(_bboxes_parsed.group(j)) for j in range(1, 5)] for _bboxes_parsed in bboxes_parsed]
808
+ instance['bbox'] = self.box_quantizer.dequantize(
809
+ boxes=torch.tensor(bbox_bins),
810
+ size=image_size
811
+ ).tolist()
812
+
813
+ # exclude non-ascii characters
814
+ phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
815
+ instance['cat_name'] = phrase
816
+
817
+ instances.append(instance)
818
+
819
+ return instances
820
+
821
+ def parse_description_with_bboxes_from_text_and_spans(self, text, pattern, image_size, allow_empty_phrase=False):
822
+ # temporary parse solution, split by '.'
823
+ # ignore <s> </s> and <pad>
824
+
825
+ text = text.replace('<s>', '')
826
+ text = text.replace('</s>', '')
827
+ text = text.replace('<pad>', '')
828
+
829
+ if allow_empty_phrase:
830
+ pattern = rf"(?:(?:<loc_\d+>){{4,}})"
831
+ else:
832
+ pattern = r"([^<]+(?:<loc_\d+>){4,})"
833
+ phrases = re.findall(pattern, text)
834
+
835
+ # pattern should be text pattern and od pattern
836
+ pattern = r'^\s*(.*?)(?=<od>|</od>|<box>|</box>|<bbox>|</bbox>|<loc_)'
837
+ box_pattern = r'<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>'
838
+
839
+ instances = []
840
+ for pharse_text in phrases:
841
+ phrase_text_strip = pharse_text.replace('<ground>', '', 1)
842
+ phrase_text_strip = pharse_text.replace('<obj>', '', 1)
843
+
844
+ if phrase_text_strip == '' and not allow_empty_phrase:
845
+ continue
846
+
847
+ # parse phrase, get string
848
+ phrase = re.search(pattern, phrase_text_strip)
849
+ if phrase is None:
850
+ continue
851
+
852
+ phrase = phrase.group()
853
+ # remove leading and trailing spaces
854
+ phrase = phrase.strip()
855
+
856
+ # parse bboxes by box_pattern
857
+ bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
858
+ if len(bboxes_parsed) == 0:
859
+ continue
860
+
861
+ # a list of list
862
+ bbox_bins = [[int(_bboxes_parsed.group(j)) for j in range(1, 5)] for _bboxes_parsed in bboxes_parsed]
863
+
864
+ bboxes = self.box_quantizer.dequantize(
865
+ boxes=torch.tensor(bbox_bins),
866
+ size=image_size
867
+ ).tolist()
868
+
869
+ phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
870
+ for _bboxes in bboxes:
871
+ # Prepare instance.
872
+ instance = {}
873
+ instance['bbox'] = _bboxes
874
+ # exclude non-ascii characters
875
+ instance['cat_name'] = phrase
876
+ instances.append(instance)
877
+
878
+ return instances
879
+
880
+ def parse_description_with_polygons_from_text_and_spans(self, text, pattern, image_size,
881
+ allow_empty_phrase=False,
882
+ polygon_sep_token='<sep>',
883
+ polygon_start_token='<poly>',
884
+ polygon_end_token='</poly>',
885
+ with_box_at_start=False,
886
+ ):
887
+
888
+ # ref_seg format: '<expression><x1><y1><x2><y2><><><sep><><><><>'
889
+ # ignore <s> </s> and <pad>
890
+
891
+ text = text.replace('<s>', '')
892
+ text = text.replace('</s>', '')
893
+ text = text.replace('<pad>', '')
894
+
895
+ if allow_empty_phrase:
896
+ pattern = rf"(?:(?:<loc_\d+>|{re.escape(polygon_sep_token)}|{re.escape(polygon_start_token)}|{re.escape(polygon_end_token)}){{4,}})"
897
+ else:
898
+ # [^<]+: This part matches one or more characters that are not the < symbol.
899
+ # The ^ inside the square brackets [] is a negation, meaning it matches anything except <.
900
+ #
901
+ pattern = rf"([^<]+(?:<loc_\d+>|{re.escape(polygon_sep_token)}|{re.escape(polygon_start_token)}|{re.escape(polygon_end_token)}){{4,}})"
902
+ phrases = re.findall(pattern, text)
903
+
904
+ phrase_string_pattern = r'^\s*(.*?)(?=<od>|</od>|<box>|</box>|<bbox>|</bbox>|<loc_|<poly>)'
905
+ box_pattern = rf'((?:<loc_\d+>)+)(?:{re.escape(polygon_sep_token)}|$)'
906
+
907
+ # one polygons instance is separated by polygon_start_token and polygon_end_token
908
+ polygons_instance_pattern = rf'{re.escape(polygon_start_token)}(.*?){re.escape(polygon_end_token)}'
909
+
910
+ instances = []
911
+ for phrase_text in phrases:
912
+
913
+ # exclude loc_\d+>
914
+ # need to get span if want to include category score
915
+ phrase_text_strip = re.sub(r'^loc_\d+>', '', phrase_text, count=1)
916
+
917
+ # phrase = phrase.replace('<poly>', '')
918
+ # phrase = phrase.replace('poly>', '')
919
+
920
+ if phrase_text_strip == '' and not allow_empty_phrase:
921
+ continue
922
+
923
+
924
+ # parse phrase, get string
925
+ phrase = re.search(phrase_string_pattern, phrase_text_strip)
926
+ if phrase is None:
927
+ continue
928
+ phrase = phrase.group()
929
+ # remove leading and trailing spaces
930
+ phrase = phrase.strip()
931
+
932
+ # parse bboxes by box_pattern
933
+
934
+ # split by polygon_start_token and polygon_end_token first using polygons_instance_pattern
935
+ if polygon_start_token in phrase_text and polygon_end_token in phrase_text:
936
+ polygons_instances_parsed = list(re.finditer(polygons_instance_pattern, phrase_text))
937
+ else:
938
+ polygons_instances_parsed = [phrase_text]
939
+
940
+ for _polygons_instances_parsed in polygons_instances_parsed:
941
+ # Prepare instance.
942
+ instance = {}
943
+
944
+ # polygons_parsed= list(re.finditer(box_pattern, phrase_text))
945
+ if isinstance(_polygons_instances_parsed, str):
946
+ polygons_parsed= list(re.finditer(box_pattern, _polygons_instances_parsed))
947
+ else:
948
+ polygons_parsed= list(re.finditer(box_pattern, _polygons_instances_parsed.group(1)))
949
+ if len(polygons_parsed) == 0:
950
+ continue
951
+
952
+ # a list of list (polygon)
953
+ bbox = []
954
+ polygons = []
955
+ for _polygon_parsed in polygons_parsed:
956
+ # group 1: whole <loc_\d+>...</loc_\d+>
957
+ _polygon = _polygon_parsed.group(1)
958
+ # parse into list of int
959
+ _polygon = [int(_loc_parsed.group(1)) for _loc_parsed in re.finditer(r'<loc_(\d+)>', _polygon)]
960
+ if with_box_at_start and len(bbox) == 0:
961
+ if len(_polygon) > 4:
962
+ # no valid bbox prediction
963
+ bbox = _polygon[:4]
964
+ _polygon = _polygon[4:]
965
+ else:
966
+ bbox = [0, 0, 0, 0]
967
+ # abandon last element if is not paired
968
+ if len(_polygon) % 2 == 1:
969
+ _polygon = _polygon[:-1]
970
+
971
+ # reshape into (n, 2)
972
+ _polygon = self.coordinates_quantizer.dequantize(
973
+ torch.tensor(np.array(_polygon).reshape(-1, 2)),
974
+ size=image_size
975
+ ).reshape(-1).tolist()
976
+ # reshape back
977
+ polygons.append(_polygon)
978
+
979
+ instance['cat_name'] = phrase
980
+ instance['polygons'] = polygons
981
+ if len(bbox) != 0:
982
+ instance['bbox'] = self.box_quantizer.dequantize(
983
+ boxes=torch.tensor([bbox]),
984
+ size=image_size
985
+ ).tolist()[0]
986
+
987
+ instances.append(instance)
988
+
989
+ return instances
990
+
991
+ def __call__(
992
+ self,
993
+ text=None,
994
+ image_size=None,
995
+ parse_tasks=None,
996
+ ):
997
+ """
998
+ Args:
999
+ text: model outputs
1000
+ image_size: (width, height)
1001
+ parse_tasks: a list of tasks to parse, if None, parse all tasks.
1002
+
1003
+ """
1004
+ if parse_tasks is not None:
1005
+ if isinstance(parse_tasks, str):
1006
+ parse_tasks = [parse_tasks]
1007
+ for _parse_task in parse_tasks:
1008
+ assert _parse_task in self.parse_tasks, f'parse task {_parse_task} not supported'
1009
+
1010
+ # sequence or text should be provided
1011
+ assert text is not None, 'text should be provided'
1012
+
1013
+ parsed_dict = {
1014
+ 'text': text
1015
+ }
1016
+
1017
+ for task in self.parse_tasks:
1018
+ if parse_tasks is not None and task not in parse_tasks:
1019
+ continue
1020
+
1021
+ pattern = self.parse_tasks_configs[task].get('PATTERN', None)
1022
+
1023
+ if task == 'ocr':
1024
+ instances = self.parse_ocr_from_text_and_spans(
1025
+ text,
1026
+ pattern=pattern,
1027
+ image_size=image_size,
1028
+ area_threshold=self.parse_tasks_configs[task].get('AREA_THRESHOLD', 0.01),
1029
+ )
1030
+ parsed_dict['ocr'] = instances
1031
+ elif task == 'phrase_grounding':
1032
+ instances = self.parse_phrase_grounding_from_text_and_spans(
1033
+ text,
1034
+ pattern=pattern,
1035
+ image_size=image_size,
1036
+ )
1037
+ parsed_dict['phrase_grounding'] = instances
1038
+ elif task == 'pure_text':
1039
+ parsed_dict['pure_text'] = text
1040
+ elif task == 'description_with_bboxes':
1041
+ instances = self.parse_description_with_bboxes_from_text_and_spans(
1042
+ text,
1043
+ pattern=pattern,
1044
+ image_size=image_size,
1045
+ )
1046
+ parsed_dict['description_with_bboxes'] = instances
1047
+ elif task == 'description_with_polygons':
1048
+ instances = self.parse_description_with_polygons_from_text_and_spans(
1049
+ text,
1050
+ pattern=pattern,
1051
+ image_size=image_size,
1052
+ )
1053
+ parsed_dict['description_with_polygons'] = instances
1054
+ elif task == 'polygons':
1055
+ instances = self.parse_description_with_polygons_from_text_and_spans(
1056
+ text,
1057
+ pattern=pattern,
1058
+ image_size=image_size,
1059
+ allow_empty_phrase=True,
1060
+ )
1061
+ parsed_dict['polygons'] = instances
1062
+ elif task == 'bboxes':
1063
+ instances = self.parse_description_with_bboxes_from_text_and_spans(
1064
+ text,
1065
+ pattern=pattern,
1066
+ image_size=image_size,
1067
+ allow_empty_phrase=True,
1068
+ )
1069
+ parsed_dict['bboxes'] = instances
1070
+ elif task == 'description_with_bboxes_or_polygons':
1071
+ if '<poly>' in text:
1072
+ # only support either polygons or bboxes, not both at the same time
1073
+ instances = self.parse_description_with_polygons_from_text_and_spans(
1074
+ text,
1075
+ pattern=pattern,
1076
+ image_size=image_size,
1077
+ )
1078
+ else:
1079
+ instances = self.parse_description_with_bboxes_from_text_and_spans(
1080
+ text,
1081
+ pattern=pattern,
1082
+ image_size=image_size,
1083
+ )
1084
+ parsed_dict['description_with_bboxes_or_polygons'] = instances
1085
+ else:
1086
+ raise ValueError("task {} is not supported".format(task))
1087
+
1088
+ return parsed_dict