fix-get_max_length-deprecation
Browse fileshttps://huggingface.co./microsoft/Phi-3-vision-128k-instruct/discussions/69 , https://github.com/huggingface/transformers/issues/36071#issuecomment-2642222305
get_max_length deprecated with transformers 4.49
- modeling_pharia.py +3 -3
modeling_pharia.py
CHANGED
@@ -606,7 +606,7 @@ class PhariaModel(PhariaPreTrainedModel):
|
|
606 |
min_dtype = torch.finfo(dtype).min
|
607 |
sequence_length = input_tensor.shape[1]
|
608 |
if using_static_cache:
|
609 |
-
target_length = past_key_values.
|
610 |
else:
|
611 |
target_length = (
|
612 |
attention_mask.shape[-1]
|
@@ -812,9 +812,9 @@ class PhariaForCausalLM(PhariaPreTrainedModel):
|
|
812 |
)
|
813 |
max_cache_length = (
|
814 |
torch.tensor(
|
815 |
-
past_key_values.
|
816 |
)
|
817 |
-
if past_key_values.
|
818 |
else None
|
819 |
)
|
820 |
cache_length = (
|
|
|
606 |
min_dtype = torch.finfo(dtype).min
|
607 |
sequence_length = input_tensor.shape[1]
|
608 |
if using_static_cache:
|
609 |
+
target_length = past_key_values.get_max_cache_shape()
|
610 |
else:
|
611 |
target_length = (
|
612 |
attention_mask.shape[-1]
|
|
|
812 |
)
|
813 |
max_cache_length = (
|
814 |
torch.tensor(
|
815 |
+
past_key_values.get_max_cache_shape(), device=input_ids.device
|
816 |
)
|
817 |
+
if past_key_values.get_max_cache_shape() is not None
|
818 |
else None
|
819 |
)
|
820 |
cache_length = (
|