pawlowskipawel commited on
Commit
c69c410
·
verified ·
1 Parent(s): 39ddb41

Add task_prefix_attention_mask argument to _merge_input_ids_with_image_features for better padding handling

Browse files

This PR introduces a small change in the _merge_input_ids_with_image_features function by adding a task_prefix_attention_mask=None argument. This enhancement ensures that when doing batch processing with padding to the max length, the attention mask correctly ignores padding tokens.

Changes Made:
1. Added task_prefix_attention_mask=None argument to _merge_input_ids_with_image_features function.
2. Updated the function to incorporate the provided attention mask, allowing it to ignore padding tokens during batch processing.

Below is an example demonstrating the issue and the improvement:
```python
prompts =["prompt", "longer prompt", "much much longer prompt"]

url = "https://huggingface.co./datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"

image = Image.open(requests.get(url, stream=True).raw)
images = [image] * len(prompts)

inputs = processor(text=prompts, images=images, return_tensors="pt", padding=True).to("cuda", torch.float16)

inputs_embeds = model.get_input_embeddings()(inputs.input_ids)
image_features = model._encode_image(inputs.pixel_values)

print(inputs.input_ids)
# Output:
# tensor([[ 0, 12501, 3320, 2, 1, 1],
# [ 0, 3479, 254, 14302, 2, 1],
# [ 0, 28431, 203, 1181, 14302, 2]], device='cuda:0')

# Before change
inputs_embeds, attention_mask = model._merge_input_ids_with_image_features(image_features, inputs_embeds)
print(attention_mask[:, -10:])
# Output:
# tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0')

# After change
inputs_embeds, attention_mask = model._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=inputs.attention_mask)
print(attention_mask[:, -10:])
# Output:
# tensor([[1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0')
```

Files changed (1) hide show
  1. modeling_florence2.py +10 -6
modeling_florence2.py CHANGED
@@ -2643,7 +2643,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2643
  return x
2644
 
2645
  def _merge_input_ids_with_image_features(
2646
- self, image_features, inputs_embeds
2647
  ):
2648
  batch_size, image_token_length = image_features.size()[:-1]
2649
  device = image_features.device
@@ -2655,10 +2655,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2655
  return image_features, image_attention_mask
2656
 
2657
  task_prefix_embeds = inputs_embeds
2658
- task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
2659
 
2660
- if len(task_prefix_attention_mask.shape) == 3:
2661
- task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
 
 
 
2662
 
2663
  # concat [image embeds, task prefix embeds]
2664
  inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
@@ -2734,7 +2736,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2734
  if pixel_values is not None:
2735
  # (batch_size, num_image_tokens, hidden_size)
2736
  image_features = self._encode_image(pixel_values)
2737
- inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
2738
 
2739
  if inputs_embeds is not None:
2740
  attention_mask = attention_mask.to(inputs_embeds.dtype)
@@ -2781,6 +2783,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2781
  input_ids,
2782
  inputs_embeds=None,
2783
  pixel_values=None,
 
2784
  **kwargs
2785
  ):
2786
 
@@ -2791,11 +2794,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2791
  # 2. Merge text and images
2792
  if pixel_values is not None:
2793
  image_features = self._encode_image(pixel_values)
2794
- inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
2795
 
2796
  return self.language_model.generate(
2797
  input_ids=None,
2798
  inputs_embeds=inputs_embeds,
 
2799
  **kwargs
2800
  )
2801
 
 
2643
  return x
2644
 
2645
  def _merge_input_ids_with_image_features(
2646
+ self, image_features, inputs_embeds, task_prefix_attention_mask=None
2647
  ):
2648
  batch_size, image_token_length = image_features.size()[:-1]
2649
  device = image_features.device
 
2655
  return image_features, image_attention_mask
2656
 
2657
  task_prefix_embeds = inputs_embeds
 
2658
 
2659
+ if task_prefix_attention_mask is None:
2660
+ task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
2661
+
2662
+ if len(task_prefix_attention_mask.shape) == 3:
2663
+ task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
2664
 
2665
  # concat [image embeds, task prefix embeds]
2666
  inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
 
2736
  if pixel_values is not None:
2737
  # (batch_size, num_image_tokens, hidden_size)
2738
  image_features = self._encode_image(pixel_values)
2739
+ inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=attention_mask)
2740
 
2741
  if inputs_embeds is not None:
2742
  attention_mask = attention_mask.to(inputs_embeds.dtype)
 
2783
  input_ids,
2784
  inputs_embeds=None,
2785
  pixel_values=None,
2786
+ attention_mask=None,
2787
  **kwargs
2788
  ):
2789
 
 
2794
  # 2. Merge text and images
2795
  if pixel_values is not None:
2796
  image_features = self._encode_image(pixel_values)
2797
+ inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=attention_mask)
2798
 
2799
  return self.language_model.generate(
2800
  input_ids=None,
2801
  inputs_embeds=inputs_embeds,
2802
+ attention_mask=attention_mask,
2803
  **kwargs
2804
  )
2805