Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

rag_tokenizer.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import copy
  18. import datrie
  19. import math
  20. import os
  21. import re
  22. import string
  23. import sys
  24. from hanziconv import HanziConv
  25. from nltk import word_tokenize
  26. from nltk.stem import PorterStemmer, WordNetLemmatizer
  27. from api.utils.file_utils import get_project_base_directory
  28. class RagTokenizer:
  29. def key_(self, line):
  30. return str(line.lower().encode("utf-8"))[2:-1]
  31. def rkey_(self, line):
  32. return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
  33. def loadDict_(self, fnm):
  34. logging.info(f"[HUQIE]:Build trie from {fnm}")
  35. try:
  36. of = open(fnm, "r", encoding='utf-8')
  37. while True:
  38. line = of.readline()
  39. if not line:
  40. break
  41. line = re.sub(r"[\r\n]+", "", line)
  42. line = re.split(r"[ \t]", line)
  43. k = self.key_(line[0])
  44. F = int(math.log(float(line[1]) / self.DENOMINATOR) + .5)
  45. if k not in self.trie_ or self.trie_[k][0] < F:
  46. self.trie_[self.key_(line[0])] = (F, line[2])
  47. self.trie_[self.rkey_(line[0])] = 1
  48. dict_file_cache = fnm + ".trie"
  49. logging.info(f"[HUQIE]:Build trie cache to {dict_file_cache}")
  50. self.trie_.save(dict_file_cache)
  51. of.close()
  52. except Exception:
  53. logging.exception(f"[HUQIE]:Build trie {fnm} failed")
  54. def __init__(self, debug=False):
  55. self.DEBUG = debug
  56. self.DENOMINATOR = 1000000
  57. self.DIR_ = os.path.join(get_project_base_directory(), "rag/res", "huqie")
  58. self.stemmer = PorterStemmer()
  59. self.lemmatizer = WordNetLemmatizer()
  60. self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-z\.-]+|[0-9,\.-]+)"
  61. trie_file_name = self.DIR_ + ".txt.trie"
  62. # check if trie file existence
  63. if os.path.exists(trie_file_name):
  64. try:
  65. # load trie from file
  66. self.trie_ = datrie.Trie.load(trie_file_name)
  67. return
  68. except Exception:
  69. # fail to load trie from file, build default trie
  70. logging.exception(f"[HUQIE]:Fail to load trie file {trie_file_name}, build the default trie file")
  71. self.trie_ = datrie.Trie(string.printable)
  72. else:
  73. # file not exist, build default trie
  74. logging.info(f"[HUQIE]:Trie file {trie_file_name} not found, build the default trie file")
  75. self.trie_ = datrie.Trie(string.printable)
  76. # load data from dict file and save to trie file
  77. self.loadDict_(self.DIR_ + ".txt")
  78. def loadUserDict(self, fnm):
  79. try:
  80. self.trie_ = datrie.Trie.load(fnm + ".trie")
  81. return
  82. except Exception:
  83. self.trie_ = datrie.Trie(string.printable)
  84. self.loadDict_(fnm)
  85. def addUserDict(self, fnm):
  86. self.loadDict_(fnm)
  87. def _strQ2B(self, ustring):
  88. """Convert full-width characters to half-width characters"""
  89. rstring = ""
  90. for uchar in ustring:
  91. inside_code = ord(uchar)
  92. if inside_code == 0x3000:
  93. inside_code = 0x0020
  94. else:
  95. inside_code -= 0xfee0
  96. if inside_code < 0x0020 or inside_code > 0x7e: # After the conversion, if it's not a half-width character, return the original character.
  97. rstring += uchar
  98. else:
  99. rstring += chr(inside_code)
  100. return rstring
  101. def _tradi2simp(self, line):
  102. return HanziConv.toSimplified(line)
  103. def dfs_(self, chars, s, preTks, tkslist):
  104. res = s
  105. # if s > MAX_L or s>= len(chars):
  106. if s >= len(chars):
  107. tkslist.append(preTks)
  108. return res
  109. # pruning
  110. S = s + 1
  111. if s + 2 <= len(chars):
  112. t1, t2 = "".join(chars[s:s + 1]), "".join(chars[s:s + 2])
  113. if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(
  114. self.key_(t2)):
  115. S = s + 2
  116. if len(preTks) > 2 and len(
  117. preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
  118. t1 = preTks[-1][0] + "".join(chars[s:s + 1])
  119. if self.trie_.has_keys_with_prefix(self.key_(t1)):
  120. S = s + 2
  121. ################
  122. for e in range(S, len(chars) + 1):
  123. t = "".join(chars[s:e])
  124. k = self.key_(t)
  125. if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
  126. break
  127. if k in self.trie_:
  128. pretks = copy.deepcopy(preTks)
  129. if k in self.trie_:
  130. pretks.append((t, self.trie_[k]))
  131. else:
  132. pretks.append((t, (-12, '')))
  133. res = max(res, self.dfs_(chars, e, pretks, tkslist))
  134. if res > s:
  135. return res
  136. t = "".join(chars[s:s + 1])
  137. k = self.key_(t)
  138. if k in self.trie_:
  139. preTks.append((t, self.trie_[k]))
  140. else:
  141. preTks.append((t, (-12, '')))
  142. return self.dfs_(chars, s + 1, preTks, tkslist)
  143. def freq(self, tk):
  144. k = self.key_(tk)
  145. if k not in self.trie_:
  146. return 0
  147. return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5)
  148. def tag(self, tk):
  149. k = self.key_(tk)
  150. if k not in self.trie_:
  151. return ""
  152. return self.trie_[k][1]
  153. def score_(self, tfts):
  154. B = 30
  155. F, L, tks = 0, 0, []
  156. for tk, (freq, tag) in tfts:
  157. F += freq
  158. L += 0 if len(tk) < 2 else 1
  159. tks.append(tk)
  160. #F /= len(tks)
  161. L /= len(tks)
  162. logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
  163. return tks, B / len(tks) + L + F
  164. def sortTks_(self, tkslist):
  165. res = []
  166. for tfts in tkslist:
  167. tks, s = self.score_(tfts)
  168. res.append((tks, s))
  169. return sorted(res, key=lambda x: x[1], reverse=True)
  170. def merge_(self, tks):
  171. # if split chars is part of token
  172. res = []
  173. tks = re.sub(r"[ ]+", " ", tks).split()
  174. s = 0
  175. while True:
  176. if s >= len(tks):
  177. break
  178. E = s + 1
  179. for e in range(s + 2, min(len(tks) + 2, s + 6)):
  180. tk = "".join(tks[s:e])
  181. if re.search(self.SPLIT_CHAR, tk) and self.freq(tk):
  182. E = e
  183. res.append("".join(tks[s:E]))
  184. s = E
  185. return " ".join(res)
  186. def maxForward_(self, line):
  187. res = []
  188. s = 0
  189. while s < len(line):
  190. e = s + 1
  191. t = line[s:e]
  192. while e < len(line) and self.trie_.has_keys_with_prefix(
  193. self.key_(t)):
  194. e += 1
  195. t = line[s:e]
  196. while e - 1 > s and self.key_(t) not in self.trie_:
  197. e -= 1
  198. t = line[s:e]
  199. if self.key_(t) in self.trie_:
  200. res.append((t, self.trie_[self.key_(t)]))
  201. else:
  202. res.append((t, (0, '')))
  203. s = e
  204. return self.score_(res)
  205. def maxBackward_(self, line):
  206. res = []
  207. s = len(line) - 1
  208. while s >= 0:
  209. e = s + 1
  210. t = line[s:e]
  211. while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)):
  212. s -= 1
  213. t = line[s:e]
  214. while s + 1 < e and self.key_(t) not in self.trie_:
  215. s += 1
  216. t = line[s:e]
  217. if self.key_(t) in self.trie_:
  218. res.append((t, self.trie_[self.key_(t)]))
  219. else:
  220. res.append((t, (0, '')))
  221. s -= 1
  222. return self.score_(res[::-1])
  223. def english_normalize_(self, tks):
  224. return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks]
  225. def tokenize(self, line):
  226. line = self._strQ2B(line).lower()
  227. line = self._tradi2simp(line)
  228. zh_num = len([1 for c in line if is_chinese(c)])
  229. if zh_num == 0:
  230. return " ".join([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(line)])
  231. arr = re.split(self.SPLIT_CHAR, line)
  232. res = []
  233. for L in arr:
  234. if len(L) < 2 or re.match(
  235. r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
  236. res.append(L)
  237. continue
  238. # print(L)
  239. # use maxforward for the first time
  240. tks, s = self.maxForward_(L)
  241. tks1, s1 = self.maxBackward_(L)
  242. if self.DEBUG:
  243. logging.debug("[FW] {} {}".format(tks, s))
  244. logging.debug("[BW] {} {}".format(tks1, s1))
  245. i, j, _i, _j = 0, 0, 0, 0
  246. same = 0
  247. while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
  248. same += 1
  249. if same > 0:
  250. res.append(" ".join(tks[j: j + same]))
  251. _i = i + same
  252. _j = j + same
  253. j = _j + 1
  254. i = _i + 1
  255. while i < len(tks1) and j < len(tks):
  256. tk1, tk = "".join(tks1[_i:i]), "".join(tks[_j:j])
  257. if tk1 != tk:
  258. if len(tk1) > len(tk):
  259. j += 1
  260. else:
  261. i += 1
  262. continue
  263. if tks1[i] != tks[j]:
  264. i += 1
  265. j += 1
  266. continue
  267. # backward tokens from_i to i are different from forward tokens from _j to j.
  268. tkslist = []
  269. self.dfs_("".join(tks[_j:j]), 0, [], tkslist)
  270. res.append(" ".join(self.sortTks_(tkslist)[0][0]))
  271. same = 1
  272. while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
  273. same += 1
  274. res.append(" ".join(tks[j: j + same]))
  275. _i = i + same
  276. _j = j + same
  277. j = _j + 1
  278. i = _i + 1
  279. if _i < len(tks1):
  280. assert _j < len(tks)
  281. assert "".join(tks1[_i:]) == "".join(tks[_j:])
  282. tkslist = []
  283. self.dfs_("".join(tks[_j:]), 0, [], tkslist)
  284. res.append(" ".join(self.sortTks_(tkslist)[0][0]))
  285. res = " ".join(self.english_normalize_(res))
  286. logging.debug("[TKS] {}".format(self.merge_(res)))
  287. return self.merge_(res)
  288. def fine_grained_tokenize(self, tks):
  289. tks = tks.split()
  290. zh_num = len([1 for c in tks if c and is_chinese(c[0])])
  291. if zh_num < len(tks) * 0.2:
  292. res = []
  293. for tk in tks:
  294. res.extend(tk.split("/"))
  295. return " ".join(res)
  296. res = []
  297. for tk in tks:
  298. if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
  299. res.append(tk)
  300. continue
  301. tkslist = []
  302. if len(tk) > 10:
  303. tkslist.append(tk)
  304. else:
  305. self.dfs_(tk, 0, [], tkslist)
  306. if len(tkslist) < 2:
  307. res.append(tk)
  308. continue
  309. stk = self.sortTks_(tkslist)[1][0]
  310. if len(stk) == len(tk):
  311. stk = tk
  312. else:
  313. if re.match(r"[a-z\.-]+$", tk):
  314. for t in stk:
  315. if len(t) < 3:
  316. stk = tk
  317. break
  318. else:
  319. stk = " ".join(stk)
  320. else:
  321. stk = " ".join(stk)
  322. res.append(stk)
  323. return " ".join(self.english_normalize_(res))
  324. def is_chinese(s):
  325. if s >= u'\u4e00' and s <= u'\u9fa5':
  326. return True
  327. else:
  328. return False
  329. def is_number(s):
  330. if s >= u'\u0030' and s <= u'\u0039':
  331. return True
  332. else:
  333. return False
  334. def is_alphabet(s):
  335. if (s >= u'\u0041' and s <= u'\u005a') or (
  336. s >= u'\u0061' and s <= u'\u007a'):
  337. return True
  338. else:
  339. return False
  340. def naiveQie(txt):
  341. tks = []
  342. for t in txt.split():
  343. if tks and re.match(r".*[a-zA-Z]$", tks[-1]
  344. ) and re.match(r".*[a-zA-Z]$", t):
  345. tks.append(" ")
  346. tks.append(t)
  347. return tks
  348. tokenizer = RagTokenizer()
  349. tokenize = tokenizer.tokenize
  350. fine_grained_tokenize = tokenizer.fine_grained_tokenize
  351. tag = tokenizer.tag
  352. freq = tokenizer.freq
  353. loadUserDict = tokenizer.loadUserDict
  354. addUserDict = tokenizer.addUserDict
  355. tradi2simp = tokenizer._tradi2simp
  356. strQ2B = tokenizer._strQ2B
  357. if __name__ == '__main__':
  358. tknzr = RagTokenizer(debug=True)
  359. # huqie.addUserDict("/tmp/tmp.new.tks.dict")
  360. tks = tknzr.tokenize(
  361. "哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
  362. logging.info(tknzr.fine_grained_tokenize(tks))
  363. tks = tknzr.tokenize(
  364. "公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。")
  365. logging.info(tknzr.fine_grained_tokenize(tks))
  366. tks = tknzr.tokenize(
  367. "多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
  368. logging.info(tknzr.fine_grained_tokenize(tks))
  369. tks = tknzr.tokenize(
  370. "实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
  371. logging.info(tknzr.fine_grained_tokenize(tks))
  372. tks = tknzr.tokenize("虽然我不怎么玩")
  373. logging.info(tknzr.fine_grained_tokenize(tks))
  374. tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
  375. logging.info(tknzr.fine_grained_tokenize(tks))
  376. tks = tknzr.tokenize(
  377. "涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
  378. logging.info(tknzr.fine_grained_tokenize(tks))
  379. tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?")
  380. logging.info(tknzr.fine_grained_tokenize(tks))
  381. tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
  382. logging.info(tknzr.fine_grained_tokenize(tks))
  383. tks = tknzr.tokenize(
  384. "数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
  385. logging.info(tknzr.fine_grained_tokenize(tks))
  386. if len(sys.argv) < 2:
  387. sys.exit()
  388. tknzr.DEBUG = False
  389. tknzr.loadUserDict(sys.argv[1])
  390. of = open(sys.argv[2], "r")
  391. while True:
  392. line = of.readline()
  393. if not line:
  394. break
  395. logging.info(tknzr.tokenize(line))
  396. of.close()