visheratin
commited on
Update modeling file
Browse files- modeling_llava.py +75 -71
modeling_llava.py
CHANGED
@@ -1148,71 +1148,6 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
|
1148 |
attentions=outputs.attentions,
|
1149 |
)
|
1150 |
|
1151 |
-
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
|
1152 |
-
def prepare_inputs_for_generation(
|
1153 |
-
self,
|
1154 |
-
input_ids,
|
1155 |
-
past_key_values=None,
|
1156 |
-
attention_mask=None,
|
1157 |
-
inputs_embeds=None,
|
1158 |
-
**kwargs,
|
1159 |
-
):
|
1160 |
-
if past_key_values is not None:
|
1161 |
-
if isinstance(past_key_values, Cache):
|
1162 |
-
cache_length = past_key_values.get_seq_length()
|
1163 |
-
past_length = past_key_values.seen_tokens
|
1164 |
-
max_cache_length = past_key_values.get_max_length()
|
1165 |
-
else:
|
1166 |
-
cache_length = past_length = past_key_values[0][0].shape[2]
|
1167 |
-
max_cache_length = None
|
1168 |
-
|
1169 |
-
# Keep only the unprocessed tokens:
|
1170 |
-
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
1171 |
-
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
1172 |
-
# input)
|
1173 |
-
if (
|
1174 |
-
attention_mask is not None
|
1175 |
-
and attention_mask.shape[1] > input_ids.shape[1]
|
1176 |
-
):
|
1177 |
-
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
1178 |
-
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
1179 |
-
# input_ids based on the past_length.
|
1180 |
-
elif past_length < input_ids.shape[1]:
|
1181 |
-
input_ids = input_ids[:, past_length:]
|
1182 |
-
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
1183 |
-
|
1184 |
-
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
1185 |
-
if (
|
1186 |
-
max_cache_length is not None
|
1187 |
-
and attention_mask is not None
|
1188 |
-
and cache_length + input_ids.shape[1] > max_cache_length
|
1189 |
-
):
|
1190 |
-
attention_mask = attention_mask[:, -max_cache_length:]
|
1191 |
-
|
1192 |
-
position_ids = kwargs.get("position_ids", None)
|
1193 |
-
if attention_mask is not None and position_ids is None:
|
1194 |
-
# create position_ids on the fly for batch generation
|
1195 |
-
position_ids = attention_mask.long().cumsum(-1) - 1
|
1196 |
-
position_ids.masked_fill_(attention_mask == 0, 1)
|
1197 |
-
if past_key_values:
|
1198 |
-
position_ids = position_ids[:, -input_ids.shape[1] :]
|
1199 |
-
|
1200 |
-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1201 |
-
if inputs_embeds is not None and past_key_values is None:
|
1202 |
-
model_inputs = {"inputs_embeds": inputs_embeds}
|
1203 |
-
else:
|
1204 |
-
model_inputs = {"input_ids": input_ids}
|
1205 |
-
|
1206 |
-
model_inputs.update(
|
1207 |
-
{
|
1208 |
-
"position_ids": position_ids,
|
1209 |
-
"past_key_values": past_key_values,
|
1210 |
-
"use_cache": kwargs.get("use_cache"),
|
1211 |
-
"attention_mask": attention_mask,
|
1212 |
-
}
|
1213 |
-
)
|
1214 |
-
return model_inputs
|
1215 |
-
|
1216 |
@staticmethod
|
1217 |
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
|
1218 |
def _reorder_cache(past_key_values, beam_idx):
|
@@ -1723,6 +1658,38 @@ class LlavaForCausalLM(LlavaPreTrainedModel):
|
|
1723 |
attention_mask,
|
1724 |
position_ids,
|
1725 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1726 |
|
1727 |
outputs = self.language_model(
|
1728 |
input_ids=None,
|
@@ -1759,13 +1726,49 @@ class LlavaForCausalLM(LlavaPreTrainedModel):
|
|
1759 |
image_features=None,
|
1760 |
**kwargs,
|
1761 |
):
|
1762 |
-
|
1763 |
-
|
1764 |
-
|
1765 |
-
|
1766 |
-
|
1767 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1768 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1769 |
if inputs_embeds is not None and past_key_values is None:
|
1770 |
model_inputs = {"inputs_embeds": inputs_embeds}
|
1771 |
else:
|
@@ -1773,6 +1776,7 @@ class LlavaForCausalLM(LlavaPreTrainedModel):
|
|
1773 |
|
1774 |
model_inputs.update(
|
1775 |
{
|
|
|
1776 |
"past_key_values": past_key_values,
|
1777 |
"use_cache": kwargs.get("use_cache"),
|
1778 |
"attention_mask": attention_mask,
|
|
|
1148 |
attentions=outputs.attentions,
|
1149 |
)
|
1150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1151 |
@staticmethod
|
1152 |
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
|
1153 |
def _reorder_cache(past_key_values, beam_idx):
|
|
|
1658 |
attention_mask,
|
1659 |
position_ids,
|
1660 |
)
|
1661 |
+
else:
|
1662 |
+
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
1663 |
+
# generation with cache
|
1664 |
+
if past_key_values is not None and image_features is not None and input_ids.shape[1] == 1:
|
1665 |
+
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
1666 |
+
# that are set to 0
|
1667 |
+
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
1668 |
+
|
1669 |
+
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
1670 |
+
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
1671 |
+
|
1672 |
+
# Get the target length
|
1673 |
+
target_seqlen = first_layer_past_key_value.shape[-1] + 1
|
1674 |
+
|
1675 |
+
extended_attention_mask = torch.ones(
|
1676 |
+
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
|
1677 |
+
dtype=attention_mask.dtype,
|
1678 |
+
device=attention_mask.device,
|
1679 |
+
)
|
1680 |
+
|
1681 |
+
# Filter out only the tokens that can be un-attended, this can happen
|
1682 |
+
# if one uses Llava + Fused modules where the cache on the
|
1683 |
+
# first iteration is already big enough, or if one passes custom cache
|
1684 |
+
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
1685 |
+
new_batch_index = batch_index[valid_indices]
|
1686 |
+
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
1687 |
+
|
1688 |
+
# Zero-out the places where we don't need to attend
|
1689 |
+
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
1690 |
+
|
1691 |
+
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
|
1692 |
+
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
1693 |
|
1694 |
outputs = self.language_model(
|
1695 |
input_ids=None,
|
|
|
1726 |
image_features=None,
|
1727 |
**kwargs,
|
1728 |
):
|
1729 |
+
if past_key_values is not None:
|
1730 |
+
if isinstance(past_key_values, Cache):
|
1731 |
+
cache_length = past_key_values.get_seq_length()
|
1732 |
+
past_length = past_key_values.seen_tokens
|
1733 |
+
max_cache_length = past_key_values.get_max_length()
|
1734 |
+
else:
|
1735 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
1736 |
+
max_cache_length = None
|
1737 |
+
|
1738 |
+
# Keep only the unprocessed tokens:
|
1739 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
1740 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
1741 |
+
# input)
|
1742 |
+
if (
|
1743 |
+
attention_mask is not None
|
1744 |
+
and attention_mask.shape[1] > input_ids.shape[1]
|
1745 |
+
):
|
1746 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
1747 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
1748 |
+
# input_ids based on the past_length.
|
1749 |
+
elif past_length < input_ids.shape[1]+image_features.shape[1]-1:
|
1750 |
+
past_length -= image_features.shape[1]-1
|
1751 |
+
input_ids = input_ids[:, past_length:]
|
1752 |
+
attention_mask = attention_mask[:, past_length:]
|
1753 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
1754 |
+
|
1755 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
1756 |
+
if (
|
1757 |
+
max_cache_length is not None
|
1758 |
+
and attention_mask is not None
|
1759 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
1760 |
+
):
|
1761 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
1762 |
|
1763 |
+
position_ids = kwargs.get("position_ids", None)
|
1764 |
+
if attention_mask is not None and position_ids is None:
|
1765 |
+
# create position_ids on the fly for batch generation
|
1766 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1767 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1768 |
+
if past_key_values:
|
1769 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
1770 |
+
|
1771 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1772 |
if inputs_embeds is not None and past_key_values is None:
|
1773 |
model_inputs = {"inputs_embeds": inputs_embeds}
|
1774 |
else:
|
|
|
1776 |
|
1777 |
model_inputs.update(
|
1778 |
{
|
1779 |
+
"position_ids": position_ids,
|
1780 |
"past_key_values": past_key_values,
|
1781 |
"use_cache": kwargs.get("use_cache"),
|
1782 |
"attention_mask": attention_mask,
|