zxdu20 commited on
Commit
fb23542
1 Parent(s): 08bc851

Fix generate

Browse files
Files changed (2) hide show
  1. modeling_chatglm.py +14 -8
  2. tokenization_chatglm.py +0 -2
modeling_chatglm.py CHANGED
@@ -1054,13 +1054,14 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1054
  # update attention mask
1055
  if "attention_mask" in model_kwargs:
1056
  attention_mask = model_kwargs["attention_mask"]
1057
- attention_mask = torch.cat(
1058
- [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3)
1059
- new_attention_mask = attention_mask[:, :, -1:].clone()
1060
- new_attention_mask[..., -1] = False
1061
- model_kwargs["attention_mask"] = torch.cat(
1062
- [attention_mask, new_attention_mask], dim=2
1063
- )
 
1064
 
1065
  # update position ids
1066
  if "position_ids" in model_kwargs:
@@ -1092,8 +1093,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1092
  # only last token for input_ids if past is not None
1093
  if past is not None or past_key_values is not None:
1094
  last_token = input_ids[:, -1].unsqueeze(-1)
1095
- if attention_mask is not None:
1096
  attention_mask = attention_mask[:, :, -1:]
 
 
1097
  if position_ids is not None:
1098
  position_ids = position_ids[..., -1:]
1099
  else:
@@ -1115,6 +1118,9 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1115
  "attention_mask": attention_mask
1116
  }
1117
  else:
 
 
 
1118
  if attention_mask is None:
1119
  attention_mask = self.get_masks(
1120
  input_ids,
 
1054
  # update attention mask
1055
  if "attention_mask" in model_kwargs:
1056
  attention_mask = model_kwargs["attention_mask"]
1057
+ if attention_mask is not None and attention_mask.dtype == torch.bool:
1058
+ attention_mask = torch.cat(
1059
+ [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3)
1060
+ new_attention_mask = attention_mask[:, :, -1:].clone()
1061
+ new_attention_mask[..., -1] = False
1062
+ model_kwargs["attention_mask"] = torch.cat(
1063
+ [attention_mask, new_attention_mask], dim=2
1064
+ )
1065
 
1066
  # update position ids
1067
  if "position_ids" in model_kwargs:
 
1093
  # only last token for input_ids if past is not None
1094
  if past is not None or past_key_values is not None:
1095
  last_token = input_ids[:, -1].unsqueeze(-1)
1096
+ if attention_mask is not None and attention_mask.dtype == torch.bool:
1097
  attention_mask = attention_mask[:, :, -1:]
1098
+ else:
1099
+ attention_mask = None
1100
  if position_ids is not None:
1101
  position_ids = position_ids[..., -1:]
1102
  else:
 
1118
  "attention_mask": attention_mask
1119
  }
1120
  else:
1121
+ if attention_mask is not None and attention_mask.dtype != torch.bool:
1122
+ logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool")
1123
+ attention_mask = None
1124
  if attention_mask is None:
1125
  attention_mask = self.get_masks(
1126
  input_ids,
tokenization_chatglm.py CHANGED
@@ -382,8 +382,6 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
382
  mask_token_id = self.sp_tokenizer[self.mask_token]
383
  gmask_token_id = self.sp_tokenizer[self.gmask_token]
384
  assert self.padding_side == "left"
385
- if return_attention_mask is None:
386
- return_attention_mask = "attention_mask" in self.model_input_names
387
 
388
  required_input = encoded_inputs[self.model_input_names[0]]
389
  seq_length = len(required_input)
 
382
  mask_token_id = self.sp_tokenizer[self.mask_token]
383
  gmask_token_id = self.sp_tokenizer[self.gmask_token]
384
  assert self.padding_side == "left"
 
 
385
 
386
  required_input = encoded_inputs[self.model_input_names[0]]
387
  seq_length = len(required_input)