AzizBelaweid commited on
Commit
2cfde52
·
verified ·
1 Parent(s): d8fc101

fix-get_max_length-deprecation

Browse files

https://huggingface.co./microsoft/Phi-3-vision-128k-instruct/discussions/69 , https://github.com/huggingface/transformers/issues/36071#issuecomment-2642222305

get_max_length deprecated with transformers 4.49

Files changed (1) hide show
  1. modeling_pharia.py +3 -3
modeling_pharia.py CHANGED
@@ -606,7 +606,7 @@ class PhariaModel(PhariaPreTrainedModel):
606
  min_dtype = torch.finfo(dtype).min
607
  sequence_length = input_tensor.shape[1]
608
  if using_static_cache:
609
- target_length = past_key_values.get_max_length()
610
  else:
611
  target_length = (
612
  attention_mask.shape[-1]
@@ -812,9 +812,9 @@ class PhariaForCausalLM(PhariaPreTrainedModel):
812
  )
813
  max_cache_length = (
814
  torch.tensor(
815
- past_key_values.get_max_length(), device=input_ids.device
816
  )
817
- if past_key_values.get_max_length() is not None
818
  else None
819
  )
820
  cache_length = (
 
606
  min_dtype = torch.finfo(dtype).min
607
  sequence_length = input_tensor.shape[1]
608
  if using_static_cache:
609
+ target_length = past_key_values.get_max_cache_shape()
610
  else:
611
  target_length = (
612
  attention_mask.shape[-1]
 
812
  )
813
  max_cache_length = (
814
  torch.tensor(
815
+ past_key_values.get_max_cache_shape(), device=input_ids.device
816
  )
817
+ if past_key_values.get_max_cache_shape() is not None
818
  else None
819
  )
820
  cache_length = (