Post-inference normalization with user-provided locale hints

#6
by sprouts - opened
Files changed (1) hide show
  1. modeling_chatglm.py +61 -6
modeling_chatglm.py CHANGED
@@ -46,6 +46,17 @@ CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
46
  # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm
47
  ]
48
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  class InvalidScoreLogitsProcessor(LogitsProcessor):
51
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
@@ -1087,9 +1098,19 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1087
  for layer_past in past
1088
  )
1089
 
 
 
 
 
 
 
 
 
 
1090
  @torch.no_grad()
1091
- def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
1092
- do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
 
1093
  if history is None:
1094
  history = []
1095
  if logits_processor is None:
@@ -1097,20 +1118,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1097
  logits_processor.append(InvalidScoreLogitsProcessor())
1098
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1099
  "temperature": temperature, "logits_processor": logits_processor, **kwargs}
 
 
1100
  if not history:
1101
  prompt = query
1102
  else:
1103
  prompt = ""
1104
  for i, (old_query, response) in enumerate(history):
1105
- prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
1106
- prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
1107
  input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
1108
  input_ids = input_ids.to(self.device)
1109
  outputs = self.generate(**input_ids, **gen_kwargs)
1110
  outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]) - 2:]
1111
  response = tokenizer.decode(outputs)
1112
- response = response.strip()
1113
- response = response.replace("[[训练时间]]", "2023年")
1114
  history = history + [(query, response)]
1115
  return response, history
1116
 
@@ -1165,6 +1187,39 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1165
 
1166
  return torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device)
1167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1168
  def quantize(self, bits: int):
1169
  from .quantization import quantize
1170
  self.transformer = quantize(self.transformer, bits)
 
46
  # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm
47
  ]
48
 
49
+ QUERY_KEYWORDS = {
50
+ 'chinese-simplified': {
51
+ 'question': '问:',
52
+ 'answer': '答:',
53
+ },
54
+ 'english': {
55
+ 'question': 'Q:',
56
+ 'answer': 'A:',
57
+ }
58
+ }
59
+
60
 
61
  class InvalidScoreLogitsProcessor(LogitsProcessor):
62
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
 
1098
  for layer_past in past
1099
  )
1100
 
1101
+ def chat(self, *args, **kwargs):
1102
+ return self.chat_chinese_simplified(*args, **kwargs)
1103
+
1104
+ def chat_chinese_simplified(self, *args, **kwargs):
1105
+ return self.chat_internal(*args, **kwargs, locale='chinese-simplified')
1106
+
1107
+ def chat_english(self, *args, **kwargs):
1108
+ return self.chat_internal(*args, **kwargs, locale='english')
1109
+
1110
  @torch.no_grad()
1111
+ def chat_internal(self, tokenizer, query: str, locale: str,
1112
+ history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
1113
+ do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
1114
  if history is None:
1115
  history = []
1116
  if logits_processor is None:
 
1118
  logits_processor.append(InvalidScoreLogitsProcessor())
1119
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1120
  "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1121
+ format_query_keyword_question = QUERY_KEYWORDS[locale]['question']
1122
+ format_query_keyword_answer = QUERY_KEYWORDS[locale]['answer']
1123
  if not history:
1124
  prompt = query
1125
  else:
1126
  prompt = ""
1127
  for i, (old_query, response) in enumerate(history):
1128
+ prompt += f"[Round {i}]\n{format_query_keyword_question}{old_query}\n{format_query_keyword_answer}{response}\n"
1129
+ prompt += f"[Round {len(history)}]\n{format_query_keyword_question}{query}\n{format_query_keyword_answer}"
1130
  input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
1131
  input_ids = input_ids.to(self.device)
1132
  outputs = self.generate(**input_ids, **gen_kwargs)
1133
  outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]) - 2:]
1134
  response = tokenizer.decode(outputs)
1135
+ response = self.post_process(response, locale=locale)
 
1136
  history = history + [(query, response)]
1137
  return response, history
1138
 
 
1187
 
1188
  return torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device)
1189
 
1190
+ def post_process(self, response: str, locale: str) -> str:
1191
+ response = response.strip()
1192
+ response = response.replace("[[训练时间]]", "2023年")
1193
+
1194
+ if locale == 'chinese-simplified':
1195
+ import re
1196
+ # CJK Unified Ideographs + CJK Unified Ideographs Extension A
1197
+ cjk_regex = r'([\u4e00-\u9fff]|[\u3400-\u4dbf])'
1198
+ regex_mapping = {
1199
+ cjk_regex + ',': r'\1,',
1200
+ cjk_regex + r'\.': r'\1。',
1201
+ cjk_regex + r'\?': r'\1?',
1202
+ cjk_regex + '!': r'\1!',
1203
+ cjk_regex + ':': r'\1:',
1204
+ cjk_regex + ';': r'\1;',
1205
+ }
1206
+ for pattern in regex_mapping:
1207
+ response = re.sub(pattern, regex_mapping[pattern], response)
1208
+ # Nested parantheses not supported.
1209
+ response = re.sub(r'\(([^\(\)]*(?:[\u4e00-\u9fff]|[\u3400-\u4dbf])[^\(\)]*)\)', r'(\1)', response)
1210
+ elif locale == 'english':
1211
+ mapping = {
1212
+ ',': ',',
1213
+ '。': '.',
1214
+ '?': '?',
1215
+ '!': '!',
1216
+ ':': ':',
1217
+ ';': ';',
1218
+ }
1219
+ for char in mapping:
1220
+ response = response.replace(char, mapping[char])
1221
+ return response
1222
+
1223
  def quantize(self, bits: int):
1224
  from .quantization import quantize
1225
  self.transformer = quantize(self.transformer, bits)