indiejoseph commited on
Commit
cd13b6c
·
1 Parent(s): ab0f5b9

Upload 2 files

Browse files
Files changed (2) hide show
  1. translation_pipeline.py +159 -0
  2. translator.py +437 -0
translation_pipeline.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TranslationPipeline as HfTranslationPipeline
2
+ from transformers.pipelines.text2text_generation import ReturnType
3
+ import re
4
+
5
+
6
+ hans_chars = set(
7
+ "万与专业丛东丝丢两严丧个丬丰临为丽举么义乌乐乔习乡书买乱争于亏亘亚产亩亲亵亸亿仅从仑仓仪们价众优会伛伞伟传伤伥伦伧伪伫体佥侠侣侥侦侧侨侩侪侬俣俦俨俩俪俭债倾偬偻偾偿傥傧储傩儿兑兖党兰关兴兹养兽冁内冈册写军农冢冯冲决况冻净凄凉减凑凛凤凫凭凯击凼凿刍刘则刚创删别刬刭刽刿剀剂剐剑剥剧劝办务劢动励劲劳势勋勐勚匀匦匮区医华协单卖卢卤卫却卺厂厅历厉压厌厍厕厢厣厦厨厩厮县参叆叇双发变叙叠叶号叹叽吁吕吗吣吨听启吴呒呓呕呖呗员呙呛呜咏咙咛咝咤咴哌哑哒哓哔哕哗哙哜哝哟唛唝唠唡唢唣唤唿啧啬啭啮啰啴啸喷喽喾嗫嗳嘘嘤嘱噜嚣嚯团园囱围囵国图圆圣圹场坂坏块坚坛坜坝坞坟坠垄垅垆垒垦垧垩垫垭垯垱垲垴埘埙埚埝埯堑堕塆墙壮声壳壶壸处备复够头夸夹夺奁奂奋奖奥妆妇妈妩妪妫姗娄娅娆娇娈娱娲娴婳婴婵婶媪嫒嫔嫱嬷孙学孪宁宝实宠审宪宫宽宾寝对寻导寿将尔尘尧尴尸尽层屃屉届属屡屦屿岁岂岖岗岘岙岚岛岭岽岿峃峄峡峣峤峥峦崂崃崄崭嵘嵚嵛嵝嵴巅巩巯币帅师帏帐帘帜带帧帮帱帻帼幂幞并广庄庆庐庑库应庙庞废庼廪开异弃张弥弪弯弹强归当录彟彦彻径徕忆忏忧忾怀态怂怃怄怅怆怜总怼怿恋恳恶恸恹恺恻恼恽悦悫悬悭悯惊惧惨惩惫惬惭惮惯愍愠愤愦愿慑慭憷懑懒懔戆戋戏戗战戬户扦执扩扪扫扬扰抚抛抟抠抡抢护报担拟拢拣拥拦拧拨择挂挚挛挜挝挞挟挠挡挢挣挤挥挦捞损捡换捣据掳掴掷掸掺掼揽揿搀搁搂搅携摄摅摆摇摈摊撄撵撷撸撺擞攒敌敛数斋斓斩断无旧时旷旸昙昼昽显晋晒晓晔晕晖暂暧术机杀杂权条来杨杩极构枞枢枣枥枧枨枪枫枭柜柠柽栀栅标栈栉栊栋栌栎栏树栖样栾桊桠桡桢档桤桥桦桧桨桩梦梼梾检棂椁椟椠椤椭楼榄榇榈榉槚槛槟槠横樯樱橥橱橹橼檩欢欤欧歼殁殇残殒殓殚殡殴毁毂毕毙毡毵氇气氢氩氲汇汉汤汹沓沟没沣沤沥沦沧沨沩沪沵泞泪泶泷泸泺泻泼泽泾洁洒洼浃浅浆浇浈浉浊测浍济浏浐浑浒浓浔浕涛涝涞涟涠涡涢涣涤润涧涨涩淀渊渌渍渎渐渑渔渖渗温湾湿溃溅溆溇滗滚滞滟滠满滢滤滥滦滨滩滪漤潆潇潋潍潜潴澜濑濒灏灭灯灵灾灿炀炉炖炜炝点炼炽烁烂烃烛烟烦烧烨烩烫烬热焕焖焘煅煳熘爱爷牍牦牵牺犊犟状犷犸犹狈狍狝狞独狭狮狯狰狱狲猃猎猕猡猪猫猬献獭玑玙玚玛玮环现玱玺珉珏珐珑珰珲琎琏琐琼瑶瑷璎瓒瓮瓯电画畅畲畴疖疗疟疠疡疬疮疯疴痈痉痒痖痨痪痫瘅瘆瘗瘘瘪瘫瘾瘿癞癣癫癯皑皱皲盏盐监盖盗盘眍眦眬睁睐睑瞒瞩矫矶矾矿砀码砖砗砚砜砺砻砾础硁硕硖硗硙硚确硷碍碛碜碱碹磙礼祎祢祯祷祸禀禄禅离秃秆种积称秽秾稆税稣稳穑穷窃窍窑窜窝窥窦窭竖竞笃笋笔笕笺笼笾筑筚筛筜筝筹签简箓箦箧箨箩箪箫篑篓篮篱簖籁籴类籼粜粝粤粪粮糁糇紧絷纟纠纡红纣纤纥约级纨纩纪纫纬纭纮纯纰纱纲纳纴纵纶纷纸纹纺纻纼纽纾线绀绁绂练组绅细织终绉绊绋绌绍绎经绐绑绒结绔绕绖绗绘给绚绛络绝绞统绠绡绢绣绤绥绦继绨绩绪绫绬续绮绯绰绱绲绳维绵绶绷绸绹绺绻综绽绾绿缀缁缂缃缄缅缆缇缈缉缊缋缌缍缎缏缐缑缒缓缔缕编缗缘缙缚缛缜缝缞缟缠缡缢缣缤缥缦缧缨缩缪缫缬缭缮缯缰缱缲缳缴缵罂网罗罚罢罴羁羟羡翘翙翚耢耧耸耻聂聋职聍联聩聪肃肠肤肷肾肿胀胁胆胜胧胨胪胫胶脉脍脏脐脑脓脔脚脱脶脸腊腌腘腭腻腼腽腾膑臜舆舣舰舱舻艰艳艹艺节芈芗芜芦苁苇苈苋苌苍苎苏苘苹茎茏茑茔茕茧荆荐荙荚荛荜荞荟荠荡荣荤荥荦荧荨荩荪荫荬荭荮药莅莜莱莲莳莴莶获莸莹莺莼萚萝萤营萦萧萨葱蒇蒉蒋蒌蓝蓟蓠蓣蓥蓦蔷蔹蔺蔼蕲蕴薮藁藓虏虑虚虫虬虮虽虾虿蚀蚁蚂蚕蚝蚬蛊蛎蛏蛮蛰蛱蛲蛳蛴蜕蜗蜡蝇蝈蝉蝎蝼蝾螀螨蟏衅衔补衬衮袄袅袆袜袭袯装裆裈裢裣裤裥褛褴襁襕见观觃规觅视觇览觉觊觋觌觍觎觏觐觑觞触觯詟誉誊讠计订讣认讥讦讧讨让讪讫训议讯记讱讲讳讴讵讶讷许讹论讻讼讽设访诀证诂诃评诅识诇诈诉诊诋诌词诎诏诐译诒诓诔试诖诗诘诙诚诛诜话诞诟诠诡询诣诤该详诧诨诩诪诫诬语诮误诰诱诲诳说诵诶请诸诹诺读诼诽课诿谀谁谂调谄谅谆谇谈谊谋谌谍谎谏谐谑谒谓谔谕谖谗谘谙谚谛谜谝谞谟谠谡谢谣谤谥谦谧谨谩谪谫谬谭谮谯谰谱谲谳谴谵谶豮贝贞负贠贡财责贤败账货质贩贪贫贬购贮贯贰贱贲贳贴贵贶贷贸费贺贻贼贽贾贿赀赁赂赃资赅赆赇赈赉赊赋赌赍赎赏赐赑赒赓赔赕赖赗赘赙赚赛赜赝赞赟赠赡赢赣赪赵赶趋趱趸跃跄跖跞践跶跷跸跹跻踊踌踪踬踯蹑蹒蹰蹿躏躜躯车轧轨轩轪轫转轭轮软轰轱轲轳轴轵轶轷轸轹轺轻轼载轾轿辀辁辂较辄辅辆辇辈辉辊辋辌辍辎辏辐辑辒输辔辕辖辗辘辙辚辞辩辫边辽达迁过迈运还这进远违连迟迩迳迹适选逊递逦逻遗遥邓邝邬邮邹邺邻郄郏郐郑郓郦郧郸酝酦酱酽酾酿释鉴銮錾钆钇针钉钊钋钌钍钎钏钐钑钒钓钔钕钖钗钘钙钚钛钝钞钟钠钡钢钣钤钥钦钧钨钩钪钫钬钭钮钯钰钱钲钳钴钵钶钷钸钹钺钻钼钽钾钿铀铁铂铃铄铅铆铈铉铊铋铍铎铏铐铑铒铕铗铘铙铚铛铜铝铞铟铠铡铢铣铤铥铦铧铨铪铫铬铭铮铯铰铱铲铳铴铵银铷铸铹铺铻铼铽链铿销锁锂锃锄锅锆锇锈锉锊锋锌锍锎锏锐锑锒锓锔锕锖锗错锚锜锞锟锠锡锢锣锤锥锦锨锩锫锬锭键锯锰锱锲锳锴锵锶锷锸锹锺锻锼锽锾锿镀镁镂镃镆镇镈镉镊镌镍镎镏镐镑镒镕镖镗镙镚镛镜镝镞镟镠镡镢镣镤镥镦镧镨镩镪镫镬镭镮镯镰镱镲镳镴镶长门闩闪闫闬闭问闯闰闱闲闳间闵闶闷闸闹闺闻闼闽闾闿阀阁阂阃阄阅阆阇阈阉阊阋阌阍阎阏阐阑阒阓阔阕阖阗阘阙阚阛队阳阴阵阶际陆陇陈陉陕陧陨险随隐隶隽难雏雠雳雾霁霭靓静靥鞑鞒鞯鞴韦韧韨韩韪韫韬韵页顶顷顸项顺须顼顽顾顿颀颁颂颃预颅领颇颈颉颊颋颌颍颎颏颐频颒颓颔颕颖颗题颙颚颛颜额颞颟颠颡颢颣颤颥颦颧风飏飐飑飒飓飔飕飖飗飘飙飚飞飨餍饤饥饦饧饨饩饪饫饬饭饮饯饰饱饲饳饴饵饶饷饸饹饺饻饼饽饾饿馀馁馂馃馄馅馆馇馈馉馊馋馌馍馎馏馐馑馒馓馔馕马驭驮驯驰驱驲驳驴驵驶驷驸驹驺驻驼驽驾驿骀骁骂骃骄骅骆骇骈骉骊骋验骍骎骏骐骑骒骓骔骕骖骗骘骙骚骛骜骝骞骟骠骡骢骣骤骥骦骧髅髋髌鬓魇魉鱼鱽鱾鱿鲀鲁鲂鲄鲅鲆鲇鲈鲉鲊鲋鲌鲍鲎鲏鲐鲑鲒鲓鲔鲕鲖鲗鲘鲙鲚鲛鲜鲝鲞鲟鲠鲡鲢鲣鲤鲥鲦鲧鲨鲩鲪鲫鲬鲭鲮鲯鲰鲱鲲鲳鲴鲵鲶鲷鲸鲹鲺鲻鲼鲽鲾鲿鳀鳁鳂鳃鳄鳅鳆鳇鳈鳉鳊鳋鳌鳍鳎鳏鳐鳑鳒鳓鳔鳕鳖鳗鳘鳙鳛鳜鳝鳞鳟鳠鳡鳢鳣鸟鸠鸡鸢鸣鸤鸥鸦鸧鸨鸩鸪鸫鸬鸭鸮鸯鸰鸱鸲鸳鸴鸵鸶鸷鸸鸹鸺鸻鸼鸽鸾鸿鹀鹁鹂鹃鹄鹅鹆鹇鹈鹉鹊鹋鹌鹍鹎鹏鹐鹑鹒鹓鹔鹕鹖鹗鹘鹚鹛鹜鹝鹞鹟鹠鹡鹢鹣鹤鹥鹦鹧鹨鹩鹪鹫鹬鹭鹯鹰鹱鹲鹳鹴鹾麦麸黄黉黡黩黪黾鼋鼌鼍鼗鼹齄齐齑齿龀龁龂龃龄龅龆龇龈龉龊龋龌龙龚龛龟咨尝“”"
8
+ )
9
+
10
+
11
+ def fix_chinese_text_generation_space(text):
12
+ output_text = text
13
+ output_text = re.sub(
14
+ r'([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.·\/\\])\s([^0-9a-zA-Z])',
15
+ r"\1\2",
16
+ output_text,
17
+ )
18
+ output_text = re.sub(
19
+ r'([^0-9a-zA-Z])\s([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.·\/\\])',
20
+ r"\1\2",
21
+ output_text,
22
+ )
23
+ output_text = re.sub(
24
+ r'([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.·\/\\])\s([a-zA-Z0-9])',
25
+ r"\1\2",
26
+ output_text,
27
+ )
28
+ output_text = re.sub(
29
+ r'([a-zA-Z0-9])\s([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.·\/\\])',
30
+ r"\1\2",
31
+ output_text,
32
+ )
33
+ output_text = re.sub(r"$\s([0-9])", r"$\1", output_text)
34
+ output_text = re.sub(",", ",", output_text)
35
+ output_text = re.sub(
36
+ r"([0-9]),([0-9])", r"\1,\2", output_text
37
+ ) # fix comma in numbers
38
+ # fix multiple commas
39
+ output_text = re.sub(r"\s?[,]+\s?", ",", output_text)
40
+ output_text = re.sub(r"\s?[、]+\s?", "、", output_text)
41
+ # fix period
42
+ output_text = re.sub(r"\s?[。]+\s?", "。", output_text)
43
+ # fix ...
44
+ output_text = re.sub(r"\s?\.{3,}\s?", "...", output_text)
45
+ # fix exclamation mark
46
+ output_text = re.sub(r"\s?[!!]+\s?", "!", output_text)
47
+ # fix question mark
48
+ output_text = re.sub(r"\s?[??]+\s?", "?", output_text)
49
+ # fix colon
50
+ output_text = re.sub(r"\s?[::]+\s?", ":", output_text)
51
+ # fix quotation mark
52
+ output_text = re.sub(r'\s?(["“”\']+)\s?', r"\1", output_text)
53
+ # fix semicolon
54
+ output_text = re.sub(r"\s?[;;]+\s?", ";", output_text)
55
+ # fix dots
56
+ output_text = re.sub(r"\s?([~●.…]+)\s?", r"\1", output_text)
57
+ output_text = re.sub(r"\s?\[…\]\s?", "", output_text)
58
+ output_text = re.sub(r"\s?\[\.\.\.\]\s?", "", output_text)
59
+ output_text = re.sub(r"\s?\.{3,}\s?", "...", output_text)
60
+ # fix slash
61
+ output_text = re.sub(r"\s?[//]+\s?", "/", output_text)
62
+ # fix dollar sign
63
+ output_text = re.sub(r"\s?[$$]+\s?", "$", output_text)
64
+ # fix @
65
+ output_text = re.sub(r"\s?([@@]+)\s?", "@", output_text)
66
+ # fix baskets
67
+ output_text = re.sub(r"\s?([\[\(<��【「『()』」】〗>\)\]]+)\s?", r"\1", output_text)
68
+
69
+ return output_text
70
+
71
+
72
+ class TranslationPipeline(HfTranslationPipeline):
73
+ def __init__(
74
+ self,
75
+ model,
76
+ tokenizer,
77
+ device=None,
78
+ max_length=512,
79
+ src_lang=None,
80
+ tgt_lang=None,
81
+ num_beams=3,
82
+ do_sample=True,
83
+ top_k=50,
84
+ top_p=0.95,
85
+ temperature=1.0,
86
+ repetition_penalty=1.0,
87
+ length_penalty=1.0,
88
+ sequence_bias=None,
89
+ bad_words_ids=None,
90
+ no_repeat_ngram_size=0,
91
+ ):
92
+ self.model = model
93
+ self.tokenizer = tokenizer
94
+
95
+ def get_tokens(word):
96
+ return tokenizer([word], add_special_tokens=False).input_ids[0]
97
+
98
+ bad_words_ids = [get_tokens(char) for char in hans_chars]
99
+
100
+ super().__init__(
101
+ self.model,
102
+ self.tokenizer,
103
+ device=device,
104
+ max_length=max_length,
105
+ src_lang=src_lang,
106
+ tgt_lang=tgt_lang,
107
+ num_beams=num_beams,
108
+ do_sample=do_sample,
109
+ top_k=top_k if do_sample == True else None,
110
+ top_p=top_p if do_sample == True else None,
111
+ temperature=temperature,
112
+ repetition_penalty=repetition_penalty,
113
+ length_penalty=length_penalty,
114
+ sequence_bias=sequence_bias,
115
+ bad_words_ids=bad_words_ids,
116
+ no_repeat_ngram_size=no_repeat_ngram_size,
117
+ )
118
+
119
+ def postprocess(
120
+ self,
121
+ model_outputs,
122
+ return_type=ReturnType.TEXT,
123
+ clean_up_tokenization_spaces=True,
124
+ ):
125
+ records = super().postprocess(
126
+ model_outputs,
127
+ return_type=return_type,
128
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
129
+ )
130
+ for rec in records:
131
+ translation_text = fix_chinese_text_generation_space(
132
+ rec["translation_text"].strip()
133
+ )
134
+
135
+ rec["translation_text"] = translation_text
136
+ return records
137
+
138
+ def __call__(self, *args, **kwargs):
139
+ records = super().__call__(*args, **kwargs)
140
+
141
+ return records
142
+
143
+
144
+ if __name__ == "__main__":
145
+ from transformers import BertTokenizerFast
146
+ from optimum.onnxruntime import ORTModelForSeq2SeqLM
147
+
148
+ model_id = "hon9kon9ize/bart-translation-zh-yue-onnx"
149
+
150
+ tokenizer = BertTokenizerFast.from_pretrained(model_id)
151
+ model = ORTModelForSeq2SeqLM.from_pretrained(model_id, use_cache=False)
152
+ pipe = TranslationPipeline(model=model, tokenizer=tokenizer)
153
+
154
+ print(
155
+ pipe(
156
+ "近年成为许多港人热门移居地的英国中部城巿诺定咸(又译诺丁汉,Nottingham),多年来一直面对财政困境,市议会周三(11月29日)宣布破产,是继英国第二大城市伯明翰今年9月宣布破产后,近期「爆煲」的另一个英国主要城市。诺定咸除了维持法例规定必须提供的服务外,巿政府将暂停所有非必要的公共开支。",
157
+ max_length=300,
158
+ )
159
+ )
translator.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Callable
2
+ import traceback
3
+ from typing import List, Union
4
+ from datasets import Dataset
5
+ import re
6
+ import pickle
7
+ import os
8
+ from transformers.pipelines.pt_utils import KeyDataset
9
+ from transformers import AutoTokenizer
10
+ from tqdm.auto import tqdm
11
+
12
+ URL_REGEX = r"\b(https?://\S+)\b"
13
+ EMAIL_REGEX = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)"
14
+ TAG_REGEX = r"<[^>]+>"
15
+ HANDLE_REGEX = r"[^a-zA-Z](@\w+)"
16
+
17
+
18
+ class Translator:
19
+ def __init__(
20
+ self,
21
+ pipe: Callable,
22
+ max_length: int = 500,
23
+ batch_size: int = 16,
24
+ save_every_step=100,
25
+ text_key="text",
26
+ save_filename=None,
27
+ replace_chinese_puncts=False,
28
+ verbose=False,
29
+ ):
30
+ self.pipe = pipe
31
+ self.max_length = max_length
32
+ self.batch_size = batch_size
33
+ self.save_every_step = save_every_step
34
+ self.save_filename = save_filename
35
+ self.text_key = text_key
36
+ self.replace_chinese_puncts = replace_chinese_puncts
37
+ self.verbose = verbose
38
+
39
+ if max_length == None and hasattr(pipe.model.config, "max_length"):
40
+ self.max_length = pipe.model.config.max_length
41
+
42
+ def _is_chinese(self, text: str) -> bool:
43
+ return (
44
+ re.search(
45
+ r"[\u4e00-\u9fff\u3400-\u4dbf\U00020000-\U0002a6df\U0002a700-\U0002ebef\U00030000-\U000323af\ufa0e\ufa0f\ufa11\ufa13\ufa14\ufa1f\ufa21\ufa23\ufa24\ufa27\ufa28\ufa29\u3006\u3007][\ufe00-\ufe0f\U000e0100-\U000e01ef]?",
46
+ text,
47
+ )
48
+ is not None
49
+ )
50
+
51
+ def _split_sentences(self, text: str) -> List[str]:
52
+ tokens = self.pipe.tokenizer(text, add_special_tokens=False)
53
+ token_size = len(tokens.input_ids)
54
+
55
+ if len(text) <= self.max_length:
56
+ return [text]
57
+
58
+ delimiter = set()
59
+ delimiter.update("。!?;…!?;")
60
+ sent_list = []
61
+ sent = text
62
+
63
+ while token_size > self.max_length:
64
+ orig_sent_len = token_size
65
+
66
+ # find the index of delimiter near the max_length
67
+ for i in range(token_size - 2, 0, -1):
68
+ token = tokens.token_to_chars(0, i)
69
+ char = sent[token.start : token.end]
70
+
71
+ if char in delimiter:
72
+ split_char_index = token.end
73
+ next_sent = sent[split_char_index:]
74
+
75
+ if len(next_sent) == 1:
76
+ continue
77
+
78
+ sent_list = [next_sent] + sent_list
79
+ sent = sent[0:split_char_index]
80
+ break
81
+
82
+ tokens = self.pipe.tokenizer(sent, add_special_tokens=False)
83
+ token_size = len(tokens.input_ids)
84
+
85
+ # no delimiter found, leave the sentence as it is
86
+ if token_size == orig_sent_len:
87
+ sent_list = [sent] + sent_list
88
+ sent = ""
89
+ break
90
+
91
+ if len(sent) > 0:
92
+ sent_list = [sent] + sent_list
93
+
94
+ return sent_list
95
+
96
+ def _preprocess(self, text: str) -> (str, str):
97
+ # extract entities
98
+ tags = re.findall(TAG_REGEX, text)
99
+ handles = re.findall(HANDLE_REGEX, text)
100
+ urls = re.findall(URL_REGEX, text)
101
+ emails = re.findall(EMAIL_REGEX, text)
102
+ entities = urls + emails + tags + handles
103
+
104
+ # TODO: escape entity placeholders
105
+
106
+ for i, entity in enumerate(entities):
107
+ text = text.replace(entity, "eeee[%d]" % i, 1)
108
+
109
+ lines = text.split("\n")
110
+ sentences = []
111
+ num_tokens = []
112
+ template = text.replace("{", "{{").replace("}", "}}")
113
+ chunk_index = 0
114
+
115
+ for line in lines:
116
+ sentence = line.strip()
117
+
118
+ if len(sentence) > 0 and self._is_chinese(sentence):
119
+ chunks = self._split_sentences(sentence)
120
+
121
+ for chunk in chunks:
122
+ sentences.append(chunk)
123
+ tokens = self.pipe.tokenizer(chunk, add_special_tokens=False)
124
+ num_tokens.append(len(tokens.input_ids))
125
+ chunk = chunk.replace("{", "{{").replace("}", "}}")
126
+ template = template.replace(chunk, "{%d}" % chunk_index, 1)
127
+ chunk_index += 1
128
+
129
+ return sentences, template, num_tokens, entities
130
+
131
+ def _postprocess(
132
+ self,
133
+ template: str,
134
+ src_sentences: List[str],
135
+ translations: List[str],
136
+ entities: List[str],
137
+ ) -> str:
138
+ processed = []
139
+ alphanumeric_regex = re.compile(
140
+ "([a-zA-Za-zA-Z0-9\d+'\",,(\()\)::;;“”。·\.\??\!!‘’$\[\]<>/]+)"
141
+ )
142
+
143
+ def hash_text(text: List[str]) -> str:
144
+ text = "|".join(text)
145
+ puncts_map = str.maketrans(",;:()。?!“”‘’", ",;:().?!\"\"''")
146
+ text = text.translate(puncts_map)
147
+ return text.lower()
148
+
149
+ for i, p in enumerate(translations):
150
+ src_sentence = src_sentences[i]
151
+
152
+ if self.replace_chinese_puncts:
153
+ p = re.sub(",", ",", p) # replace all commas
154
+ p = re.sub(";", ";", p) # replace semi-colon
155
+ p = re.sub(":", ":", p) # replace colon
156
+ p = re.sub("\(", "(", p) # replace round basket
157
+ p = re.sub("\)", ")", p) # replace round basket
158
+ p = re.sub(r"([\d]),([\d])", r"\1,\2", p)
159
+
160
+ src_matches = re.findall(alphanumeric_regex, src_sentence)
161
+ tgt_matches = re.findall(alphanumeric_regex, p)
162
+
163
+ # length not match or no match
164
+ if (
165
+ len(src_matches) != len(tgt_matches)
166
+ or len(src_matches) == 0
167
+ or len(tgt_matches) == 0
168
+ ):
169
+ processed.append(p)
170
+ continue
171
+
172
+ # normalize full-width to half-width and lower case
173
+ src_hashes = hash_text(src_matches)
174
+ translated_hashes = hash_text(tgt_matches)
175
+
176
+ if src_hashes != translated_hashes:
177
+ # fix unmatched
178
+ for j in range(len(src_matches)):
179
+ if src_matches[j] != tgt_matches[j]:
180
+ p = p.replace(tgt_matches[j], src_matches[j], 1)
181
+
182
+ processed.append(p)
183
+
184
+ output = template.format(*processed)
185
+
186
+ # replace entities
187
+ for i, entity in enumerate(entities):
188
+ output = output.replace("eeee[%d]" % i, entity, 1)
189
+
190
+ # TODO: unescape entity placeholders
191
+
192
+ # fix repeated punctuations
193
+ output = re.sub(r"([「」()『』《》。,:])\1+", r"\1", output)
194
+
195
+ # fix brackets
196
+ if "“" in output:
197
+ output = re.sub("“", "「", output)
198
+ if "”" in output:
199
+ output = re.sub("”", "」", output)
200
+
201
+ return output
202
+
203
+ def _save(self, translations):
204
+ with open(self.save_filename, "wb") as f:
205
+ pickle.dump(translations, f)
206
+
207
+ def __call__(self, inputs: Union[List[str], Dataset]) -> List[str]:
208
+ templates = []
209
+ sentences = []
210
+ num_tokens = []
211
+ sentence_indices = []
212
+ outputs = []
213
+ translations = []
214
+ entities_list = []
215
+ resume_from_file = None
216
+
217
+ if isinstance(inputs, Dataset):
218
+ ds = inputs
219
+ else:
220
+ if isinstance(inputs, str):
221
+ inputs = [inputs]
222
+ ds = Dataset.from_list([{"text": text} for text in inputs])
223
+
224
+ for i, text_input in tqdm(
225
+ enumerate(ds), total=len(ds), desc="Preprocessing", disable=not self.verbose
226
+ ):
227
+ chunks, template, num_tokens, entities = self._preprocess(
228
+ text_input["text"]
229
+ )
230
+ templates.append(template)
231
+ sentence_indices.append([])
232
+ entities_list.append(entities)
233
+
234
+ for j, chunk in enumerate(chunks):
235
+ sentences.append(chunk)
236
+ sentence_indices[len(sentence_indices) - 1].append(len(sentences) - 1)
237
+ num_tokens.append(num_tokens[j])
238
+
239
+ if self.save_filename:
240
+ resume_from_file = (
241
+ self.save_filename if os.path.isfile(self.save_filename) else None
242
+ )
243
+
244
+ if resume_from_file != None:
245
+ translations = pickle.load(open(resume_from_file, "rb"))
246
+
247
+ if self.verbose:
248
+ print("translated:", len(translations))
249
+ print("to translate:", len(sentences) - len(translations))
250
+
251
+ if resume_from_file != None:
252
+ print(
253
+ "Resuming from {}({} records)".format(
254
+ resume_from_file, len(translations)
255
+ )
256
+ )
257
+
258
+ ds = Dataset.from_list(
259
+ [{"text": text} for text in sentences[len(translations) :]]
260
+ )
261
+
262
+ max_token_length = max(num_tokens)
263
+
264
+ if self.verbose:
265
+ print("Max Length:", max_token_length)
266
+
267
+ total_records = len(ds)
268
+
269
+ if total_records > 0:
270
+ step = 0
271
+
272
+ with tqdm(
273
+ disable=not self.verbose, desc="Translating", total=total_records
274
+ ) as pbar:
275
+ for out in self.pipe(
276
+ KeyDataset(ds, self.text_key),
277
+ batch_size=self.batch_size,
278
+ max_length=self.max_length,
279
+ ):
280
+ translations.append(out[0])
281
+
282
+ # export generate result every n steps
283
+ if (
284
+ step != 0
285
+ and self.save_filename != None
286
+ and step % self.save_every_step == 0
287
+ ):
288
+ self._save(translations)
289
+
290
+ step += 1
291
+
292
+ pbar.update(1)
293
+
294
+ if self.save_filename != None and total_records > 0:
295
+ self._save(translations)
296
+
297
+ for i, template in tqdm(
298
+ enumerate(templates),
299
+ total=len(templates),
300
+ desc="Postprocessing",
301
+ disable=not self.verbose,
302
+ ):
303
+ try:
304
+ src_sentences = [sentences[index] for index in sentence_indices[i]]
305
+ tgt_sentences = [
306
+ translations[index]["translation_text"]
307
+ for index in sentence_indices[i]
308
+ ]
309
+ output = self._postprocess(
310
+ template, src_sentences, tgt_sentences, entities_list[i]
311
+ )
312
+ outputs.append(output)
313
+ except Exception as error:
314
+ print(error)
315
+ print(template)
316
+ traceback.print_exc()
317
+ # print(template, sentence_indices[i], len(translations))
318
+
319
+ return outputs
320
+
321
+
322
+ class Object(object):
323
+ pass
324
+
325
+
326
+ class FakePipe(object):
327
+ def __init__(self, max_length: int = 500):
328
+ self.model = Object()
329
+ self.model.config = Object()
330
+ self.model.config.max_length = max_length
331
+ self.tokenizer = AutoTokenizer.from_pretrained(
332
+ "indiejoseph/bart-translation-zh-yue"
333
+ )
334
+
335
+ def __call__(self, text: List[str], batch_size: str, max_length: int):
336
+ for i in range(len(text)):
337
+ sentence = text[i]
338
+ # extract entities
339
+ tags = re.findall(TAG_REGEX, sentence)
340
+ handles = re.findall(HANDLE_REGEX, sentence)
341
+ urls = re.findall(URL_REGEX, sentence)
342
+ emails = re.findall(EMAIL_REGEX, sentence)
343
+ entities = urls + emails + tags + handles
344
+
345
+ for i, entity in enumerate(entities):
346
+ sentence = sentence.replace(entity, "eeee[%d]" % i, 1)
347
+
348
+ if "123" in sentence:
349
+ yield [{"translation_text": sentence.replace("123", "123")}]
350
+ continue
351
+ if "abc" in sentence:
352
+ yield [{"translation_text": sentence.replace("abc", "ABC")}]
353
+ continue
354
+ if "Acetaminophen" in sentence:
355
+ yield [
356
+ {
357
+ "translation_text": sentence.replace(
358
+ "Acetaminophen", "ACEtaminidien"
359
+ )
360
+ }
361
+ ]
362
+ continue
363
+ yield [{"translation_text": sentence}]
364
+
365
+
366
+ if __name__ == "__main__":
367
+ fake_pipe = FakePipe(60)
368
+
369
+ translator = Translator(fake_pipe, max_length=60, batch_size=2, verbose=True)
370
+
371
+ text1 = "对于编写聊天机器人的脚本,你可以采用不同的方法,包括使用基于规则的系统、自然语言处理(NLP)技术和机器学习模型。下面是一个简单的例子,展示如何使用基于规则的方法来构建一个简单的聊天机器人:"
372
+ text2 = """对于编写聊天机器人的脚本,你可以采用不同的方法,包括使用基于规则的系统、自然语言处理(NLP)技术和机器学习模型。下面是一个简单的例子,展示如何使用基于规则的方法来构建一个简单的聊天机器人:
373
+
374
+ ```
375
+ # 设置用于匹配输入的关键字,并定义相应的回答数据字典。
376
+ keywords = {'你好': '你好!很高兴见到你。',
377
+ '再见': '再见!有机会再聊。',
378
+ '你叫什么': '我是一个聊天机器人。',
379
+ '你是谁': '我是一个基于人工智能技术制作的聊天机器人。'}
380
+
381
+ # 定义用于处理用户输入的函数。
382
+ def chatbot(input_text):
383
+ # 遍历关键字数据字典,匹配用户的输入。
384
+ for key in keywords:
385
+ if key in input_text:
386
+ # 如果匹配到了关键字,返回相应的回答。
387
+ return keywords[key]
388
+ # 如果没有找到匹配的关键字,返回默认回答。
389
+ return "对不起,我不知道你在说什么。"
390
+
391
+ # 运行聊天机器人。
392
+ while True:
393
+ # 获取用户输入。
394
+ user_input = input('用户: ')
395
+ # 如果用户输入“再见”,退出程序。
396
+ if user_input == '再见':
397
+ break
398
+ # 处理用户输入,并打印回答。
399
+ print('机器人: ' + chatbot(user_input))
400
+ ```
401
+
402
+ 这是一个非常简单的例子。对于实用的聊天机器人,可能需要使用更复杂的 NLP 技术和机器学习模型,以更好地理解和回答用户的问题。"""
403
+ text3 = "布洛芬(Ibuprofen)同撲熱息痛(Acetaminophen)係兩種常見嘅非處方藥,用於緩解疼痛、發燒同關節痛。"
404
+ text4 = "123 “abc” def's http://www.google.com [email protected] @abc 網址:http://localhost/abc下載"
405
+ text5 = "新力公司董事長盛田昭夫、自民黨國會議員石原慎太郎等人撰寫嘅《日本可以說「不」》、《日本還要說「不」》、《日本堅決說「不」》三本書中話道:「無啦啦挑起戰爭嘅好戰日本人,製造南京大屠殺嘅殘暴嘅日本人,呢d就係人地對日本人嘅兩個誤解,都係‘敲打日本’嘅兩個根由,我地必須採取措施消除佢。」"
406
+ outputs = translator([text1, text2, text3, text4, text5])
407
+
408
+ # for i, line in enumerate(outputs[1].split("\n")):
409
+ # input_text = text2.split("\n")[i]
410
+
411
+ # if line != input_text:
412
+ # print(line, text2.split("\n")[i])
413
+
414
+ assert outputs[0] == text1
415
+ assert outputs[1] == text2.replace("“", "「").replace("”", "」")
416
+ assert outputs[2] == text3
417
+ assert outputs[3] == text4.replace("“", "「").replace("”", "」")
418
+ assert outputs[4] == text5
419
+
420
+ # exception
421
+ assert (
422
+ len(
423
+ translator._split_sentences(
424
+ "新力公司董事長盛田昭夫、自民黨國會議員石原慎太郎等人撰寫嘅《日本可以說「不」》、《日本還要說「不」》、《日本堅決說「不」》三本書中話道:「無啦啦挑起戰爭嘅好戰日本人,製造南京大屠殺嘅殘暴嘅日本人,呢d就係人地對日本人嘅兩個誤解,都係‘敲打日本’嘅兩個根由,我地必須採取措施消除佢。」"
425
+ )
426
+ )
427
+ == 1
428
+ )
429
+
430
+ translator = Translator(fake_pipe, max_length=5, batch_size=2, verbose=True)
431
+
432
+ assert (
433
+ len(
434
+ translator._split_sentences("====。====。====。====。====。====。====。====。====。")
435
+ )
436
+ == 9
437
+ )