clean up
#16
by
pawlowskipawel
- opened
- modeling_florence2.py +15 -7
modeling_florence2.py
CHANGED
@@ -2643,7 +2643,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2643 |
return x
|
2644 |
|
2645 |
def _merge_input_ids_with_image_features(
|
2646 |
-
self, image_features, inputs_embeds
|
2647 |
):
|
2648 |
batch_size, image_token_length = image_features.size()[:-1]
|
2649 |
device = image_features.device
|
@@ -2655,10 +2655,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2655 |
return image_features, image_attention_mask
|
2656 |
|
2657 |
task_prefix_embeds = inputs_embeds
|
2658 |
-
|
|
|
|
|
2659 |
|
2660 |
-
|
2661 |
-
|
2662 |
|
2663 |
# concat [image embeds, task prefix embeds]
|
2664 |
inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
|
@@ -2719,12 +2721,14 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2719 |
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
2720 |
"A green car parked in front of a yellow building."
|
2721 |
```"""
|
|
|
|
|
2722 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
2723 |
output_hidden_states = (
|
2724 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
2725 |
)
|
2726 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
2727 |
-
|
2728 |
image_features = None
|
2729 |
if inputs_embeds is None:
|
2730 |
# 1. Extra the input embeddings
|
@@ -2735,7 +2739,9 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2735 |
# (batch_size, num_image_tokens, hidden_size)
|
2736 |
image_features = self._encode_image(pixel_values)
|
2737 |
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
|
2738 |
-
|
|
|
|
|
2739 |
if inputs_embeds is not None:
|
2740 |
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
2741 |
outputs = self.language_model(
|
@@ -2781,6 +2787,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2781 |
input_ids,
|
2782 |
inputs_embeds=None,
|
2783 |
pixel_values=None,
|
|
|
2784 |
**kwargs
|
2785 |
):
|
2786 |
|
@@ -2791,11 +2798,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2791 |
# 2. Merge text and images
|
2792 |
if pixel_values is not None:
|
2793 |
image_features = self._encode_image(pixel_values)
|
2794 |
-
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
|
2795 |
|
2796 |
return self.language_model.generate(
|
2797 |
input_ids=None,
|
2798 |
inputs_embeds=inputs_embeds,
|
|
|
2799 |
**kwargs
|
2800 |
)
|
2801 |
|
|
|
2643 |
return x
|
2644 |
|
2645 |
def _merge_input_ids_with_image_features(
|
2646 |
+
self, image_features, inputs_embeds, task_prefix_attention_mask=None
|
2647 |
):
|
2648 |
batch_size, image_token_length = image_features.size()[:-1]
|
2649 |
device = image_features.device
|
|
|
2655 |
return image_features, image_attention_mask
|
2656 |
|
2657 |
task_prefix_embeds = inputs_embeds
|
2658 |
+
|
2659 |
+
if task_prefix_attention_mask is None:
|
2660 |
+
task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
|
2661 |
|
2662 |
+
if len(task_prefix_attention_mask.shape) == 3:
|
2663 |
+
task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
|
2664 |
|
2665 |
# concat [image embeds, task prefix embeds]
|
2666 |
inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
|
|
|
2721 |
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
2722 |
"A green car parked in front of a yellow building."
|
2723 |
```"""
|
2724 |
+
print("asdasdasdasdasdasdasdasda")
|
2725 |
+
|
2726 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
2727 |
output_hidden_states = (
|
2728 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
2729 |
)
|
2730 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
2731 |
+
print("asdasdasdasdasdasdasdasda")
|
2732 |
image_features = None
|
2733 |
if inputs_embeds is None:
|
2734 |
# 1. Extra the input embeddings
|
|
|
2739 |
# (batch_size, num_image_tokens, hidden_size)
|
2740 |
image_features = self._encode_image(pixel_values)
|
2741 |
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
|
2742 |
+
|
2743 |
+
print(attention_mask)
|
2744 |
+
|
2745 |
if inputs_embeds is not None:
|
2746 |
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
2747 |
outputs = self.language_model(
|
|
|
2787 |
input_ids,
|
2788 |
inputs_embeds=None,
|
2789 |
pixel_values=None,
|
2790 |
+
attention_mask=None,
|
2791 |
**kwargs
|
2792 |
):
|
2793 |
|
|
|
2798 |
# 2. Merge text and images
|
2799 |
if pixel_values is not None:
|
2800 |
image_features = self._encode_image(pixel_values)
|
2801 |
+
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=attention_mask)
|
2802 |
|
2803 |
return self.language_model.generate(
|
2804 |
input_ids=None,
|
2805 |
inputs_embeds=inputs_embeds,
|
2806 |
+
attention_mask=attention_mask,
|
2807 |
**kwargs
|
2808 |
)
|
2809 |
|