Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

rag_tokenizer.py 15KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import copy
  18. import datrie
  19. import math
  20. import os
  21. import re
  22. import string
  23. import sys
  24. from hanziconv import HanziConv
  25. from nltk import word_tokenize
  26. from nltk.stem import PorterStemmer, WordNetLemmatizer
  27. from api.utils.file_utils import get_project_base_directory
  28. class RagTokenizer:
  29. def key_(self, line):
  30. return str(line.lower().encode("utf-8"))[2:-1]
  31. def rkey_(self, line):
  32. return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
  33. def loadDict_(self, fnm):
  34. logging.info(f"[HUQIE]:Build trie {fnm}")
  35. try:
  36. of = open(fnm, "r", encoding='utf-8')
  37. while True:
  38. line = of.readline()
  39. if not line:
  40. break
  41. line = re.sub(r"[\r\n]+", "", line)
  42. line = re.split(r"[ \t]", line)
  43. k = self.key_(line[0])
  44. F = int(math.log(float(line[1]) / self.DENOMINATOR) + .5)
  45. if k not in self.trie_ or self.trie_[k][0] < F:
  46. self.trie_[self.key_(line[0])] = (F, line[2])
  47. self.trie_[self.rkey_(line[0])] = 1
  48. self.trie_.save(fnm + ".trie")
  49. of.close()
  50. except Exception:
  51. logging.exception(f"[HUQIE]:Build trie {fnm} failed")
  52. def __init__(self, debug=False):
  53. self.DEBUG = debug
  54. self.DENOMINATOR = 1000000
  55. self.trie_ = datrie.Trie(string.printable)
  56. self.DIR_ = os.path.join(get_project_base_directory(), "rag/res", "huqie")
  57. self.stemmer = PorterStemmer()
  58. self.lemmatizer = WordNetLemmatizer()
  59. self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-z\.-]+|[0-9,\.-]+)"
  60. try:
  61. self.trie_ = datrie.Trie.load(self.DIR_ + ".txt.trie")
  62. return
  63. except Exception:
  64. logging.exception("[HUQIE]:Build default trie")
  65. self.trie_ = datrie.Trie(string.printable)
  66. self.loadDict_(self.DIR_ + ".txt")
  67. def loadUserDict(self, fnm):
  68. try:
  69. self.trie_ = datrie.Trie.load(fnm + ".trie")
  70. return
  71. except Exception:
  72. self.trie_ = datrie.Trie(string.printable)
  73. self.loadDict_(fnm)
  74. def addUserDict(self, fnm):
  75. self.loadDict_(fnm)
  76. def _strQ2B(self, ustring):
  77. """把字符串全角转半角"""
  78. rstring = ""
  79. for uchar in ustring:
  80. inside_code = ord(uchar)
  81. if inside_code == 0x3000:
  82. inside_code = 0x0020
  83. else:
  84. inside_code -= 0xfee0
  85. if inside_code < 0x0020 or inside_code > 0x7e: # 转完之后不是半角字符返回原来的字符
  86. rstring += uchar
  87. else:
  88. rstring += chr(inside_code)
  89. return rstring
  90. def _tradi2simp(self, line):
  91. return HanziConv.toSimplified(line)
  92. def dfs_(self, chars, s, preTks, tkslist):
  93. res = s
  94. # if s > MAX_L or s>= len(chars):
  95. if s >= len(chars):
  96. tkslist.append(preTks)
  97. return res
  98. # pruning
  99. S = s + 1
  100. if s + 2 <= len(chars):
  101. t1, t2 = "".join(chars[s:s + 1]), "".join(chars[s:s + 2])
  102. if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(
  103. self.key_(t2)):
  104. S = s + 2
  105. if len(preTks) > 2 and len(
  106. preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
  107. t1 = preTks[-1][0] + "".join(chars[s:s + 1])
  108. if self.trie_.has_keys_with_prefix(self.key_(t1)):
  109. S = s + 2
  110. ################
  111. for e in range(S, len(chars) + 1):
  112. t = "".join(chars[s:e])
  113. k = self.key_(t)
  114. if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
  115. break
  116. if k in self.trie_:
  117. pretks = copy.deepcopy(preTks)
  118. if k in self.trie_:
  119. pretks.append((t, self.trie_[k]))
  120. else:
  121. pretks.append((t, (-12, '')))
  122. res = max(res, self.dfs_(chars, e, pretks, tkslist))
  123. if res > s:
  124. return res
  125. t = "".join(chars[s:s + 1])
  126. k = self.key_(t)
  127. if k in self.trie_:
  128. preTks.append((t, self.trie_[k]))
  129. else:
  130. preTks.append((t, (-12, '')))
  131. return self.dfs_(chars, s + 1, preTks, tkslist)
  132. def freq(self, tk):
  133. k = self.key_(tk)
  134. if k not in self.trie_:
  135. return 0
  136. return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5)
  137. def tag(self, tk):
  138. k = self.key_(tk)
  139. if k not in self.trie_:
  140. return ""
  141. return self.trie_[k][1]
  142. def score_(self, tfts):
  143. B = 30
  144. F, L, tks = 0, 0, []
  145. for tk, (freq, tag) in tfts:
  146. F += freq
  147. L += 0 if len(tk) < 2 else 1
  148. tks.append(tk)
  149. #F /= len(tks)
  150. L /= len(tks)
  151. logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
  152. return tks, B / len(tks) + L + F
  153. def sortTks_(self, tkslist):
  154. res = []
  155. for tfts in tkslist:
  156. tks, s = self.score_(tfts)
  157. res.append((tks, s))
  158. return sorted(res, key=lambda x: x[1], reverse=True)
  159. def merge_(self, tks):
  160. # if split chars is part of token
  161. res = []
  162. tks = re.sub(r"[ ]+", " ", tks).split()
  163. s = 0
  164. while True:
  165. if s >= len(tks):
  166. break
  167. E = s + 1
  168. for e in range(s + 2, min(len(tks) + 2, s + 6)):
  169. tk = "".join(tks[s:e])
  170. if re.search(self.SPLIT_CHAR, tk) and self.freq(tk):
  171. E = e
  172. res.append("".join(tks[s:E]))
  173. s = E
  174. return " ".join(res)
  175. def maxForward_(self, line):
  176. res = []
  177. s = 0
  178. while s < len(line):
  179. e = s + 1
  180. t = line[s:e]
  181. while e < len(line) and self.trie_.has_keys_with_prefix(
  182. self.key_(t)):
  183. e += 1
  184. t = line[s:e]
  185. while e - 1 > s and self.key_(t) not in self.trie_:
  186. e -= 1
  187. t = line[s:e]
  188. if self.key_(t) in self.trie_:
  189. res.append((t, self.trie_[self.key_(t)]))
  190. else:
  191. res.append((t, (0, '')))
  192. s = e
  193. return self.score_(res)
  194. def maxBackward_(self, line):
  195. res = []
  196. s = len(line) - 1
  197. while s >= 0:
  198. e = s + 1
  199. t = line[s:e]
  200. while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)):
  201. s -= 1
  202. t = line[s:e]
  203. while s + 1 < e and self.key_(t) not in self.trie_:
  204. s += 1
  205. t = line[s:e]
  206. if self.key_(t) in self.trie_:
  207. res.append((t, self.trie_[self.key_(t)]))
  208. else:
  209. res.append((t, (0, '')))
  210. s -= 1
  211. return self.score_(res[::-1])
  212. def english_normalize_(self, tks):
  213. return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks]
  214. def tokenize(self, line):
  215. line = self._strQ2B(line).lower()
  216. line = self._tradi2simp(line)
  217. zh_num = len([1 for c in line if is_chinese(c)])
  218. if zh_num == 0:
  219. return " ".join([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(line)])
  220. arr = re.split(self.SPLIT_CHAR, line)
  221. res = []
  222. for L in arr:
  223. if len(L) < 2 or re.match(
  224. r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
  225. res.append(L)
  226. continue
  227. # print(L)
  228. # use maxforward for the first time
  229. tks, s = self.maxForward_(L)
  230. tks1, s1 = self.maxBackward_(L)
  231. if self.DEBUG:
  232. logging.debug("[FW] {} {}".format(tks, s))
  233. logging.debug("[BW] {} {}".format(tks1, s1))
  234. i, j, _i, _j = 0, 0, 0, 0
  235. same = 0
  236. while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
  237. same += 1
  238. if same > 0:
  239. res.append(" ".join(tks[j: j + same]))
  240. _i = i + same
  241. _j = j + same
  242. j = _j + 1
  243. i = _i + 1
  244. while i < len(tks1) and j < len(tks):
  245. tk1, tk = "".join(tks1[_i:i]), "".join(tks[_j:j])
  246. if tk1 != tk:
  247. if len(tk1) > len(tk):
  248. j += 1
  249. else:
  250. i += 1
  251. continue
  252. if tks1[i] != tks[j]:
  253. i += 1
  254. j += 1
  255. continue
  256. # backward tokens from_i to i are different from forward tokens from _j to j.
  257. tkslist = []
  258. self.dfs_("".join(tks[_j:j]), 0, [], tkslist)
  259. res.append(" ".join(self.sortTks_(tkslist)[0][0]))
  260. same = 1
  261. while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
  262. same += 1
  263. res.append(" ".join(tks[j: j + same]))
  264. _i = i + same
  265. _j = j + same
  266. j = _j + 1
  267. i = _i + 1
  268. if _i < len(tks1):
  269. assert _j < len(tks)
  270. assert "".join(tks1[_i:]) == "".join(tks[_j:])
  271. tkslist = []
  272. self.dfs_("".join(tks[_j:]), 0, [], tkslist)
  273. res.append(" ".join(self.sortTks_(tkslist)[0][0]))
  274. res = " ".join(self.english_normalize_(res))
  275. logging.debug("[TKS] {}".format(self.merge_(res)))
  276. return self.merge_(res)
  277. def fine_grained_tokenize(self, tks):
  278. tks = tks.split()
  279. zh_num = len([1 for c in tks if c and is_chinese(c[0])])
  280. if zh_num < len(tks) * 0.2:
  281. res = []
  282. for tk in tks:
  283. res.extend(tk.split("/"))
  284. return " ".join(res)
  285. res = []
  286. for tk in tks:
  287. if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
  288. res.append(tk)
  289. continue
  290. tkslist = []
  291. if len(tk) > 10:
  292. tkslist.append(tk)
  293. else:
  294. self.dfs_(tk, 0, [], tkslist)
  295. if len(tkslist) < 2:
  296. res.append(tk)
  297. continue
  298. stk = self.sortTks_(tkslist)[1][0]
  299. if len(stk) == len(tk):
  300. stk = tk
  301. else:
  302. if re.match(r"[a-z\.-]+$", tk):
  303. for t in stk:
  304. if len(t) < 3:
  305. stk = tk
  306. break
  307. else:
  308. stk = " ".join(stk)
  309. else:
  310. stk = " ".join(stk)
  311. res.append(stk)
  312. return " ".join(self.english_normalize_(res))
  313. def is_chinese(s):
  314. if s >= u'\u4e00' and s <= u'\u9fa5':
  315. return True
  316. else:
  317. return False
  318. def is_number(s):
  319. if s >= u'\u0030' and s <= u'\u0039':
  320. return True
  321. else:
  322. return False
  323. def is_alphabet(s):
  324. if (s >= u'\u0041' and s <= u'\u005a') or (
  325. s >= u'\u0061' and s <= u'\u007a'):
  326. return True
  327. else:
  328. return False
  329. def naiveQie(txt):
  330. tks = []
  331. for t in txt.split():
  332. if tks and re.match(r".*[a-zA-Z]$", tks[-1]
  333. ) and re.match(r".*[a-zA-Z]$", t):
  334. tks.append(" ")
  335. tks.append(t)
  336. return tks
  337. tokenizer = RagTokenizer()
  338. tokenize = tokenizer.tokenize
  339. fine_grained_tokenize = tokenizer.fine_grained_tokenize
  340. tag = tokenizer.tag
  341. freq = tokenizer.freq
  342. loadUserDict = tokenizer.loadUserDict
  343. addUserDict = tokenizer.addUserDict
  344. tradi2simp = tokenizer._tradi2simp
  345. strQ2B = tokenizer._strQ2B
  346. if __name__ == '__main__':
  347. tknzr = RagTokenizer(debug=True)
  348. # huqie.addUserDict("/tmp/tmp.new.tks.dict")
  349. tks = tknzr.tokenize(
  350. "哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
  351. logging.info(tknzr.fine_grained_tokenize(tks))
  352. tks = tknzr.tokenize(
  353. "公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。")
  354. logging.info(tknzr.fine_grained_tokenize(tks))
  355. tks = tknzr.tokenize(
  356. "多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
  357. logging.info(tknzr.fine_grained_tokenize(tks))
  358. tks = tknzr.tokenize(
  359. "实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
  360. logging.info(tknzr.fine_grained_tokenize(tks))
  361. tks = tknzr.tokenize("虽然我不怎么玩")
  362. logging.info(tknzr.fine_grained_tokenize(tks))
  363. tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
  364. logging.info(tknzr.fine_grained_tokenize(tks))
  365. tks = tknzr.tokenize(
  366. "涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
  367. logging.info(tknzr.fine_grained_tokenize(tks))
  368. tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?")
  369. logging.info(tknzr.fine_grained_tokenize(tks))
  370. tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
  371. logging.info(tknzr.fine_grained_tokenize(tks))
  372. tks = tknzr.tokenize(
  373. "数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
  374. logging.info(tknzr.fine_grained_tokenize(tks))
  375. if len(sys.argv) < 2:
  376. sys.exit()
  377. tknzr.DEBUG = False
  378. tknzr.loadUserDict(sys.argv[1])
  379. of = open(sys.argv[2], "r")
  380. while True:
  381. line = of.readline()
  382. if not line:
  383. break
  384. logging.info(tknzr.tokenize(line))
  385. of.close()