zxdu20 commited on
Commit
9324de7
1 Parent(s): d467eff

Use gmask in first place

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +4 -4
modeling_chatglm.py CHANGED
@@ -922,8 +922,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
922
 
923
  if position_ids is None:
924
  MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
925
- mask_token = MASK if MASK in input_ids else gMASK
926
- use_gmask = False if MASK in input_ids else True
927
 
928
  mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
929
  position_ids = self.get_position_ids(
@@ -1085,8 +1085,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1085
  ) -> dict:
1086
  batch_size, seq_length = input_ids.shape
1087
  MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
1088
- mask_token = MASK if MASK in input_ids else gMASK
1089
- use_gmask = False if MASK in input_ids else True
1090
  seqs = input_ids.tolist()
1091
  mask_positions = [seq.index(mask_token) for seq in seqs]
1092
 
 
922
 
923
  if position_ids is None:
924
  MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
925
+ mask_token = gMASK if gMASK in input_ids else MASK
926
+ use_gmask = True if gMASK in input_ids else False
927
 
928
  mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
929
  position_ids = self.get_position_ids(
 
1085
  ) -> dict:
1086
  batch_size, seq_length = input_ids.shape
1087
  MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
1088
+ mask_token = gMASK if gMASK in input_ids else MASK
1089
+ use_gmask = True if gMASK in input_ids else False
1090
  seqs = input_ids.tolist()
1091
  mask_positions = [seq.index(mask_token) for seq in seqs]
1092