選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

rag_tokenizer.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import copy
  17. import datrie
  18. import math
  19. import os
  20. import re
  21. import string
  22. import sys
  23. from hanziconv import HanziConv
  24. from nltk import word_tokenize
  25. from nltk.stem import PorterStemmer, WordNetLemmatizer
  26. from api.utils.file_utils import get_project_base_directory
  27. from api.utils.log_utils import logger
  28. class RagTokenizer:
  29. def key_(self, line):
  30. return str(line.lower().encode("utf-8"))[2:-1]
  31. def rkey_(self, line):
  32. return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
  33. def loadDict_(self, fnm):
  34. logger.info(f"[HUQIE]:Build trie {fnm}")
  35. try:
  36. of = open(fnm, "r", encoding='utf-8')
  37. while True:
  38. line = of.readline()
  39. if not line:
  40. break
  41. line = re.sub(r"[\r\n]+", "", line)
  42. line = re.split(r"[ \t]", line)
  43. k = self.key_(line[0])
  44. F = int(math.log(float(line[1]) / self.DENOMINATOR) + .5)
  45. if k not in self.trie_ or self.trie_[k][0] < F:
  46. self.trie_[self.key_(line[0])] = (F, line[2])
  47. self.trie_[self.rkey_(line[0])] = 1
  48. self.trie_.save(fnm + ".trie")
  49. of.close()
  50. except Exception:
  51. logger.exception(f"[HUQIE]:Build trie {fnm} failed")
  52. def __init__(self, debug=False):
  53. self.DEBUG = debug
  54. self.DENOMINATOR = 1000000
  55. self.trie_ = datrie.Trie(string.printable)
  56. self.DIR_ = os.path.join(get_project_base_directory(), "rag/res", "huqie")
  57. self.stemmer = PorterStemmer()
  58. self.lemmatizer = WordNetLemmatizer()
  59. self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-z\.-]+|[0-9,\.-]+)"
  60. try:
  61. self.trie_ = datrie.Trie.load(self.DIR_ + ".txt.trie")
  62. return
  63. except Exception:
  64. logger.exception("[HUQIE]:Build default trie")
  65. self.trie_ = datrie.Trie(string.printable)
  66. self.loadDict_(self.DIR_ + ".txt")
  67. def loadUserDict(self, fnm):
  68. try:
  69. self.trie_ = datrie.Trie.load(fnm + ".trie")
  70. return
  71. except Exception:
  72. self.trie_ = datrie.Trie(string.printable)
  73. self.loadDict_(fnm)
  74. def addUserDict(self, fnm):
  75. self.loadDict_(fnm)
  76. def _strQ2B(self, ustring):
  77. """把字符串全角转半角"""
  78. rstring = ""
  79. for uchar in ustring:
  80. inside_code = ord(uchar)
  81. if inside_code == 0x3000:
  82. inside_code = 0x0020
  83. else:
  84. inside_code -= 0xfee0
  85. if inside_code < 0x0020 or inside_code > 0x7e: # 转完之后不是半角字符返回原来的字符
  86. rstring += uchar
  87. else:
  88. rstring += chr(inside_code)
  89. return rstring
  90. def _tradi2simp(self, line):
  91. return HanziConv.toSimplified(line)
  92. def dfs_(self, chars, s, preTks, tkslist):
  93. MAX_L = 10
  94. res = s
  95. # if s > MAX_L or s>= len(chars):
  96. if s >= len(chars):
  97. tkslist.append(preTks)
  98. return res
  99. # pruning
  100. S = s + 1
  101. if s + 2 <= len(chars):
  102. t1, t2 = "".join(chars[s:s + 1]), "".join(chars[s:s + 2])
  103. if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(
  104. self.key_(t2)):
  105. S = s + 2
  106. if len(preTks) > 2 and len(
  107. preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
  108. t1 = preTks[-1][0] + "".join(chars[s:s + 1])
  109. if self.trie_.has_keys_with_prefix(self.key_(t1)):
  110. S = s + 2
  111. ################
  112. for e in range(S, len(chars) + 1):
  113. t = "".join(chars[s:e])
  114. k = self.key_(t)
  115. if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
  116. break
  117. if k in self.trie_:
  118. pretks = copy.deepcopy(preTks)
  119. if k in self.trie_:
  120. pretks.append((t, self.trie_[k]))
  121. else:
  122. pretks.append((t, (-12, '')))
  123. res = max(res, self.dfs_(chars, e, pretks, tkslist))
  124. if res > s:
  125. return res
  126. t = "".join(chars[s:s + 1])
  127. k = self.key_(t)
  128. if k in self.trie_:
  129. preTks.append((t, self.trie_[k]))
  130. else:
  131. preTks.append((t, (-12, '')))
  132. return self.dfs_(chars, s + 1, preTks, tkslist)
  133. def freq(self, tk):
  134. k = self.key_(tk)
  135. if k not in self.trie_:
  136. return 0
  137. return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5)
  138. def tag(self, tk):
  139. k = self.key_(tk)
  140. if k not in self.trie_:
  141. return ""
  142. return self.trie_[k][1]
  143. def score_(self, tfts):
  144. B = 30
  145. F, L, tks = 0, 0, []
  146. for tk, (freq, tag) in tfts:
  147. F += freq
  148. L += 0 if len(tk) < 2 else 1
  149. tks.append(tk)
  150. F /= len(tks)
  151. L /= len(tks)
  152. logger.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
  153. return tks, B / len(tks) + L + F
  154. def sortTks_(self, tkslist):
  155. res = []
  156. for tfts in tkslist:
  157. tks, s = self.score_(tfts)
  158. res.append((tks, s))
  159. return sorted(res, key=lambda x: x[1], reverse=True)
  160. def merge_(self, tks):
  161. patts = [
  162. (r"[ ]+", " "),
  163. (r"([0-9\+\.,%\*=-]) ([0-9\+\.,%\*=-])", r"\1\2"),
  164. ]
  165. # for p,s in patts: tks = re.sub(p, s, tks)
  166. # if split chars is part of token
  167. res = []
  168. tks = re.sub(r"[ ]+", " ", tks).split(" ")
  169. s = 0
  170. while True:
  171. if s >= len(tks):
  172. break
  173. E = s + 1
  174. for e in range(s + 2, min(len(tks) + 2, s + 6)):
  175. tk = "".join(tks[s:e])
  176. if re.search(self.SPLIT_CHAR, tk) and self.freq(tk):
  177. E = e
  178. res.append("".join(tks[s:E]))
  179. s = E
  180. return " ".join(res)
  181. def maxForward_(self, line):
  182. res = []
  183. s = 0
  184. while s < len(line):
  185. e = s + 1
  186. t = line[s:e]
  187. while e < len(line) and self.trie_.has_keys_with_prefix(
  188. self.key_(t)):
  189. e += 1
  190. t = line[s:e]
  191. while e - 1 > s and self.key_(t) not in self.trie_:
  192. e -= 1
  193. t = line[s:e]
  194. if self.key_(t) in self.trie_:
  195. res.append((t, self.trie_[self.key_(t)]))
  196. else:
  197. res.append((t, (0, '')))
  198. s = e
  199. return self.score_(res)
  200. def maxBackward_(self, line):
  201. res = []
  202. s = len(line) - 1
  203. while s >= 0:
  204. e = s + 1
  205. t = line[s:e]
  206. while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)):
  207. s -= 1
  208. t = line[s:e]
  209. while s + 1 < e and self.key_(t) not in self.trie_:
  210. s += 1
  211. t = line[s:e]
  212. if self.key_(t) in self.trie_:
  213. res.append((t, self.trie_[self.key_(t)]))
  214. else:
  215. res.append((t, (0, '')))
  216. s -= 1
  217. return self.score_(res[::-1])
  218. def english_normalize_(self, tks):
  219. return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks]
  220. def tokenize(self, line):
  221. line = self._strQ2B(line).lower()
  222. line = self._tradi2simp(line)
  223. zh_num = len([1 for c in line if is_chinese(c)])
  224. if zh_num == 0:
  225. return " ".join([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(line)])
  226. arr = re.split(self.SPLIT_CHAR, line)
  227. res = []
  228. for L in arr:
  229. if len(L) < 2 or re.match(
  230. r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
  231. res.append(L)
  232. continue
  233. # print(L)
  234. # use maxforward for the first time
  235. tks, s = self.maxForward_(L)
  236. tks1, s1 = self.maxBackward_(L)
  237. if self.DEBUG:
  238. logger.debug("[FW] {} {}".format(tks, s))
  239. logger.debug("[BW] {} {}".format(tks1, s1))
  240. i, j, _i, _j = 0, 0, 0, 0
  241. same = 0
  242. while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
  243. same += 1
  244. if same > 0: res.append(" ".join(tks[j: j + same]))
  245. _i = i + same
  246. _j = j + same
  247. j = _j + 1
  248. i = _i + 1
  249. while i < len(tks1) and j < len(tks):
  250. tk1, tk = "".join(tks1[_i:i]), "".join(tks[_j:j])
  251. if tk1 != tk:
  252. if len(tk1) > len(tk):
  253. j += 1
  254. else:
  255. i += 1
  256. continue
  257. if tks1[i] != tks[j]:
  258. i += 1
  259. j += 1
  260. continue
  261. # backward tokens from_i to i are different from forward tokens from _j to j.
  262. tkslist = []
  263. self.dfs_("".join(tks[_j:j]), 0, [], tkslist)
  264. res.append(" ".join(self.sortTks_(tkslist)[0][0]))
  265. same = 1
  266. while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
  267. same += 1
  268. res.append(" ".join(tks[j: j + same]))
  269. _i = i + same
  270. _j = j + same
  271. j = _j + 1
  272. i = _i + 1
  273. if _i < len(tks1):
  274. assert _j < len(tks)
  275. assert "".join(tks1[_i:]) == "".join(tks[_j:])
  276. tkslist = []
  277. self.dfs_("".join(tks[_j:]), 0, [], tkslist)
  278. res.append(" ".join(self.sortTks_(tkslist)[0][0]))
  279. res = " ".join(self.english_normalize_(res))
  280. logger.debug("[TKS] {}".format(self.merge_(res)))
  281. return self.merge_(res)
  282. def fine_grained_tokenize(self, tks):
  283. tks = tks.split(" ")
  284. zh_num = len([1 for c in tks if c and is_chinese(c[0])])
  285. if zh_num < len(tks) * 0.2:
  286. res = []
  287. for tk in tks:
  288. res.extend(tk.split("/"))
  289. return " ".join(res)
  290. res = []
  291. for tk in tks:
  292. if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
  293. res.append(tk)
  294. continue
  295. tkslist = []
  296. if len(tk) > 10:
  297. tkslist.append(tk)
  298. else:
  299. self.dfs_(tk, 0, [], tkslist)
  300. if len(tkslist) < 2:
  301. res.append(tk)
  302. continue
  303. stk = self.sortTks_(tkslist)[1][0]
  304. if len(stk) == len(tk):
  305. stk = tk
  306. else:
  307. if re.match(r"[a-z\.-]+$", tk):
  308. for t in stk:
  309. if len(t) < 3:
  310. stk = tk
  311. break
  312. else:
  313. stk = " ".join(stk)
  314. else:
  315. stk = " ".join(stk)
  316. res.append(stk)
  317. return " ".join(self.english_normalize_(res))
  318. def is_chinese(s):
  319. if s >= u'\u4e00' and s <= u'\u9fa5':
  320. return True
  321. else:
  322. return False
  323. def is_number(s):
  324. if s >= u'\u0030' and s <= u'\u0039':
  325. return True
  326. else:
  327. return False
  328. def is_alphabet(s):
  329. if (s >= u'\u0041' and s <= u'\u005a') or (
  330. s >= u'\u0061' and s <= u'\u007a'):
  331. return True
  332. else:
  333. return False
  334. def naiveQie(txt):
  335. tks = []
  336. for t in txt.split(" "):
  337. if tks and re.match(r".*[a-zA-Z]$", tks[-1]
  338. ) and re.match(r".*[a-zA-Z]$", t):
  339. tks.append(" ")
  340. tks.append(t)
  341. return tks
  342. tokenizer = RagTokenizer()
  343. tokenize = tokenizer.tokenize
  344. fine_grained_tokenize = tokenizer.fine_grained_tokenize
  345. tag = tokenizer.tag
  346. freq = tokenizer.freq
  347. loadUserDict = tokenizer.loadUserDict
  348. addUserDict = tokenizer.addUserDict
  349. tradi2simp = tokenizer._tradi2simp
  350. strQ2B = tokenizer._strQ2B
  351. if __name__ == '__main__':
  352. tknzr = RagTokenizer(debug=True)
  353. # huqie.addUserDict("/tmp/tmp.new.tks.dict")
  354. tks = tknzr.tokenize(
  355. "哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
  356. logger.info(tknzr.fine_grained_tokenize(tks))
  357. tks = tknzr.tokenize(
  358. "公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。")
  359. logger.info(tknzr.fine_grained_tokenize(tks))
  360. tks = tknzr.tokenize(
  361. "多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
  362. logger.info(tknzr.fine_grained_tokenize(tks))
  363. tks = tknzr.tokenize(
  364. "实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
  365. logger.info(tknzr.fine_grained_tokenize(tks))
  366. tks = tknzr.tokenize("虽然我不怎么玩")
  367. logger.info(tknzr.fine_grained_tokenize(tks))
  368. tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
  369. logger.info(tknzr.fine_grained_tokenize(tks))
  370. tks = tknzr.tokenize(
  371. "涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
  372. logger.info(tknzr.fine_grained_tokenize(tks))
  373. tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?")
  374. logger.info(tknzr.fine_grained_tokenize(tks))
  375. tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
  376. logger.info(tknzr.fine_grained_tokenize(tks))
  377. tks = tknzr.tokenize(
  378. "数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
  379. logger.info(tknzr.fine_grained_tokenize(tks))
  380. if len(sys.argv) < 2:
  381. sys.exit()
  382. tknzr.DEBUG = False
  383. tknzr.loadUserDict(sys.argv[1])
  384. of = open(sys.argv[2], "r")
  385. while True:
  386. line = of.readline()
  387. if not line:
  388. break
  389. logger.info(tknzr.tokenize(line))
  390. of.close()