Fix the issue on latest transformers update

#27
Files changed (1) hide show
  1. modeling_phi.py +1 -1
modeling_phi.py CHANGED
@@ -1161,7 +1161,7 @@ class PhiForCausalLM(PhiPreTrainedModel):
1161
  position_ids = position_ids[:, -input_ids.shape[1] :]
1162
 
1163
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1164
- if inputs_embeds is not None and past_key_values is None:
1165
  model_inputs = {"inputs_embeds": inputs_embeds}
1166
  else:
1167
  model_inputs = {"input_ids": input_ids}
 
1161
  position_ids = position_ids[:, -input_ids.shape[1] :]
1162
 
1163
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1164
+ if inputs_embeds is not None and (input_ids is None or input_ids.shape[1] == 0):
1165
  model_inputs = {"inputs_embeds": inputs_embeds}
1166
  else:
1167
  model_inputs = {"input_ids": input_ids}