Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

rag_tokenizer.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import copy
  18. import datrie
  19. import math
  20. import os
  21. import re
  22. import string
  23. import sys
  24. from hanziconv import HanziConv
  25. from nltk import word_tokenize
  26. from nltk.stem import PorterStemmer, WordNetLemmatizer
  27. from api.utils.file_utils import get_project_base_directory
  28. class RagTokenizer:
  29. def key_(self, line):
  30. return str(line.lower().encode("utf-8"))[2:-1]
  31. def rkey_(self, line):
  32. return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
  33. def loadDict_(self, fnm):
  34. logging.info(f"[HUQIE]:Build trie from {fnm}")
  35. try:
  36. of = open(fnm, "r", encoding='utf-8')
  37. while True:
  38. line = of.readline()
  39. if not line:
  40. break
  41. line = re.sub(r"[\r\n]+", "", line)
  42. line = re.split(r"[ \t]", line)
  43. k = self.key_(line[0])
  44. F = int(math.log(float(line[1]) / self.DENOMINATOR) + .5)
  45. if k not in self.trie_ or self.trie_[k][0] < F:
  46. self.trie_[self.key_(line[0])] = (F, line[2])
  47. self.trie_[self.rkey_(line[0])] = 1
  48. dict_file_cache = fnm + ".trie"
  49. logging.info(f"[HUQIE]:Build trie cache to {dict_file_cache}")
  50. self.trie_.save(dict_file_cache)
  51. of.close()
  52. except Exception:
  53. logging.exception(f"[HUQIE]:Build trie {fnm} failed")
  54. def __init__(self, debug=False):
  55. self.DEBUG = debug
  56. self.DENOMINATOR = 1000000
  57. self.DIR_ = os.path.join(get_project_base_directory(), "rag/res", "huqie")
  58. self.stemmer = PorterStemmer()
  59. self.lemmatizer = WordNetLemmatizer()
  60. self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-zA-Z0-9,\.-]+)"
  61. trie_file_name = self.DIR_ + ".txt.trie"
  62. # check if trie file existence
  63. if os.path.exists(trie_file_name):
  64. try:
  65. # load trie from file
  66. self.trie_ = datrie.Trie.load(trie_file_name)
  67. return
  68. except Exception:
  69. # fail to load trie from file, build default trie
  70. logging.exception(f"[HUQIE]:Fail to load trie file {trie_file_name}, build the default trie file")
  71. self.trie_ = datrie.Trie(string.printable)
  72. else:
  73. # file not exist, build default trie
  74. logging.info(f"[HUQIE]:Trie file {trie_file_name} not found, build the default trie file")
  75. self.trie_ = datrie.Trie(string.printable)
  76. # load data from dict file and save to trie file
  77. self.loadDict_(self.DIR_ + ".txt")
  78. def loadUserDict(self, fnm):
  79. try:
  80. self.trie_ = datrie.Trie.load(fnm + ".trie")
  81. return
  82. except Exception:
  83. self.trie_ = datrie.Trie(string.printable)
  84. self.loadDict_(fnm)
  85. def addUserDict(self, fnm):
  86. self.loadDict_(fnm)
  87. def _strQ2B(self, ustring):
  88. """Convert full-width characters to half-width characters"""
  89. rstring = ""
  90. for uchar in ustring:
  91. inside_code = ord(uchar)
  92. if inside_code == 0x3000:
  93. inside_code = 0x0020
  94. else:
  95. inside_code -= 0xfee0
  96. if inside_code < 0x0020 or inside_code > 0x7e: # After the conversion, if it's not a half-width character, return the original character.
  97. rstring += uchar
  98. else:
  99. rstring += chr(inside_code)
  100. return rstring
  101. def _tradi2simp(self, line):
  102. return HanziConv.toSimplified(line)
  103. def dfs_(self, chars, s, preTks, tkslist):
  104. res = s
  105. if len(tkslist) >= 2048:
  106. return res
  107. # if s > MAX_L or s>= len(chars):
  108. if s >= len(chars):
  109. tkslist.append(preTks)
  110. return res
  111. # pruning
  112. S = s + 1
  113. if s + 2 <= len(chars):
  114. t1, t2 = "".join(chars[s:s + 1]), "".join(chars[s:s + 2])
  115. if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(
  116. self.key_(t2)):
  117. S = s + 2
  118. if len(preTks) > 2 and len(
  119. preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
  120. t1 = preTks[-1][0] + "".join(chars[s:s + 1])
  121. if self.trie_.has_keys_with_prefix(self.key_(t1)):
  122. S = s + 2
  123. ################
  124. for e in range(S, len(chars) + 1):
  125. t = "".join(chars[s:e])
  126. k = self.key_(t)
  127. if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
  128. break
  129. if k in self.trie_:
  130. pretks = copy.deepcopy(preTks)
  131. if k in self.trie_:
  132. pretks.append((t, self.trie_[k]))
  133. else:
  134. pretks.append((t, (-12, '')))
  135. res = max(res, self.dfs_(chars, e, pretks, tkslist))
  136. if res > s:
  137. return res
  138. t = "".join(chars[s:s + 1])
  139. k = self.key_(t)
  140. if k in self.trie_:
  141. preTks.append((t, self.trie_[k]))
  142. else:
  143. preTks.append((t, (-12, '')))
  144. return self.dfs_(chars, s + 1, preTks, tkslist)
  145. def freq(self, tk):
  146. k = self.key_(tk)
  147. if k not in self.trie_:
  148. return 0
  149. return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5)
  150. def tag(self, tk):
  151. k = self.key_(tk)
  152. if k not in self.trie_:
  153. return ""
  154. return self.trie_[k][1]
  155. def score_(self, tfts):
  156. B = 30
  157. F, L, tks = 0, 0, []
  158. for tk, (freq, tag) in tfts:
  159. F += freq
  160. L += 0 if len(tk) < 2 else 1
  161. tks.append(tk)
  162. #F /= len(tks)
  163. L /= len(tks)
  164. logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
  165. return tks, B / len(tks) + L + F
  166. def sortTks_(self, tkslist):
  167. res = []
  168. for tfts in tkslist:
  169. tks, s = self.score_(tfts)
  170. res.append((tks, s))
  171. return sorted(res, key=lambda x: x[1], reverse=True)
  172. def merge_(self, tks):
  173. # if split chars is part of token
  174. res = []
  175. tks = re.sub(r"[ ]+", " ", tks).split()
  176. s = 0
  177. while True:
  178. if s >= len(tks):
  179. break
  180. E = s + 1
  181. for e in range(s + 2, min(len(tks) + 2, s + 6)):
  182. tk = "".join(tks[s:e])
  183. if re.search(self.SPLIT_CHAR, tk) and self.freq(tk):
  184. E = e
  185. res.append("".join(tks[s:E]))
  186. s = E
  187. return " ".join(res)
  188. def maxForward_(self, line):
  189. res = []
  190. s = 0
  191. while s < len(line):
  192. e = s + 1
  193. t = line[s:e]
  194. while e < len(line) and self.trie_.has_keys_with_prefix(
  195. self.key_(t)):
  196. e += 1
  197. t = line[s:e]
  198. while e - 1 > s and self.key_(t) not in self.trie_:
  199. e -= 1
  200. t = line[s:e]
  201. if self.key_(t) in self.trie_:
  202. res.append((t, self.trie_[self.key_(t)]))
  203. else:
  204. res.append((t, (0, '')))
  205. s = e
  206. return self.score_(res)
  207. def maxBackward_(self, line):
  208. res = []
  209. s = len(line) - 1
  210. while s >= 0:
  211. e = s + 1
  212. t = line[s:e]
  213. while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)):
  214. s -= 1
  215. t = line[s:e]
  216. while s + 1 < e and self.key_(t) not in self.trie_:
  217. s += 1
  218. t = line[s:e]
  219. if self.key_(t) in self.trie_:
  220. res.append((t, self.trie_[self.key_(t)]))
  221. else:
  222. res.append((t, (0, '')))
  223. s -= 1
  224. return self.score_(res[::-1])
  225. def english_normalize_(self, tks):
  226. return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks]
  227. def _split_by_lang(self, line):
  228. txt_lang_pairs = []
  229. arr = re.split(self.SPLIT_CHAR, line)
  230. for a in arr:
  231. if not a:
  232. continue
  233. s = 0
  234. e = s + 1
  235. zh = is_chinese(a[s])
  236. while e < len(a):
  237. _zh = is_chinese(a[e])
  238. if _zh == zh:
  239. e += 1
  240. continue
  241. txt_lang_pairs.append((a[s: e], zh))
  242. s = e
  243. e = s + 1
  244. zh = _zh
  245. if s >= len(a):
  246. continue
  247. txt_lang_pairs.append((a[s: e], zh))
  248. return txt_lang_pairs
  249. def tokenize(self, line):
  250. line = re.sub(r"\W+", " ", line)
  251. line = self._strQ2B(line).lower()
  252. line = self._tradi2simp(line)
  253. arr = self._split_by_lang(line)
  254. res = []
  255. for L,lang in arr:
  256. if not lang:
  257. res.extend([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(L)])
  258. continue
  259. if len(L) < 2 or re.match(
  260. r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
  261. res.append(L)
  262. continue
  263. # use maxforward for the first time
  264. tks, s = self.maxForward_(L)
  265. tks1, s1 = self.maxBackward_(L)
  266. if self.DEBUG:
  267. logging.debug("[FW] {} {}".format(tks, s))
  268. logging.debug("[BW] {} {}".format(tks1, s1))
  269. i, j, _i, _j = 0, 0, 0, 0
  270. same = 0
  271. while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
  272. same += 1
  273. if same > 0:
  274. res.append(" ".join(tks[j: j + same]))
  275. _i = i + same
  276. _j = j + same
  277. j = _j + 1
  278. i = _i + 1
  279. while i < len(tks1) and j < len(tks):
  280. tk1, tk = "".join(tks1[_i:i]), "".join(tks[_j:j])
  281. if tk1 != tk:
  282. if len(tk1) > len(tk):
  283. j += 1
  284. else:
  285. i += 1
  286. continue
  287. if tks1[i] != tks[j]:
  288. i += 1
  289. j += 1
  290. continue
  291. # backward tokens from_i to i are different from forward tokens from _j to j.
  292. tkslist = []
  293. self.dfs_("".join(tks[_j:j]), 0, [], tkslist)
  294. res.append(" ".join(self.sortTks_(tkslist)[0][0]))
  295. same = 1
  296. while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
  297. same += 1
  298. res.append(" ".join(tks[j: j + same]))
  299. _i = i + same
  300. _j = j + same
  301. j = _j + 1
  302. i = _i + 1
  303. if _i < len(tks1):
  304. assert _j < len(tks)
  305. assert "".join(tks1[_i:]) == "".join(tks[_j:])
  306. tkslist = []
  307. self.dfs_("".join(tks[_j:]), 0, [], tkslist)
  308. res.append(" ".join(self.sortTks_(tkslist)[0][0]))
  309. res = " ".join(res)
  310. logging.debug("[TKS] {}".format(self.merge_(res)))
  311. return self.merge_(res)
  312. def fine_grained_tokenize(self, tks):
  313. tks = tks.split()
  314. zh_num = len([1 for c in tks if c and is_chinese(c[0])])
  315. if zh_num < len(tks) * 0.2:
  316. res = []
  317. for tk in tks:
  318. res.extend(tk.split("/"))
  319. return " ".join(res)
  320. res = []
  321. for tk in tks:
  322. if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
  323. res.append(tk)
  324. continue
  325. tkslist = []
  326. if len(tk) > 10:
  327. tkslist.append(tk)
  328. else:
  329. self.dfs_(tk, 0, [], tkslist)
  330. if len(tkslist) < 2:
  331. res.append(tk)
  332. continue
  333. stk = self.sortTks_(tkslist)[1][0]
  334. if len(stk) == len(tk):
  335. stk = tk
  336. else:
  337. if re.match(r"[a-z\.-]+$", tk):
  338. for t in stk:
  339. if len(t) < 3:
  340. stk = tk
  341. break
  342. else:
  343. stk = " ".join(stk)
  344. else:
  345. stk = " ".join(stk)
  346. res.append(stk)
  347. return " ".join(self.english_normalize_(res))
  348. def is_chinese(s):
  349. if s >= u'\u4e00' and s <= u'\u9fa5':
  350. return True
  351. else:
  352. return False
  353. def is_number(s):
  354. if s >= u'\u0030' and s <= u'\u0039':
  355. return True
  356. else:
  357. return False
  358. def is_alphabet(s):
  359. if (s >= u'\u0041' and s <= u'\u005a') or (
  360. s >= u'\u0061' and s <= u'\u007a'):
  361. return True
  362. else:
  363. return False
  364. def naiveQie(txt):
  365. tks = []
  366. for t in txt.split():
  367. if tks and re.match(r".*[a-zA-Z]$", tks[-1]
  368. ) and re.match(r".*[a-zA-Z]$", t):
  369. tks.append(" ")
  370. tks.append(t)
  371. return tks
  372. tokenizer = RagTokenizer()
  373. tokenize = tokenizer.tokenize
  374. fine_grained_tokenize = tokenizer.fine_grained_tokenize
  375. tag = tokenizer.tag
  376. freq = tokenizer.freq
  377. loadUserDict = tokenizer.loadUserDict
  378. addUserDict = tokenizer.addUserDict
  379. tradi2simp = tokenizer._tradi2simp
  380. strQ2B = tokenizer._strQ2B
  381. if __name__ == '__main__':
  382. tknzr = RagTokenizer(debug=True)
  383. # huqie.addUserDict("/tmp/tmp.new.tks.dict")
  384. tks = tknzr.tokenize(
  385. "哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
  386. logging.info(tknzr.fine_grained_tokenize(tks))
  387. tks = tknzr.tokenize(
  388. "公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。")
  389. logging.info(tknzr.fine_grained_tokenize(tks))
  390. tks = tknzr.tokenize(
  391. "多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
  392. logging.info(tknzr.fine_grained_tokenize(tks))
  393. tks = tknzr.tokenize(
  394. "实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
  395. logging.info(tknzr.fine_grained_tokenize(tks))
  396. tks = tknzr.tokenize("虽然我不怎么玩")
  397. logging.info(tknzr.fine_grained_tokenize(tks))
  398. tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
  399. logging.info(tknzr.fine_grained_tokenize(tks))
  400. tks = tknzr.tokenize(
  401. "涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
  402. logging.info(tknzr.fine_grained_tokenize(tks))
  403. tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?")
  404. logging.info(tknzr.fine_grained_tokenize(tks))
  405. tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
  406. logging.info(tknzr.fine_grained_tokenize(tks))
  407. tks = tknzr.tokenize(
  408. "数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
  409. logging.info(tknzr.fine_grained_tokenize(tks))
  410. if len(sys.argv) < 2:
  411. sys.exit()
  412. tknzr.DEBUG = False
  413. tknzr.loadUserDict(sys.argv[1])
  414. of = open(sys.argv[2], "r")
  415. while True:
  416. line = of.readline()
  417. if not line:
  418. break
  419. logging.info(tknzr.tokenize(line))
  420. of.close()