devve1 commited on
Commit
1ca56eb
1 Parent(s): 56374e1

Create rag_tokenizer.py

Browse files
Files changed (1) hide show
  1. rag_tokenizer.py +440 -0
rag_tokenizer.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+
17
+ import copy
18
+ import datrie
19
+ import math
20
+ import os
21
+ import re
22
+ import string
23
+ import sys
24
+ from hanziconv import HanziConv
25
+ from huggingface_hub import snapshot_download
26
+ from nltk import word_tokenize
27
+ from nltk.stem import PorterStemmer, WordNetLemmatizer
28
+ from api.utils.file_utils import get_project_base_directory
29
+
30
+
31
+ class RagTokenizer:
32
+ def key_(self, line):
33
+ return str(line.lower().encode("utf-8"))[2:-1]
34
+
35
+ def rkey_(self, line):
36
+ return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
37
+
38
+ def loadDict_(self, fnm):
39
+ print("[HUQIE]:Build trie", fnm, file=sys.stderr)
40
+ try:
41
+ of = open(fnm, "r", encoding='utf-8')
42
+ while True:
43
+ line = of.readline()
44
+ if not line:
45
+ break
46
+ line = re.sub(r"[\r\n]+", "", line)
47
+ line = re.split(r"[ \t]", line)
48
+ k = self.key_(line[0])
49
+ F = int(math.log(float(line[1]) / self.DENOMINATOR) + .5)
50
+ if k not in self.trie_ or self.trie_[k][0] < F:
51
+ self.trie_[self.key_(line[0])] = (F, line[2])
52
+ self.trie_[self.rkey_(line[0])] = 1
53
+ self.trie_.save(fnm + ".trie")
54
+ of.close()
55
+ except Exception as e:
56
+ print("[HUQIE]:Faild to build trie, ", fnm, e, file=sys.stderr)
57
+
58
+ def __init__(self, debug=False):
59
+ self.DEBUG = debug
60
+ self.DENOMINATOR = 1000000
61
+ self.trie_ = datrie.Trie(string.printable)
62
+ self.DIR_ = os.path.join(get_project_base_directory(), "rag/res", "huqie")
63
+
64
+ self.stemmer = PorterStemmer()
65
+ self.lemmatizer = WordNetLemmatizer()
66
+
67
+ self.SPLIT_CHAR = r"([ ,\.<>/?;'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-z\.-]+|[0-9,\.-]+)"
68
+ try:
69
+ self.trie_ = datrie.Trie.load(self.DIR_ + ".txt.trie")
70
+ return
71
+ except Exception as e:
72
+ print("[HUQIE]:Build default trie", file=sys.stderr)
73
+ self.trie_ = datrie.Trie(string.printable)
74
+
75
+ self.loadDict_(self.DIR_ + ".txt")
76
+
77
+ def loadUserDict(self, fnm):
78
+ try:
79
+ self.trie_ = datrie.Trie.load(fnm + ".trie")
80
+ return
81
+ except Exception as e:
82
+ self.trie_ = datrie.Trie(string.printable)
83
+ self.loadDict_(fnm)
84
+
85
+ def addUserDict(self, fnm):
86
+ self.loadDict_(fnm)
87
+
88
+ def _strQ2B(self, ustring):
89
+ """把字符串全角转半角"""
90
+ rstring = ""
91
+ for uchar in ustring:
92
+ inside_code = ord(uchar)
93
+ if inside_code == 0x3000:
94
+ inside_code = 0x0020
95
+ else:
96
+ inside_code -= 0xfee0
97
+ if inside_code < 0x0020 or inside_code > 0x7e: # 转完之后不是半角字符返回原来的字符
98
+ rstring += uchar
99
+ else:
100
+ rstring += chr(inside_code)
101
+ return rstring
102
+
103
+ def _tradi2simp(self, line):
104
+ return HanziConv.toSimplified(line)
105
+
106
+ def dfs_(self, chars, s, preTks, tkslist):
107
+ MAX_L = 10
108
+ res = s
109
+ # if s > MAX_L or s>= len(chars):
110
+ if s >= len(chars):
111
+ tkslist.append(preTks)
112
+ return res
113
+
114
+ # pruning
115
+ S = s + 1
116
+ if s + 2 <= len(chars):
117
+ t1, t2 = "".join(chars[s:s + 1]), "".join(chars[s:s + 2])
118
+ if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(
119
+ self.key_(t2)):
120
+ S = s + 2
121
+ if len(preTks) > 2 and len(
122
+ preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
123
+ t1 = preTks[-1][0] + "".join(chars[s:s + 1])
124
+ if self.trie_.has_keys_with_prefix(self.key_(t1)):
125
+ S = s + 2
126
+
127
+ ################
128
+ for e in range(S, len(chars) + 1):
129
+ t = "".join(chars[s:e])
130
+ k = self.key_(t)
131
+
132
+ if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
133
+ break
134
+
135
+ if k in self.trie_:
136
+ pretks = copy.deepcopy(preTks)
137
+ if k in self.trie_:
138
+ pretks.append((t, self.trie_[k]))
139
+ else:
140
+ pretks.append((t, (-12, '')))
141
+ res = max(res, self.dfs_(chars, e, pretks, tkslist))
142
+
143
+ if res > s:
144
+ return res
145
+
146
+ t = "".join(chars[s:s + 1])
147
+ k = self.key_(t)
148
+ if k in self.trie_:
149
+ preTks.append((t, self.trie_[k]))
150
+ else:
151
+ preTks.append((t, (-12, '')))
152
+
153
+ return self.dfs_(chars, s + 1, preTks, tkslist)
154
+
155
+ def freq(self, tk):
156
+ k = self.key_(tk)
157
+ if k not in self.trie_:
158
+ return 0
159
+ return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5)
160
+
161
+ def tag(self, tk):
162
+ k = self.key_(tk)
163
+ if k not in self.trie_:
164
+ return ""
165
+ return self.trie_[k][1]
166
+
167
+ def score_(self, tfts):
168
+ B = 30
169
+ F, L, tks = 0, 0, []
170
+ for tk, (freq, tag) in tfts:
171
+ F += freq
172
+ L += 0 if len(tk) < 2 else 1
173
+ tks.append(tk)
174
+ F /= len(tks)
175
+ L /= len(tks)
176
+ if self.DEBUG:
177
+ print("[SC]", tks, len(tks), L, F, B / len(tks) + L + F)
178
+ return tks, B / len(tks) + L + F
179
+
180
+ def sortTks_(self, tkslist):
181
+ res = []
182
+ for tfts in tkslist:
183
+ tks, s = self.score_(tfts)
184
+ res.append((tks, s))
185
+ return sorted(res, key=lambda x: x[1], reverse=True)
186
+
187
+ def merge_(self, tks):
188
+ patts = [
189
+ (r"[ ]+", " "),
190
+ (r"([0-9\+\.,%\*=-]) ([0-9\+\.,%\*=-])", r"\1\2"),
191
+ ]
192
+ # for p,s in patts: tks = re.sub(p, s, tks)
193
+
194
+ # if split chars is part of token
195
+ res = []
196
+ tks = re.sub(r"[ ]+", " ", tks).split(" ")
197
+ s = 0
198
+ while True:
199
+ if s >= len(tks):
200
+ break
201
+ E = s + 1
202
+ for e in range(s + 2, min(len(tks) + 2, s + 6)):
203
+ tk = "".join(tks[s:e])
204
+ if re.search(self.SPLIT_CHAR, tk) and self.freq(tk):
205
+ E = e
206
+ res.append("".join(tks[s:E]))
207
+ s = E
208
+
209
+ return " ".join(res)
210
+
211
+ def maxForward_(self, line):
212
+ res = []
213
+ s = 0
214
+ while s < len(line):
215
+ e = s + 1
216
+ t = line[s:e]
217
+ while e < len(line) and self.trie_.has_keys_with_prefix(
218
+ self.key_(t)):
219
+ e += 1
220
+ t = line[s:e]
221
+
222
+ while e - 1 > s and self.key_(t) not in self.trie_:
223
+ e -= 1
224
+ t = line[s:e]
225
+
226
+ if self.key_(t) in self.trie_:
227
+ res.append((t, self.trie_[self.key_(t)]))
228
+ else:
229
+ res.append((t, (0, '')))
230
+
231
+ s = e
232
+
233
+ return self.score_(res)
234
+
235
+ def maxBackward_(self, line):
236
+ res = []
237
+ s = len(line) - 1
238
+ while s >= 0:
239
+ e = s + 1
240
+ t = line[s:e]
241
+ while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)):
242
+ s -= 1
243
+ t = line[s:e]
244
+
245
+ while s + 1 < e and self.key_(t) not in self.trie_:
246
+ s += 1
247
+ t = line[s:e]
248
+
249
+ if self.key_(t) in self.trie_:
250
+ res.append((t, self.trie_[self.key_(t)]))
251
+ else:
252
+ res.append((t, (0, '')))
253
+
254
+ s -= 1
255
+
256
+ return self.score_(res[::-1])
257
+
258
+ def english_normalize_(self, tks):
259
+ return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks]
260
+
261
+ def tokenize(self, line):
262
+ line = self._strQ2B(line).lower()
263
+ line = self._tradi2simp(line)
264
+ zh_num = len([1 for c in line if is_chinese(c)])
265
+ if zh_num == 0:
266
+ return " ".join([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(line)])
267
+
268
+ arr = re.split(self.SPLIT_CHAR, line)
269
+ res = []
270
+ for L in arr:
271
+ if len(L) < 2 or re.match(
272
+ r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
273
+ res.append(L)
274
+ continue
275
+ # print(L)
276
+
277
+ # use maxforward for the first time
278
+ tks, s = self.maxForward_(L)
279
+ tks1, s1 = self.maxBackward_(L)
280
+ if self.DEBUG:
281
+ print("[FW]", tks, s)
282
+ print("[BW]", tks1, s1)
283
+
284
+ diff = [0 for _ in range(max(len(tks1), len(tks)))]
285
+ for i in range(min(len(tks1), len(tks))):
286
+ if tks[i] != tks1[i]:
287
+ diff[i] = 1
288
+
289
+ if s1 > s:
290
+ tks = tks1
291
+
292
+ i = 0
293
+ while i < len(tks):
294
+ s = i
295
+ while s < len(tks) and diff[s] == 0:
296
+ s += 1
297
+ if s == len(tks):
298
+ res.append(" ".join(tks[i:]))
299
+ break
300
+ if s > i:
301
+ res.append(" ".join(tks[i:s]))
302
+
303
+ e = s
304
+ while e < len(tks) and e - s < 5 and diff[e] == 1:
305
+ e += 1
306
+
307
+ tkslist = []
308
+ self.dfs_("".join(tks[s:e + 1]), 0, [], tkslist)
309
+ res.append(" ".join(self.sortTks_(tkslist)[0][0]))
310
+
311
+ i = e + 1
312
+
313
+ res = " ".join(self.english_normalize_(res))
314
+ if self.DEBUG:
315
+ print("[TKS]", self.merge_(res))
316
+ return self.merge_(res)
317
+
318
+ def fine_grained_tokenize(self, tks):
319
+ tks = tks.split(" ")
320
+ zh_num = len([1 for c in tks if c and is_chinese(c[0])])
321
+ if zh_num < len(tks) * 0.2:
322
+ res = []
323
+ for tk in tks:
324
+ res.extend(tk.split("/"))
325
+ return " ".join(res)
326
+
327
+ res = []
328
+ for tk in tks:
329
+ if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
330
+ res.append(tk)
331
+ continue
332
+ tkslist = []
333
+ if len(tk) > 10:
334
+ tkslist.append(tk)
335
+ else:
336
+ self.dfs_(tk, 0, [], tkslist)
337
+ if len(tkslist) < 2:
338
+ res.append(tk)
339
+ continue
340
+ stk = self.sortTks_(tkslist)[1][0]
341
+ if len(stk) == len(tk):
342
+ stk = tk
343
+ else:
344
+ if re.match(r"[a-z\.-]+$", tk):
345
+ for t in stk:
346
+ if len(t) < 3:
347
+ stk = tk
348
+ break
349
+ else:
350
+ stk = " ".join(stk)
351
+ else:
352
+ stk = " ".join(stk)
353
+
354
+ res.append(stk)
355
+
356
+ return " ".join(self.english_normalize_(res))
357
+
358
+
359
+ def is_chinese(s):
360
+ if s >= u'\u4e00' and s <= u'\u9fa5':
361
+ return True
362
+ else:
363
+ return False
364
+
365
+
366
+ def is_number(s):
367
+ if s >= u'\u0030' and s <= u'\u0039':
368
+ return True
369
+ else:
370
+ return False
371
+
372
+
373
+ def is_alphabet(s):
374
+ if (s >= u'\u0041' and s <= u'\u005a') or (
375
+ s >= u'\u0061' and s <= u'\u007a'):
376
+ return True
377
+ else:
378
+ return False
379
+
380
+
381
+ def naiveQie(txt):
382
+ tks = []
383
+ for t in txt.split(" "):
384
+ if tks and re.match(r".*[a-zA-Z]$", tks[-1]
385
+ ) and re.match(r".*[a-zA-Z]$", t):
386
+ tks.append(" ")
387
+ tks.append(t)
388
+ return tks
389
+
390
+
391
+ tokenizer = RagTokenizer()
392
+ tokenize = tokenizer.tokenize
393
+ fine_grained_tokenize = tokenizer.fine_grained_tokenize
394
+ tag = tokenizer.tag
395
+ freq = tokenizer.freq
396
+ loadUserDict = tokenizer.loadUserDict
397
+ addUserDict = tokenizer.addUserDict
398
+ tradi2simp = tokenizer._tradi2simp
399
+ strQ2B = tokenizer._strQ2B
400
+
401
+ if __name__ == '__main__':
402
+ tknzr = RagTokenizer(debug=True)
403
+ # huqie.addUserDict("/tmp/tmp.new.tks.dict")
404
+ tks = tknzr.tokenize(
405
+ "哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
406
+ print(tknzr.fine_grained_tokenize(tks))
407
+ tks = tknzr.tokenize(
408
+ "公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。")
409
+ print(tknzr.fine_grained_tokenize(tks))
410
+ tks = tknzr.tokenize(
411
+ "多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
412
+ print(tknzr.fine_grained_tokenize(tks))
413
+ tks = tknzr.tokenize(
414
+ "实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
415
+ print(tknzr.fine_grained_tokenize(tks))
416
+ tks = tknzr.tokenize("虽然我不怎么玩")
417
+ print(tknzr.fine_grained_tokenize(tks))
418
+ tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
419
+ print(tknzr.fine_grained_tokenize(tks))
420
+ tks = tknzr.tokenize(
421
+ "涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
422
+ print(tknzr.fine_grained_tokenize(tks))
423
+ tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?")
424
+ print(tknzr.fine_grained_tokenize(tks))
425
+ tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
426
+ print(tknzr.fine_grained_tokenize(tks))
427
+ tks = tknzr.tokenize(
428
+ "数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
429
+ print(tknzr.fine_grained_tokenize(tks))
430
+ if len(sys.argv) < 2:
431
+ sys.exit()
432
+ tknzr.DEBUG = False
433
+ tknzr.loadUserDict(sys.argv[1])
434
+ of = open(sys.argv[2], "r")
435
+ while True:
436
+ line = of.readline()
437
+ if not line:
438
+ break
439
+ print(tknzr.tokenize(line))
440
+ of.close()