wilbown commited on
Commit
5a30066
1 Parent(s): 8b0fa0e

Fix to handle None for image

Browse files
Files changed (1) hide show
  1. modeling_llava.py +3 -3
modeling_llava.py CHANGED
@@ -1661,7 +1661,7 @@ class LlavaForCausalLM(LlavaPreTrainedModel):
1661
  else:
1662
  # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
1663
  # generation with cache
1664
- if past_key_values is not None and image_features is not None and input_ids.shape[1] == 1:
1665
  # Retrieve the first layer to inspect the logits and mask out the hidden states
1666
  # that are set to 0
1667
  first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
@@ -1734,6 +1734,7 @@ class LlavaForCausalLM(LlavaPreTrainedModel):
1734
  else:
1735
  cache_length = past_length = past_key_values[0][0].shape[2]
1736
  max_cache_length = None
 
1737
 
1738
  # Keep only the unprocessed tokens:
1739
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
@@ -1746,8 +1747,7 @@ class LlavaForCausalLM(LlavaPreTrainedModel):
1746
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1747
  # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1748
  # input_ids based on the past_length.
1749
- elif past_length < input_ids.shape[1]+image_features.shape[1]-1:
1750
- past_length -= image_features.shape[1]-1
1751
  input_ids = input_ids[:, past_length:]
1752
  attention_mask = attention_mask[:, past_length:]
1753
  # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
 
1661
  else:
1662
  # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
1663
  # generation with cache
1664
+ if past_key_values is not None and input_ids.shape[1] == 1:
1665
  # Retrieve the first layer to inspect the logits and mask out the hidden states
1666
  # that are set to 0
1667
  first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
 
1734
  else:
1735
  cache_length = past_length = past_key_values[0][0].shape[2]
1736
  max_cache_length = None
1737
+ past_length -= image_features.shape[1]-1 if image_features is not None else 0
1738
 
1739
  # Keep only the unprocessed tokens:
1740
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
 
1747
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1748
  # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1749
  # input_ids based on the past_length.
1750
+ elif past_length < input_ids.shape[1]:
 
1751
  input_ids = input_ids[:, past_length:]
1752
  attention_mask = attention_mask[:, past_length:]
1753
  # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.