Update modeling_chatglm.py for inputs_embeds

#45
by Xipotzzz - opened
Files changed (1) hide show
  1. modeling_chatglm.py +22 -11
modeling_chatglm.py CHANGED
@@ -914,11 +914,11 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
914
  use_cache = False
915
 
916
  if input_ids is not None and inputs_embeds is not None:
917
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
918
- elif input_ids is not None:
919
  batch_size, seq_length = input_ids.shape[:2]
920
  elif inputs_embeds is not None:
921
- batch_size, seq_length, _ = inputs_embeds.shape[:2]
922
  else:
923
  raise ValueError("You have to specify either input_ids or inputs_embeds")
924
 
@@ -972,9 +972,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
972
 
973
  if attention_mask is None:
974
  attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
975
-
976
  else:
977
- attention_mask = attention_mask.to(input_ids.device)
978
 
979
  for i, layer in enumerate(self.layers):
980
 
@@ -1105,6 +1104,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1105
  def prepare_inputs_for_generation(
1106
  self,
1107
  input_ids: torch.LongTensor,
 
1108
  past: Optional[torch.Tensor] = None,
1109
  past_key_values: Optional[torch.Tensor] = None,
1110
  attention_mask: Optional[torch.Tensor] = None,
@@ -1165,12 +1165,23 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1165
  use_gmasks=use_gmasks
1166
  )
1167
 
1168
- return {
1169
- "input_ids": input_ids,
1170
- "past_key_values": past,
1171
- "position_ids": position_ids,
1172
- "attention_mask": attention_mask
1173
- }
 
 
 
 
 
 
 
 
 
 
 
1174
 
1175
  def forward(
1176
  self,
 
914
  use_cache = False
915
 
916
  if input_ids is not None and inputs_embeds is not None:
917
+ logger.warning("You passed both `inputs_embeds` and `input_ids`. Will use `inputs_embeds`")
918
+ if input_ids is not None:
919
  batch_size, seq_length = input_ids.shape[:2]
920
  elif inputs_embeds is not None:
921
+ batch_size, seq_length = inputs_embeds.shape[:2]
922
  else:
923
  raise ValueError("You have to specify either input_ids or inputs_embeds")
924
 
 
972
 
973
  if attention_mask is None:
974
  attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
 
975
  else:
976
+ attention_mask = attention_mask.to(hidden_states.device)
977
 
978
  for i, layer in enumerate(self.layers):
979
 
 
1104
  def prepare_inputs_for_generation(
1105
  self,
1106
  input_ids: torch.LongTensor,
1107
+ inputs_embeds: Optional[torch.Tensor] = None,
1108
  past: Optional[torch.Tensor] = None,
1109
  past_key_values: Optional[torch.Tensor] = None,
1110
  attention_mask: Optional[torch.Tensor] = None,
 
1165
  use_gmasks=use_gmasks
1166
  )
1167
 
1168
+ if inputs_embeds is not None:
1169
+ assert input_ids.size(1) == inputs_embeds.size(
1170
+ 1
1171
+ ), f"Make sure that both input_ids ({input_ids.size(1)}) and inputs_embeds ({inputs_embeds.size(1)}) have the same length."
1172
+ return {
1173
+ "inputs_embeds": inputs_embeds,
1174
+ "past_key_values": past,
1175
+ "position_ids": position_ids,
1176
+ "attention_mask": attention_mask,
1177
+ }
1178
+ else:
1179
+ return {
1180
+ "input_ids": input_ids,
1181
+ "past_key_values": past,
1182
+ "position_ids": position_ids,
1183
+ "attention_mask": attention_mask,
1184
+ }
1185
 
1186
  def forward(
1187
  self,