Fix generate
Browse files- modeling_chatglm.py +14 -8
- tokenization_chatglm.py +0 -2
modeling_chatglm.py
CHANGED
@@ -1054,13 +1054,14 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1054 |
# update attention mask
|
1055 |
if "attention_mask" in model_kwargs:
|
1056 |
attention_mask = model_kwargs["attention_mask"]
|
1057 |
-
attention_mask
|
1058 |
-
|
1059 |
-
|
1060 |
-
|
1061 |
-
|
1062 |
-
[attention_mask
|
1063 |
-
|
|
|
1064 |
|
1065 |
# update position ids
|
1066 |
if "position_ids" in model_kwargs:
|
@@ -1092,8 +1093,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1092 |
# only last token for input_ids if past is not None
|
1093 |
if past is not None or past_key_values is not None:
|
1094 |
last_token = input_ids[:, -1].unsqueeze(-1)
|
1095 |
-
if attention_mask is not None:
|
1096 |
attention_mask = attention_mask[:, :, -1:]
|
|
|
|
|
1097 |
if position_ids is not None:
|
1098 |
position_ids = position_ids[..., -1:]
|
1099 |
else:
|
@@ -1115,6 +1118,9 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1115 |
"attention_mask": attention_mask
|
1116 |
}
|
1117 |
else:
|
|
|
|
|
|
|
1118 |
if attention_mask is None:
|
1119 |
attention_mask = self.get_masks(
|
1120 |
input_ids,
|
|
|
1054 |
# update attention mask
|
1055 |
if "attention_mask" in model_kwargs:
|
1056 |
attention_mask = model_kwargs["attention_mask"]
|
1057 |
+
if attention_mask is not None and attention_mask.dtype == torch.bool:
|
1058 |
+
attention_mask = torch.cat(
|
1059 |
+
[attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3)
|
1060 |
+
new_attention_mask = attention_mask[:, :, -1:].clone()
|
1061 |
+
new_attention_mask[..., -1] = False
|
1062 |
+
model_kwargs["attention_mask"] = torch.cat(
|
1063 |
+
[attention_mask, new_attention_mask], dim=2
|
1064 |
+
)
|
1065 |
|
1066 |
# update position ids
|
1067 |
if "position_ids" in model_kwargs:
|
|
|
1093 |
# only last token for input_ids if past is not None
|
1094 |
if past is not None or past_key_values is not None:
|
1095 |
last_token = input_ids[:, -1].unsqueeze(-1)
|
1096 |
+
if attention_mask is not None and attention_mask.dtype == torch.bool:
|
1097 |
attention_mask = attention_mask[:, :, -1:]
|
1098 |
+
else:
|
1099 |
+
attention_mask = None
|
1100 |
if position_ids is not None:
|
1101 |
position_ids = position_ids[..., -1:]
|
1102 |
else:
|
|
|
1118 |
"attention_mask": attention_mask
|
1119 |
}
|
1120 |
else:
|
1121 |
+
if attention_mask is not None and attention_mask.dtype != torch.bool:
|
1122 |
+
logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool")
|
1123 |
+
attention_mask = None
|
1124 |
if attention_mask is None:
|
1125 |
attention_mask = self.get_masks(
|
1126 |
input_ids,
|
tokenization_chatglm.py
CHANGED
@@ -382,8 +382,6 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
382 |
mask_token_id = self.sp_tokenizer[self.mask_token]
|
383 |
gmask_token_id = self.sp_tokenizer[self.gmask_token]
|
384 |
assert self.padding_side == "left"
|
385 |
-
if return_attention_mask is None:
|
386 |
-
return_attention_mask = "attention_mask" in self.model_input_names
|
387 |
|
388 |
required_input = encoded_inputs[self.model_input_names[0]]
|
389 |
seq_length = len(required_input)
|
|
|
382 |
mask_token_id = self.sp_tokenizer[self.mask_token]
|
383 |
gmask_token_id = self.sp_tokenizer[self.gmask_token]
|
384 |
assert self.padding_side == "left"
|
|
|
|
|
385 |
|
386 |
required_input = encoded_inputs[self.model_input_names[0]]
|
387 |
seq_length = len(required_input)
|