Fix to handle None for image
Browse files- modeling_llava.py +3 -3
modeling_llava.py
CHANGED
@@ -1661,7 +1661,7 @@ class LlavaForCausalLM(LlavaPreTrainedModel):
|
|
1661 |
else:
|
1662 |
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
1663 |
# generation with cache
|
1664 |
-
if past_key_values is not None and
|
1665 |
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
1666 |
# that are set to 0
|
1667 |
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
@@ -1734,6 +1734,7 @@ class LlavaForCausalLM(LlavaPreTrainedModel):
|
|
1734 |
else:
|
1735 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
1736 |
max_cache_length = None
|
|
|
1737 |
|
1738 |
# Keep only the unprocessed tokens:
|
1739 |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
@@ -1746,8 +1747,7 @@ class LlavaForCausalLM(LlavaPreTrainedModel):
|
|
1746 |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
1747 |
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
1748 |
# input_ids based on the past_length.
|
1749 |
-
elif past_length < input_ids.shape[1]
|
1750 |
-
past_length -= image_features.shape[1]-1
|
1751 |
input_ids = input_ids[:, past_length:]
|
1752 |
attention_mask = attention_mask[:, past_length:]
|
1753 |
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
|
|
1661 |
else:
|
1662 |
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
1663 |
# generation with cache
|
1664 |
+
if past_key_values is not None and input_ids.shape[1] == 1:
|
1665 |
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
1666 |
# that are set to 0
|
1667 |
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
|
|
1734 |
else:
|
1735 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
1736 |
max_cache_length = None
|
1737 |
+
past_length -= image_features.shape[1]-1 if image_features is not None else 0
|
1738 |
|
1739 |
# Keep only the unprocessed tokens:
|
1740 |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
|
|
1747 |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
1748 |
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
1749 |
# input_ids based on the past_length.
|
1750 |
+
elif past_length < input_ids.shape[1]:
|
|
|
1751 |
input_ids = input_ids[:, past_length:]
|
1752 |
attention_mask = attention_mask[:, past_length:]
|
1753 |
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|