Fix input embeds
Browse files- modeling_chatglm.py +2 -3
modeling_chatglm.py
CHANGED
@@ -918,7 +918,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
918 |
elif input_ids is not None:
|
919 |
batch_size, seq_length = input_ids.shape[:2]
|
920 |
elif inputs_embeds is not None:
|
921 |
-
batch_size, seq_length
|
922 |
else:
|
923 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
924 |
|
@@ -972,9 +972,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
972 |
|
973 |
if attention_mask is None:
|
974 |
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
|
975 |
-
|
976 |
else:
|
977 |
-
attention_mask = attention_mask.to(
|
978 |
|
979 |
for i, layer in enumerate(self.layers):
|
980 |
|
|
|
918 |
elif input_ids is not None:
|
919 |
batch_size, seq_length = input_ids.shape[:2]
|
920 |
elif inputs_embeds is not None:
|
921 |
+
batch_size, seq_length = inputs_embeds.shape[:2]
|
922 |
else:
|
923 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
924 |
|
|
|
972 |
|
973 |
if attention_mask is None:
|
974 |
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
|
|
|
975 |
else:
|
976 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
977 |
|
978 |
for i, layer in enumerate(self.layers):
|
979 |
|