zxdu20 commited on
Commit
34a7f82
2 Parent(s): 3118bcb 35ca523

Merge branch 'main' into pr/45

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +2 -0
modeling_chatglm.py CHANGED
@@ -970,6 +970,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
970
 
971
  if attention_mask is None:
972
  attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
 
 
973
 
974
  for i, layer in enumerate(self.layers):
975
 
 
970
 
971
  if attention_mask is None:
972
  attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
973
+ else:
974
+ attention_mask = attention_mask.to(hidden_states.device)
975
 
976
  for i, layer in enumerate(self.layers):
977