visheratin commited on
Commit
e9369b2
·
verified ·
1 Parent(s): f53fea1

Update modeling file

Browse files
Files changed (1) hide show
  1. modeling_llava.py +75 -71
modeling_llava.py CHANGED
@@ -1148,71 +1148,6 @@ class PhiForCausalLM(PhiPreTrainedModel):
1148
  attentions=outputs.attentions,
1149
  )
1150
 
1151
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
1152
- def prepare_inputs_for_generation(
1153
- self,
1154
- input_ids,
1155
- past_key_values=None,
1156
- attention_mask=None,
1157
- inputs_embeds=None,
1158
- **kwargs,
1159
- ):
1160
- if past_key_values is not None:
1161
- if isinstance(past_key_values, Cache):
1162
- cache_length = past_key_values.get_seq_length()
1163
- past_length = past_key_values.seen_tokens
1164
- max_cache_length = past_key_values.get_max_length()
1165
- else:
1166
- cache_length = past_length = past_key_values[0][0].shape[2]
1167
- max_cache_length = None
1168
-
1169
- # Keep only the unprocessed tokens:
1170
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1171
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1172
- # input)
1173
- if (
1174
- attention_mask is not None
1175
- and attention_mask.shape[1] > input_ids.shape[1]
1176
- ):
1177
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1178
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1179
- # input_ids based on the past_length.
1180
- elif past_length < input_ids.shape[1]:
1181
- input_ids = input_ids[:, past_length:]
1182
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1183
-
1184
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1185
- if (
1186
- max_cache_length is not None
1187
- and attention_mask is not None
1188
- and cache_length + input_ids.shape[1] > max_cache_length
1189
- ):
1190
- attention_mask = attention_mask[:, -max_cache_length:]
1191
-
1192
- position_ids = kwargs.get("position_ids", None)
1193
- if attention_mask is not None and position_ids is None:
1194
- # create position_ids on the fly for batch generation
1195
- position_ids = attention_mask.long().cumsum(-1) - 1
1196
- position_ids.masked_fill_(attention_mask == 0, 1)
1197
- if past_key_values:
1198
- position_ids = position_ids[:, -input_ids.shape[1] :]
1199
-
1200
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1201
- if inputs_embeds is not None and past_key_values is None:
1202
- model_inputs = {"inputs_embeds": inputs_embeds}
1203
- else:
1204
- model_inputs = {"input_ids": input_ids}
1205
-
1206
- model_inputs.update(
1207
- {
1208
- "position_ids": position_ids,
1209
- "past_key_values": past_key_values,
1210
- "use_cache": kwargs.get("use_cache"),
1211
- "attention_mask": attention_mask,
1212
- }
1213
- )
1214
- return model_inputs
1215
-
1216
  @staticmethod
1217
  # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
1218
  def _reorder_cache(past_key_values, beam_idx):
@@ -1723,6 +1658,38 @@ class LlavaForCausalLM(LlavaPreTrainedModel):
1723
  attention_mask,
1724
  position_ids,
1725
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1726
 
1727
  outputs = self.language_model(
1728
  input_ids=None,
@@ -1759,13 +1726,49 @@ class LlavaForCausalLM(LlavaPreTrainedModel):
1759
  image_features=None,
1760
  **kwargs,
1761
  ):
1762
- res = self.language_model.prepare_inputs_for_generation(
1763
- input_ids, past_key_values, attention_mask, **kwargs
1764
- )
1765
- input_ids = res["input_ids"]
1766
- past_key_values = res["past_key_values"]
1767
- attention_mask = res["attention_mask"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1768
 
 
 
 
 
 
 
 
 
 
1769
  if inputs_embeds is not None and past_key_values is None:
1770
  model_inputs = {"inputs_embeds": inputs_embeds}
1771
  else:
@@ -1773,6 +1776,7 @@ class LlavaForCausalLM(LlavaPreTrainedModel):
1773
 
1774
  model_inputs.update(
1775
  {
 
1776
  "past_key_values": past_key_values,
1777
  "use_cache": kwargs.get("use_cache"),
1778
  "attention_mask": attention_mask,
 
1148
  attentions=outputs.attentions,
1149
  )
1150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1151
  @staticmethod
1152
  # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
1153
  def _reorder_cache(past_key_values, beam_idx):
 
1658
  attention_mask,
1659
  position_ids,
1660
  )
1661
+ else:
1662
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
1663
+ # generation with cache
1664
+ if past_key_values is not None and image_features is not None and input_ids.shape[1] == 1:
1665
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
1666
+ # that are set to 0
1667
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
1668
+
1669
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
1670
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
1671
+
1672
+ # Get the target length
1673
+ target_seqlen = first_layer_past_key_value.shape[-1] + 1
1674
+
1675
+ extended_attention_mask = torch.ones(
1676
+ (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
1677
+ dtype=attention_mask.dtype,
1678
+ device=attention_mask.device,
1679
+ )
1680
+
1681
+ # Filter out only the tokens that can be un-attended, this can happen
1682
+ # if one uses Llava + Fused modules where the cache on the
1683
+ # first iteration is already big enough, or if one passes custom cache
1684
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
1685
+ new_batch_index = batch_index[valid_indices]
1686
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
1687
+
1688
+ # Zero-out the places where we don't need to attend
1689
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
1690
+
1691
+ attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
1692
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
1693
 
1694
  outputs = self.language_model(
1695
  input_ids=None,
 
1726
  image_features=None,
1727
  **kwargs,
1728
  ):
1729
+ if past_key_values is not None:
1730
+ if isinstance(past_key_values, Cache):
1731
+ cache_length = past_key_values.get_seq_length()
1732
+ past_length = past_key_values.seen_tokens
1733
+ max_cache_length = past_key_values.get_max_length()
1734
+ else:
1735
+ cache_length = past_length = past_key_values[0][0].shape[2]
1736
+ max_cache_length = None
1737
+
1738
+ # Keep only the unprocessed tokens:
1739
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1740
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1741
+ # input)
1742
+ if (
1743
+ attention_mask is not None
1744
+ and attention_mask.shape[1] > input_ids.shape[1]
1745
+ ):
1746
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1747
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1748
+ # input_ids based on the past_length.
1749
+ elif past_length < input_ids.shape[1]+image_features.shape[1]-1:
1750
+ past_length -= image_features.shape[1]-1
1751
+ input_ids = input_ids[:, past_length:]
1752
+ attention_mask = attention_mask[:, past_length:]
1753
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1754
+
1755
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1756
+ if (
1757
+ max_cache_length is not None
1758
+ and attention_mask is not None
1759
+ and cache_length + input_ids.shape[1] > max_cache_length
1760
+ ):
1761
+ attention_mask = attention_mask[:, -max_cache_length:]
1762
 
1763
+ position_ids = kwargs.get("position_ids", None)
1764
+ if attention_mask is not None and position_ids is None:
1765
+ # create position_ids on the fly for batch generation
1766
+ position_ids = attention_mask.long().cumsum(-1) - 1
1767
+ position_ids.masked_fill_(attention_mask == 0, 1)
1768
+ if past_key_values:
1769
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1770
+
1771
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1772
  if inputs_embeds is not None and past_key_values is None:
1773
  model_inputs = {"inputs_embeds": inputs_embeds}
1774
  else:
 
1776
 
1777
  model_inputs.update(
1778
  {
1779
+ "position_ids": position_ids,
1780
  "past_key_values": past_key_values,
1781
  "use_cache": kwargs.get("use_cache"),
1782
  "attention_mask": attention_mask,