Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model (#19)
Browse files- Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model (775f6527d3cfd402c46b03c5fbf355b4f262b705)
Co-authored-by: Tomer Ronen <[email protected]>
- modeling_decilm.py +45 -1
modeling_decilm.py
CHANGED
@@ -25,7 +25,7 @@ import torch.utils.checkpoint
|
|
25 |
from torch import nn
|
26 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
27 |
from transformers import GenerationConfig
|
28 |
-
from transformers.generation.utils import GenerationMixin,
|
29 |
from transformers.modeling_utils import PreTrainedModel
|
30 |
from transformers.utils import (
|
31 |
add_start_docstrings,
|
@@ -1311,6 +1311,50 @@ class DeciLMForCausalLM(DeciLMPreTrainedModel, GenerationMixin):
|
|
1311 |
)
|
1312 |
return model_inputs
|
1313 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1314 |
|
1315 |
@add_start_docstrings(
|
1316 |
"""
|
|
|
25 |
from torch import nn
|
26 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
27 |
from transformers import GenerationConfig
|
28 |
+
from transformers.generation.utils import NEED_SETUP_CACHE_CLASSES_MAPPING, GenerationMixin, GenerateOutput
|
29 |
from transformers.modeling_utils import PreTrainedModel
|
30 |
from transformers.utils import (
|
31 |
add_start_docstrings,
|
|
|
1311 |
)
|
1312 |
return model_inputs
|
1313 |
|
1314 |
+
def _maybe_initialize_input_ids_for_generation(
|
1315 |
+
self,
|
1316 |
+
inputs: Optional[torch.Tensor] = None,
|
1317 |
+
bos_token_id: Optional[torch.Tensor] = None,
|
1318 |
+
model_kwargs: Optional[dict[str, torch.Tensor]] = None,
|
1319 |
+
) -> torch.LongTensor:
|
1320 |
+
"""
|
1321 |
+
Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model
|
1322 |
+
"""
|
1323 |
+
input_ids = super()._maybe_initialize_input_ids_for_generation(
|
1324 |
+
inputs=inputs, bos_token_id=bos_token_id, model_kwargs=model_kwargs)
|
1325 |
+
if (
|
1326 |
+
"inputs_embeds" in model_kwargs
|
1327 |
+
and input_ids is not None
|
1328 |
+
and input_ids.shape[1] == 0
|
1329 |
+
):
|
1330 |
+
batch_size, input_sequence_length = model_kwargs["inputs_embeds"].shape[:2]
|
1331 |
+
input_ids = torch.zeros((batch_size, input_sequence_length), dtype=torch.long, device=self.device)
|
1332 |
+
return input_ids
|
1333 |
+
|
1334 |
+
def generate(
|
1335 |
+
self,
|
1336 |
+
inputs: Optional[torch.Tensor] = None,
|
1337 |
+
*args,
|
1338 |
+
**kwargs,
|
1339 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
1340 |
+
"""
|
1341 |
+
Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model
|
1342 |
+
"""
|
1343 |
+
only_passed_inputs_embeds = (
|
1344 |
+
"inputs_embeds" in kwargs and
|
1345 |
+
"input_ids" not in kwargs and
|
1346 |
+
inputs is None
|
1347 |
+
)
|
1348 |
+
if only_passed_inputs_embeds:
|
1349 |
+
input_sequence_length = kwargs["inputs_embeds"].shape[1]
|
1350 |
+
|
1351 |
+
generation_output = super().generate(inputs=inputs, *args, **kwargs)
|
1352 |
+
|
1353 |
+
if only_passed_inputs_embeds and isinstance(generation_output, torch.Tensor):
|
1354 |
+
generation_output = generation_output[:, input_sequence_length:]
|
1355 |
+
|
1356 |
+
return generation_output
|
1357 |
+
|
1358 |
|
1359 |
@add_start_docstrings(
|
1360 |
"""
|