zxdu20 commited on
Commit
d8a6cfc
·
1 Parent(s): f6b88da

Update decode method in tokenizer

Browse files
Files changed (1) hide show
  1. tokenization_chatglm.py +20 -7
tokenization_chatglm.py CHANGED
@@ -31,6 +31,9 @@ class TextTokenizer:
31
  def tokenize(self, text):
32
  return self.sp.EncodeAsPieces(text)
33
 
 
 
 
34
  def convert_tokens_to_ids(self, tokens):
35
  return [self.sp.PieceToId(token) for token in tokens]
36
 
@@ -111,16 +114,25 @@ class SPTokenizer:
111
  tokens = [x + self.num_image_tokens for x in tmp]
112
  return tokens if add_dummy_prefix else tokens[2:]
113
 
114
- def decode(self, text_ids: List[int]) -> str:
115
- ids = [int(_id) - self.num_image_tokens for _id in text_ids]
116
- ids = [_id for _id in ids if _id >= 0]
117
- text = self._get_text_tokenizer().decode(ids)
118
  text = text.replace("<n>", "\n")
119
  text = text.replace(SPTokenizer.get_tab_token(), "\t")
120
  for i in range(2, self.max_blank_length + 1):
121
  text = text.replace(self.get_blank_token(i), " " * i)
122
  return text
123
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  def tokenize(
125
  self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
126
  ) -> List[str]:
@@ -256,11 +268,12 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
256
 
257
  return seq
258
 
 
 
 
259
  def _decode(
260
  self,
261
  token_ids: Union[int, List[int]],
262
- skip_special_tokens: bool = False,
263
- clean_up_tokenization_spaces: bool = True,
264
  **kwargs
265
  ) -> str:
266
  if isinstance(token_ids, int):
@@ -269,7 +282,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
269
  return ""
270
  if self.pad_token_id in token_ids: # remove pad
271
  token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
272
- return self.sp_tokenizer.decode(token_ids)
273
 
274
  def _convert_token_to_id(self, token):
275
  """ Converts a token (str) in an id using the vocab. """
 
31
  def tokenize(self, text):
32
  return self.sp.EncodeAsPieces(text)
33
 
34
+ def convert_tokens_to_string(self, tokens):
35
+ return self.sp.DecodePieces(tokens)
36
+
37
  def convert_tokens_to_ids(self, tokens):
38
  return [self.sp.PieceToId(token) for token in tokens]
39
 
 
114
  tokens = [x + self.num_image_tokens for x in tmp]
115
  return tokens if add_dummy_prefix else tokens[2:]
116
 
117
+ def postprocess(self, text):
 
 
 
118
  text = text.replace("<n>", "\n")
119
  text = text.replace(SPTokenizer.get_tab_token(), "\t")
120
  for i in range(2, self.max_blank_length + 1):
121
  text = text.replace(self.get_blank_token(i), " " * i)
122
  return text
123
 
124
+ def decode(self, text_ids: List[int]) -> str:
125
+ ids = [int(_id) - self.num_image_tokens for _id in text_ids]
126
+ ids = [_id for _id in ids if _id >= 0]
127
+ text = self._get_text_tokenizer().decode(ids)
128
+ text = self.postprocess(text)
129
+ return text
130
+
131
+ def decode_tokens(self, tokens: List[str]) -> str:
132
+ text = self._get_text_tokenizer().convert_tokens_to_string(tokens)
133
+ text = self.postprocess(text)
134
+ return text
135
+
136
  def tokenize(
137
  self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
138
  ) -> List[str]:
 
268
 
269
  return seq
270
 
271
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
272
+ return self.sp_tokenizer.decode_tokens(tokens)
273
+
274
  def _decode(
275
  self,
276
  token_ids: Union[int, List[int]],
 
 
277
  **kwargs
278
  ) -> str:
279
  if isinstance(token_ids, int):
 
282
  return ""
283
  if self.pad_token_id in token_ids: # remove pad
284
  token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
285
+ return super()._decode(token_ids, **kwargs)
286
 
287
  def _convert_token_to_id(self, token):
288
  """ Converts a token (str) in an id using the vocab. """