from collections.abc import Callable from typing import List, Union from datasets import Dataset import re import pickle import unicodedata import os from transformers.pipelines.pt_utils import KeyDataset class Translator: def __init__( self, pipe: Callable, max_length: int = 500, batch_size: int = 16, save_every_step=100, text_key="text", save_filename=None, ): self.pipe = pipe self.max_length = ( pipe.model.config.max_length if hasattr(pipe.model.config, "max_length") else max_length ) self.batch_size = batch_size self.save_every_step = save_every_step self.save_filename = save_filename self.text_key = text_key def _is_chinese(self, text: str) -> bool: return ( re.search( r"[\u4e00-\u9fff\u3400-\u4dbf\U00020000-\U0002a6df\U0002a700-\U0002ebef\U00030000-\U000323af\ufa0e\ufa0f\ufa11\ufa13\ufa14\ufa1f\ufa21\ufa23\ufa24\ufa27\ufa28\ufa29\u3006\u3007][\ufe00-\ufe0f\U000e0100-\U000e01ef]?", text, ) is not None ) def _split_sentences(self, text: str) -> List[str]: if len(text) <= self.max_length: return [text] delimiter = set() delimiter.update("。!?;…!?") sent_list = [] sent = text while len(sent) > self.max_length: # find the index of delimiter near the max_length for i in range(self.max_length, 0, -1): if text[i] in delimiter: sent_list.append(sent[0 : i + 1]) sent = sent[i + 1 :] break if len(sent) > 0: sent_list.append(sent) return sent_list def _preprocess(self, text: str) -> (str, str): lines = text.split("\n") sentences = [] template = text.replace("{", "{{").replace("}", "}}") chunk_index = 0 for line in lines: sentence = line.strip() if len(sentence) > 0 and self._is_chinese(sentence): chunks = self._split_sentences(sentence) for chunk in chunks: sentences.append(chunk) chunk = chunk.replace("{", "{{").replace("}", "}}") template = template.replace(chunk, "{%d}" % chunk_index, 1) chunk_index += 1 return sentences, template def _postprocess( self, template: str, src_sentences: List[str], translations: List[str] ) -> str: processed = [] alphanumeric_regex = re.compile( "([a-zA-Za-zA-Z0-9\d+'\",,(\()\)::;;“”。\.\??\!!‘’]+)" ) def hash_text(text: List[str]) -> str: text = "|".join(text) puncts_map = str.maketrans(",;:()。?!“”‘’", ",;:().?!\"\"''") text = text.translate(puncts_map) return unicodedata.normalize("NFKC", text).lower() for i, p in enumerate(translations): src_sentence = src_sentences[i] # p = re.sub(',', ',', p) # replace all commas # p = re.sub(';', ';', p) # replace semi-colon # p = re.sub(':', ':', p) # replace colon # p = re.sub('\(', '(', p) # replace round basket # p = re.sub('\)', ')', p) # replace round basket # p = re.sub(r'([\d]),([\d])', r'\1,\2', p) src_matches = re.findall(alphanumeric_regex, src_sentence) translated_matches = re.findall(alphanumeric_regex, p) # length not match or no match if ( len(src_matches) != len(translated_matches) or len(src_matches) == 0 or len(translated_matches) == 0 ): processed.append(p) continue # normalize full-width to half-width and lower case src_hashes = hash_text(src_matches) translated_hashes = hash_text(translated_matches) if src_hashes != translated_hashes: processed.append(p) continue # replace all matches for j in range(len(src_matches)): p = p.replace(translated_matches[j], src_matches[j], 1) processed.append(p) output = template.format(*processed) return output def _save(self, translations): with open("{}.pkl".format(self.save_filename), "wb") as f: pickle.dump(translations, f) def __call__(self, inputs: Union[List[str], Dataset]) -> List[str]: templates = [] sentences = [] sentence_indices = [] outputs = [] if isinstance(inputs, Dataset): ds = inputs else: if isinstance(inputs, str): inputs = [inputs] ds = Dataset.from_list([{"text": text} for text in inputs]) for i, text_input in enumerate(ds): chunks, template = self._preprocess(text_input["text"]) templates.append(template) sentence_indices.append([]) for chunk in chunks: sentences.append(chunk) sentence_indices[len(sentence_indices) - 1].append(len(sentences) - 1) resume_from_file = ( "{}.pkl".format(self.save_filename) if os.path.isfile("{}.pkl".format(self.save_filename)) else None ) translations = ( [] if resume_from_file == None else pickle.load(open(resume_from_file, "rb")) ) print("translations:", len(translations)) print("dataset:", len(translations)) if resume_from_file != None: print("Resuming from {}({} records)".format(resume_from_file, translations)) ds = Dataset.from_list( [{"text": text} for text in sentences[len(translations) :]] ) total_records = len(ds) if total_records > 0: step = 0 for out in self.pipe( KeyDataset(ds, self.text_key), batch_size=self.batch_size ): translations.append(out[0]) # export generate result every n steps if ( step != 0 and self.save_filename != None and step % self.save_every_step == 0 ): self._save(translations) step += 1 if self.save_filename != None and total_records > 0: self._save(translations) for i, template in enumerate(templates): try: src_sentences = [sentences[index] for index in sentence_indices[i]] translated_sentences = [ translations[index]["translation_text"] for index in sentence_indices[i] ] output = self._postprocess( template, src_sentences, translated_sentences ) outputs.append(output) except Exception as error: print(error) print(template) # print(template, sentence_indices[i], len(translations)) return outputs def fake_pipe(text: List[str], batch_size: str): for i in range(len(text)): if "Acetaminophen" in text[i]: # test case error yield [ {"translation_text": text[i].replace("Acetaminophen", "ACEtaminophen")} ] if "123" in text[i]: yield [{"translation_text": text[i].replace("123", "123")}] if "abc" in text[i]: yield [{"translation_text": text[i].replace("abc", "ABC")}] yield [{"translation_text": text[i]}] if __name__ == "__main__": translator = Translator(fake_pipe, max_length=60) text1 = "对于编写聊天机器人的脚本,你可以采用不同的方法,包括使用基于规则的系统、自然语言处理(NLP)技术和机器学习模型。下面是一个简单的例子,展示如何使用基于规则的方法来构建一个简单的聊天机器人:" text2 = """对于编写聊天机器人的脚本,你可以采用不同的方法,包括使用基于规则的系统、自然语言处理(NLP)技术和机器学习模型。下面是一个简单的例子,展示如何使用基于规则的方法来构建一个简单的聊天机器人: ``` # 设置用于匹配输入的关键字,并定义相应的回答数据字典。 keywords = {'你好': '你好!很高兴见到你。', '再见': '再见!有机会再聊。', '你叫什么': '我是一个聊天机器人。', '你是谁': '我是一个基于人工智能技术制作的聊天机器人。'} # 定义用于处理用户输入的函数。 def chatbot(input_text): # 遍历关键字数据字典,匹配用户的输入。 for key in keywords: if key in input_text: # 如果匹配到了关键字,返回相应的回答。 return keywords[key] # 如果没有找到匹配的关键字,返回默认回答。 return "对不起,我不知道你在说什么。" # 运行聊天机器人。 while True: # 获取用户输入。 user_input = input('用户: ') # 如果用户输入“再见”,退出程序。 if user_input == '再见': break # 处理用户输入,并打印回答。 print('机器人: ' + chatbot(user_input)) ``` 这是一个非常简单的例子。对于实用的聊天机器人,可能需要使用更复杂的 NLP 技术和机器学习模型,以更好地理解和回答用户的问题。""" text3 = "布洛芬(Ibuprofen)同撲熱息痛(Acetaminophen)係兩種常見嘅非處方藥,用於緩解疼痛、發燒同關節痛。" text4 = "123 abc def's" outputs = translator([text1, text2, text3]) # print('Output: ', outputs[0], '\nInput: ', text1) text2_lines = text2.split("\n") for i, text in enumerate(outputs[1].split("\n")): # fine different line if text != text2_lines[i]: print("Output: ", text, "\nInput: ", text2_lines[i]) break assert outputs[0] == text1 assert outputs[1] == text2 assert outputs[2] == text3