Files changed (1) hide show
  1. modeling_florence2.py +15 -7
modeling_florence2.py CHANGED
@@ -2643,7 +2643,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2643
  return x
2644
 
2645
  def _merge_input_ids_with_image_features(
2646
- self, image_features, inputs_embeds
2647
  ):
2648
  batch_size, image_token_length = image_features.size()[:-1]
2649
  device = image_features.device
@@ -2655,10 +2655,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2655
  return image_features, image_attention_mask
2656
 
2657
  task_prefix_embeds = inputs_embeds
2658
- task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
 
 
2659
 
2660
- if len(task_prefix_attention_mask.shape) == 3:
2661
- task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
2662
 
2663
  # concat [image embeds, task prefix embeds]
2664
  inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
@@ -2719,12 +2721,14 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2719
  >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
2720
  "A green car parked in front of a yellow building."
2721
  ```"""
 
 
2722
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
2723
  output_hidden_states = (
2724
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
2725
  )
2726
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2727
-
2728
  image_features = None
2729
  if inputs_embeds is None:
2730
  # 1. Extra the input embeddings
@@ -2735,7 +2739,9 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2735
  # (batch_size, num_image_tokens, hidden_size)
2736
  image_features = self._encode_image(pixel_values)
2737
  inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
2738
-
 
 
2739
  if inputs_embeds is not None:
2740
  attention_mask = attention_mask.to(inputs_embeds.dtype)
2741
  outputs = self.language_model(
@@ -2781,6 +2787,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2781
  input_ids,
2782
  inputs_embeds=None,
2783
  pixel_values=None,
 
2784
  **kwargs
2785
  ):
2786
 
@@ -2791,11 +2798,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2791
  # 2. Merge text and images
2792
  if pixel_values is not None:
2793
  image_features = self._encode_image(pixel_values)
2794
- inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
2795
 
2796
  return self.language_model.generate(
2797
  input_ids=None,
2798
  inputs_embeds=inputs_embeds,
 
2799
  **kwargs
2800
  )
2801
 
 
2643
  return x
2644
 
2645
  def _merge_input_ids_with_image_features(
2646
+ self, image_features, inputs_embeds, task_prefix_attention_mask=None
2647
  ):
2648
  batch_size, image_token_length = image_features.size()[:-1]
2649
  device = image_features.device
 
2655
  return image_features, image_attention_mask
2656
 
2657
  task_prefix_embeds = inputs_embeds
2658
+
2659
+ if task_prefix_attention_mask is None:
2660
+ task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
2661
 
2662
+ if len(task_prefix_attention_mask.shape) == 3:
2663
+ task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
2664
 
2665
  # concat [image embeds, task prefix embeds]
2666
  inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
 
2721
  >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
2722
  "A green car parked in front of a yellow building."
2723
  ```"""
2724
+ print("asdasdasdasdasdasdasdasda")
2725
+
2726
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
2727
  output_hidden_states = (
2728
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
2729
  )
2730
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2731
+ print("asdasdasdasdasdasdasdasda")
2732
  image_features = None
2733
  if inputs_embeds is None:
2734
  # 1. Extra the input embeddings
 
2739
  # (batch_size, num_image_tokens, hidden_size)
2740
  image_features = self._encode_image(pixel_values)
2741
  inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
2742
+
2743
+ print(attention_mask)
2744
+
2745
  if inputs_embeds is not None:
2746
  attention_mask = attention_mask.to(inputs_embeds.dtype)
2747
  outputs = self.language_model(
 
2787
  input_ids,
2788
  inputs_embeds=None,
2789
  pixel_values=None,
2790
+ attention_mask=None,
2791
  **kwargs
2792
  ):
2793
 
 
2798
  # 2. Merge text and images
2799
  if pixel_values is not None:
2800
  image_features = self._encode_image(pixel_values)
2801
+ inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=attention_mask)
2802
 
2803
  return self.language_model.generate(
2804
  input_ids=None,
2805
  inputs_embeds=inputs_embeds,
2806
+ attention_mask=attention_mask,
2807
  **kwargs
2808
  )
2809