haipingwu commited on
Commit
f0acedb
·
verified ·
1 Parent(s): f928440

add_confidence_score (#56)

Browse files

- add confidence score parsing (2ce7cd7837dfc8d93bf1d77dae95669ef1bcf0b3)

Files changed (2) hide show
  1. README.md +38 -0
  2. processing_florence2.py +83 -24
README.md CHANGED
@@ -190,6 +190,44 @@ prompt = "<OCR_WITH_REGION>"
190
  run_example(prompt)
191
  ```
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  for More detailed examples, please refer to [notebook](https://huggingface.co/microsoft/Florence-2-large/blob/main/sample_inference.ipynb)
194
  </details>
195
 
 
190
  run_example(prompt)
191
  ```
192
 
193
+ ### Output confidence score with Object Detection
194
+ ```python
195
+
196
+ def run_example_with_score(task_prompt, text_input=None):
197
+ if text_input is None:
198
+ prompt = task_prompt
199
+ else:
200
+ prompt = task_prompt + text_input
201
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
202
+ generated_ids = model.generate(
203
+ input_ids=inputs["input_ids"],
204
+ pixel_values=inputs["pixel_values"],
205
+ max_new_tokens=1024,
206
+ num_beams=3,
207
+ return_dict_in_generate=True,
208
+ output_scores=True,
209
+ )
210
+ generated_text = processor.batch_decode(generated_ids.sequences, skip_special_tokens=False)[0]
211
+
212
+ prediction, scores, beam_indices = generated_ids.sequences, generated_ids.scores, generated_ids.beam_indices
213
+ transition_beam_scores = model.compute_transition_scores(
214
+ sequences=prediction,
215
+ scores=scores,
216
+ beam_indices=beam_indices,
217
+ )
218
+
219
+ parsed_answer = processor.post_process_generation(sequence=generated_ids.sequences[0],
220
+ transition_beam_score=transition_beam_scores[0],
221
+ task=task_prompt, image_size=(image.width, image.height)
222
+ )
223
+
224
+ print(parsed_answer)
225
+
226
+ prompt = "<OD>"
227
+ run_example_with_score(prompt)
228
+
229
+ ```
230
+
231
  for More detailed examples, please refer to [notebook](https://huggingface.co/microsoft/Florence-2-large/blob/main/sample_inference.ipynb)
232
  </details>
233
 
processing_florence2.py CHANGED
@@ -20,6 +20,7 @@ import re
20
  import logging
21
  from typing import List, Optional, Union
22
  import numpy as np
 
23
 
24
  import torch
25
 
@@ -32,6 +33,7 @@ from transformers.tokenization_utils_base import (
32
  TextInput,
33
  TruncationStrategy,
34
  )
 
35
  from transformers.utils import TensorType
36
 
37
 
@@ -304,7 +306,7 @@ class Florence2Processor(ProcessorMixin):
304
  image_processor_input_names = self.image_processor.model_input_names
305
  return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
306
 
307
- def post_process_generation(self, text, task, image_size):
308
  """
309
  Post-process the output of the model to each of the task outputs.
310
 
@@ -317,6 +319,8 @@ class Florence2Processor(ProcessorMixin):
317
  task_answer_post_processing_type = self.tasks_answer_post_processing_type.get(task, 'pure_text')
318
  task_answer = self.post_processor(
319
  text=text,
 
 
320
  image_size=image_size,
321
  parse_tasks=task_answer_post_processing_type,
322
  )[task_answer_post_processing_type]
@@ -330,6 +334,9 @@ class Florence2Processor(ProcessorMixin):
330
  bboxes_od = [_od_instance['bbox'] for _od_instance in od_instances]
331
  labels_od = [str(_od_instance['cat_name']) for _od_instance in od_instances]
332
  final_answer = {'bboxes': bboxes_od, 'labels': labels_od}
 
 
 
333
  elif task_answer_post_processing_type in ['ocr']:
334
  bboxes = [_od_instance['quad_box'] for _od_instance in task_answer]
335
  labels = [str(_od_instance['text']) for _od_instance in task_answer]
@@ -591,7 +598,8 @@ class Florence2PostProcesser(object):
591
  'PARSE_TASKS': [
592
  {
593
  'TASK_NAME': 'od',
594
- 'PATTERN': r'([a-zA-Z0-9 ]+)<loc_(\\d+)><loc_(\\d+)><loc_(\\d+)><loc_(\\d+)>'
 
595
  },
596
  {
597
  'TASK_NAME': 'ocr',
@@ -607,6 +615,7 @@ class Florence2PostProcesser(object):
607
  },
608
  {
609
  'TASK_NAME': 'description_with_bboxes',
 
610
  },
611
  {
612
  'TASK_NAME': 'description_with_polygons',
@@ -647,10 +656,6 @@ class Florence2PostProcesser(object):
647
  filtered_tokens = tokenizer.convert_ids_to_tokens(
648
  token_ids, skip_special_tokens=False)
649
  assert len(filtered_tokens) == len(token_ids)
650
-
651
- # To avoid mixing byte-level and unicode for byte-level BPT
652
- # we need to build string separately for added tokens and byte-level tokens
653
- # cf. https://github.com/huggingface/transformers/issues/1133
654
  sub_texts = []
655
  for token in filtered_tokens:
656
  if token in self.all_special_tokens:
@@ -658,10 +663,6 @@ class Florence2PostProcesser(object):
658
  else:
659
  if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)):
660
  sub_text = tokenizer.convert_tokens_to_string([token])
661
- elif isinstance(tokenizer, (T5Tokenizer, T5TokenizerFast)):
662
- # Ref: https://github.com/google/sentencepiece#whitespace-is-treated-as-a-basic-symbol
663
- # Note: Do not strip sub_text as it may have functional whitespace
664
- sub_text = token.replace('▁', ' ')
665
  else:
666
  raise ValueError(f'type {type(tokenizer)} not supported')
667
  sub_texts.append(sub_text)
@@ -672,14 +673,6 @@ class Florence2PostProcesser(object):
672
  span = (len(text), len(text) + len(sub_text)) # [start index, end index).
673
  text += sub_text
674
  spans.append(span)
675
-
676
- # Text format:
677
- # 1. T5Tokenizer/T5TokenizerFast:
678
- # "<loc_1><loc_2><loc_3><loc_4> transplanting dog<loc_1><loc_2><loc_3><loc_4> cat</s>"
679
- # Equivalent to t5_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
680
- # 2. BartTokenizer (need to double check):
681
- # "<s><loc_1><loc_2><loc_3><loc_4>transplanting dog<loc_1><loc_2><loc_3><loc_4>cat</s>"
682
- # Equivalent to bart_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
683
  return text, spans
684
 
685
  def parse_od_from_text_and_spans(
@@ -714,7 +707,7 @@ class Florence2PostProcesser(object):
714
  return instances
715
 
716
  def parse_ocr_from_text_and_spans(self,
717
- text,
718
  pattern,
719
  image_size,
720
  area_threshold=-1.0,
@@ -818,9 +811,26 @@ class Florence2PostProcesser(object):
818
 
819
  return instances
820
 
821
- def parse_description_with_bboxes_from_text_and_spans(self, text, pattern, image_size, allow_empty_phrase=False):
822
- # temporary parse solution, split by '.'
823
- # ignore <s> </s> and <pad>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
824
 
825
  text = text.replace('<s>', '')
826
  text = text.replace('</s>', '')
@@ -842,13 +852,16 @@ class Florence2PostProcesser(object):
842
  phrase_text_strip = pharse_text.replace('<obj>', '', 1)
843
 
844
  if phrase_text_strip == '' and not allow_empty_phrase:
 
845
  continue
846
 
847
  # parse phrase, get string
848
  phrase = re.search(pattern, phrase_text_strip)
849
  if phrase is None:
 
850
  continue
851
 
 
852
  phrase = phrase.group()
853
  # remove leading and trailing spaces
854
  phrase = phrase.strip()
@@ -856,6 +869,7 @@ class Florence2PostProcesser(object):
856
  # parse bboxes by box_pattern
857
  bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
858
  if len(bboxes_parsed) == 0:
 
859
  continue
860
 
861
  # a list of list
@@ -866,14 +880,42 @@ class Florence2PostProcesser(object):
866
  size=image_size
867
  ).tolist()
868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
869
  phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
870
- for _bboxes in bboxes:
871
  # Prepare instance.
872
  instance = {}
873
  instance['bbox'] = _bboxes
874
  # exclude non-ascii characters
875
  instance['cat_name'] = phrase
 
 
876
  instances.append(instance)
 
 
877
 
878
  return instances
879
 
@@ -991,6 +1033,8 @@ class Florence2PostProcesser(object):
991
  def __call__(
992
  self,
993
  text=None,
 
 
994
  image_size=None,
995
  parse_tasks=None,
996
  ):
@@ -1008,7 +1052,18 @@ class Florence2PostProcesser(object):
1008
  assert _parse_task in self.parse_tasks, f'parse task {_parse_task} not supported'
1009
 
1010
  # sequence or text should be provided
1011
- assert text is not None, 'text should be provided'
 
 
 
 
 
 
 
 
 
 
 
1012
 
1013
  parsed_dict = {
1014
  'text': text
@@ -1019,6 +1074,7 @@ class Florence2PostProcesser(object):
1019
  continue
1020
 
1021
  pattern = self.parse_tasks_configs[task].get('PATTERN', None)
 
1022
 
1023
  if task == 'ocr':
1024
  instances = self.parse_ocr_from_text_and_spans(
@@ -1040,6 +1096,9 @@ class Florence2PostProcesser(object):
1040
  elif task == 'description_with_bboxes':
1041
  instances = self.parse_description_with_bboxes_from_text_and_spans(
1042
  text,
 
 
 
1043
  pattern=pattern,
1044
  image_size=image_size,
1045
  )
 
20
  import logging
21
  from typing import List, Optional, Union
22
  import numpy as np
23
+ import math
24
 
25
  import torch
26
 
 
33
  TextInput,
34
  TruncationStrategy,
35
  )
36
+ from transformers import BartTokenizer, BartTokenizerFast
37
  from transformers.utils import TensorType
38
 
39
 
 
306
  image_processor_input_names = self.image_processor.model_input_names
307
  return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
308
 
309
+ def post_process_generation(self, text=None, sequence=None, transition_beam_score=None, task=None, image_size=None):
310
  """
311
  Post-process the output of the model to each of the task outputs.
312
 
 
319
  task_answer_post_processing_type = self.tasks_answer_post_processing_type.get(task, 'pure_text')
320
  task_answer = self.post_processor(
321
  text=text,
322
+ sequence=sequence,
323
+ transition_beam_score=transition_beam_score,
324
  image_size=image_size,
325
  parse_tasks=task_answer_post_processing_type,
326
  )[task_answer_post_processing_type]
 
334
  bboxes_od = [_od_instance['bbox'] for _od_instance in od_instances]
335
  labels_od = [str(_od_instance['cat_name']) for _od_instance in od_instances]
336
  final_answer = {'bboxes': bboxes_od, 'labels': labels_od}
337
+ if len(od_instances) and 'score' in od_instances[0]:
338
+ scores_od = [_od_instance['score'] for _od_instance in od_instances]
339
+ final_answer['scores'] = scores_od
340
  elif task_answer_post_processing_type in ['ocr']:
341
  bboxes = [_od_instance['quad_box'] for _od_instance in task_answer]
342
  labels = [str(_od_instance['text']) for _od_instance in task_answer]
 
598
  'PARSE_TASKS': [
599
  {
600
  'TASK_NAME': 'od',
601
+ 'PATTERN': r'([a-zA-Z0-9 ]+)<loc_(\\d+)><loc_(\\d+)><loc_(\\d+)><loc_(\\d+)>',
602
+ 'SCORE_MODE': 'avg_loc_scores'
603
  },
604
  {
605
  'TASK_NAME': 'ocr',
 
615
  },
616
  {
617
  'TASK_NAME': 'description_with_bboxes',
618
+ 'SCORE_MODE': 'avg_loc_scores'
619
  },
620
  {
621
  'TASK_NAME': 'description_with_polygons',
 
656
  filtered_tokens = tokenizer.convert_ids_to_tokens(
657
  token_ids, skip_special_tokens=False)
658
  assert len(filtered_tokens) == len(token_ids)
 
 
 
 
659
  sub_texts = []
660
  for token in filtered_tokens:
661
  if token in self.all_special_tokens:
 
663
  else:
664
  if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)):
665
  sub_text = tokenizer.convert_tokens_to_string([token])
 
 
 
 
666
  else:
667
  raise ValueError(f'type {type(tokenizer)} not supported')
668
  sub_texts.append(sub_text)
 
673
  span = (len(text), len(text) + len(sub_text)) # [start index, end index).
674
  text += sub_text
675
  spans.append(span)
 
 
 
 
 
 
 
 
676
  return text, spans
677
 
678
  def parse_od_from_text_and_spans(
 
707
  return instances
708
 
709
  def parse_ocr_from_text_and_spans(self,
710
+ text,
711
  pattern,
712
  image_size,
713
  area_threshold=-1.0,
 
811
 
812
  return instances
813
 
814
+ def parse_description_with_bboxes_from_text_and_spans(
815
+ self,
816
+ text,
817
+ spans=None,
818
+ scores=None,
819
+ score_mode=None,
820
+ pattern=None,
821
+ image_size=None,
822
+ allow_empty_phrase=False
823
+ ):
824
+ def find_matched_token_indices(cur_span, token_spans):
825
+ inds = []
826
+ for i, token_span in enumerate(token_spans):
827
+ if not (token_span[1] <= cur_span[0] or token_span[0] >= cur_span[1]):
828
+ inds.append(i)
829
+ return inds
830
+
831
+ cur_span = 0
832
+ if text.startswith('<s>'):
833
+ cur_span += 3
834
 
835
  text = text.replace('<s>', '')
836
  text = text.replace('</s>', '')
 
852
  phrase_text_strip = pharse_text.replace('<obj>', '', 1)
853
 
854
  if phrase_text_strip == '' and not allow_empty_phrase:
855
+ cur_span += len(pharse_text)
856
  continue
857
 
858
  # parse phrase, get string
859
  phrase = re.search(pattern, phrase_text_strip)
860
  if phrase is None:
861
+ cur_span += len(pharse_text)
862
  continue
863
 
864
+ phrase_span = phrase.span()
865
  phrase = phrase.group()
866
  # remove leading and trailing spaces
867
  phrase = phrase.strip()
 
869
  # parse bboxes by box_pattern
870
  bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
871
  if len(bboxes_parsed) == 0:
872
+ cur_span += len(pharse_text)
873
  continue
874
 
875
  # a list of list
 
880
  size=image_size
881
  ).tolist()
882
 
883
+ if score_mode == 'avg_loc_scores':
884
+ if spans is None or scores is None:
885
+ all_scores = None
886
+ else:
887
+ bbox_end_spans = [_bboxes_parsed.span(0) for _bboxes_parsed in bboxes_parsed]
888
+ all_scores = []
889
+ for _spans in bbox_end_spans:
890
+ token_inds = find_matched_token_indices((_spans[0] + cur_span, _spans[1]+ cur_span), spans)
891
+ loc_scores = [scores[token_i] for token_i in token_inds]
892
+ score = sum(loc_scores) / len(loc_scores)
893
+ all_scores.append(score)
894
+ elif score_mode == 'avg_cat_name_scores':
895
+ if spans is None or scores is None:
896
+ all_scores = None
897
+ else:
898
+ cat_name_token_inds = find_matched_token_indices((phrase_span[0] + cur_span, phrase_span[1]+cur_span), spans)
899
+ cat_name_scores = [scores[token_i] for token_i in cat_name_token_inds]
900
+ score = sum(cat_name_scores) / len(cat_name_scores)
901
+ all_scores = [score] * len(bboxes)
902
+ elif score_mode is None:
903
+ all_scores = None
904
+ else:
905
+ raise ValueError('Unknown score mode: {}'.format(score_mode))
906
+
907
  phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
908
+ for _idx, _bboxes in enumerate(bboxes):
909
  # Prepare instance.
910
  instance = {}
911
  instance['bbox'] = _bboxes
912
  # exclude non-ascii characters
913
  instance['cat_name'] = phrase
914
+ if all_scores is not None:
915
+ instance['score'] = math.exp(all_scores[_idx])
916
  instances.append(instance)
917
+
918
+ cur_span += len(pharse_text)
919
 
920
  return instances
921
 
 
1033
  def __call__(
1034
  self,
1035
  text=None,
1036
+ sequence=None,
1037
+ transition_beam_score=None,
1038
  image_size=None,
1039
  parse_tasks=None,
1040
  ):
 
1052
  assert _parse_task in self.parse_tasks, f'parse task {_parse_task} not supported'
1053
 
1054
  # sequence or text should be provided
1055
+ assert sequence is not None or text is not None, 'sequence or text should be provided'
1056
+ assert sequence is None or text is None, 'only one of sequence and text should be provided'
1057
+
1058
+ if sequence is not None:
1059
+ sequence = sequence.tolist()[1:]
1060
+ text, spans = self.decode_with_spans(self.tokenizer, sequence)
1061
+ if transition_beam_score is not None:
1062
+ transition_beam_score = transition_beam_score.tolist()
1063
+ assert len(sequence) == len(transition_beam_score)
1064
+ else:
1065
+ spans = None
1066
+ transition_beam_score = None
1067
 
1068
  parsed_dict = {
1069
  'text': text
 
1074
  continue
1075
 
1076
  pattern = self.parse_tasks_configs[task].get('PATTERN', None)
1077
+ score_mode = self.parse_tasks_configs[task].get('SCORE_MODE', None)
1078
 
1079
  if task == 'ocr':
1080
  instances = self.parse_ocr_from_text_and_spans(
 
1096
  elif task == 'description_with_bboxes':
1097
  instances = self.parse_description_with_bboxes_from_text_and_spans(
1098
  text,
1099
+ spans=spans,
1100
+ scores=transition_beam_score,
1101
+ score_mode=score_mode,
1102
  pattern=pattern,
1103
  image_size=image_size,
1104
  )