您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import copy
  18. import datrie
  19. import math
  20. import os
  21. import re
  22. import string
  23. import sys
  24. from hanziconv import HanziConv
  25. from nltk import word_tokenize
  26. from nltk.stem import PorterStemmer, WordNetLemmatizer
  27. from api.utils.file_utils import get_project_base_directory
  28. class RagTokenizer:
  29. def key_(self, line):
  30. return str(line.lower().encode("utf-8"))[2:-1]
  31. def rkey_(self, line):
  32. return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
  33. def loadDict_(self, fnm):
  34. logging.info(f"[HUQIE]:Build trie from {fnm}")
  35. try:
  36. of = open(fnm, "r", encoding='utf-8')
  37. while True:
  38. line = of.readline()
  39. if not line:
  40. break
  41. line = re.sub(r"[\r\n]+", "", line)
  42. line = re.split(r"[ \t]", line)
  43. k = self.key_(line[0])
  44. F = int(math.log(float(line[1]) / self.DENOMINATOR) + .5)
  45. if k not in self.trie_ or self.trie_[k][0] < F:
  46. self.trie_[self.key_(line[0])] = (F, line[2])
  47. self.trie_[self.rkey_(line[0])] = 1
  48. dict_file_cache = fnm + ".trie"
  49. logging.info(f"[HUQIE]:Build trie cache to {dict_file_cache}")
  50. self.trie_.save(dict_file_cache)
  51. of.close()
  52. except Exception:
  53. logging.exception(f"[HUQIE]:Build trie {fnm} failed")
  54. def __init__(self, debug=False):
  55. self.DEBUG = debug
  56. self.DENOMINATOR = 1000000
  57. self.DIR_ = os.path.join(get_project_base_directory(), "rag/res", "huqie")
  58. self.stemmer = PorterStemmer()
  59. self.lemmatizer = WordNetLemmatizer()
  60. self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-zA-Z0-9,\.-]+)"
  61. trie_file_name = self.DIR_ + ".txt.trie"
  62. # check if trie file existence
  63. if os.path.exists(trie_file_name):
  64. try:
  65. # load trie from file
  66. self.trie_ = datrie.Trie.load(trie_file_name)
  67. return
  68. except Exception:
  69. # fail to load trie from file, build default trie
  70. logging.exception(f"[HUQIE]:Fail to load trie file {trie_file_name}, build the default trie file")
  71. self.trie_ = datrie.Trie(string.printable)
  72. else:
  73. # file not exist, build default trie
  74. logging.info(f"[HUQIE]:Trie file {trie_file_name} not found, build the default trie file")
  75. self.trie_ = datrie.Trie(string.printable)
  76. # load data from dict file and save to trie file
  77. self.loadDict_(self.DIR_ + ".txt")
  78. def loadUserDict(self, fnm):
  79. try:
  80. self.trie_ = datrie.Trie.load(fnm + ".trie")
  81. return
  82. except Exception:
  83. self.trie_ = datrie.Trie(string.printable)
  84. self.loadDict_(fnm)
  85. def addUserDict(self, fnm):
  86. self.loadDict_(fnm)
  87. def _strQ2B(self, ustring):
  88. """Convert full-width characters to half-width characters"""
  89. rstring = ""
  90. for uchar in ustring:
  91. inside_code = ord(uchar)
  92. if inside_code == 0x3000:
  93. inside_code = 0x0020
  94. else:
  95. inside_code -= 0xfee0
  96. if inside_code < 0x0020 or inside_code > 0x7e: # After the conversion, if it's not a half-width character, return the original character.
  97. rstring += uchar
  98. else:
  99. rstring += chr(inside_code)
  100. return rstring
  101. def _tradi2simp(self, line):
  102. return HanziConv.toSimplified(line)
  103. def dfs_(self, chars, s, preTks, tkslist):
  104. res = s
  105. # if s > MAX_L or s>= len(chars):
  106. if s >= len(chars):
  107. tkslist.append(preTks)
  108. return res
  109. # pruning
  110. S = s + 1
  111. if s + 2 <= len(chars):
  112. t1, t2 = "".join(chars[s:s + 1]), "".join(chars[s:s + 2])
  113. if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(
  114. self.key_(t2)):
  115. S = s + 2
  116. if len(preTks) > 2 and len(
  117. preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
  118. t1 = preTks[-1][0] + "".join(chars[s:s + 1])
  119. if self.trie_.has_keys_with_prefix(self.key_(t1)):
  120. S = s + 2
  121. ################
  122. for e in range(S, len(chars) + 1):
  123. t = "".join(chars[s:e])
  124. k = self.key_(t)
  125. if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
  126. break
  127. if k in self.trie_:
  128. pretks = copy.deepcopy(preTks)
  129. if k in self.trie_:
  130. pretks.append((t, self.trie_[k]))
  131. else:
  132. pretks.append((t, (-12, '')))
  133. res = max(res, self.dfs_(chars, e, pretks, tkslist))
  134. if res > s:
  135. return res
  136. t = "".join(chars[s:s + 1])
  137. k = self.key_(t)
  138. if k in self.trie_:
  139. preTks.append((t, self.trie_[k]))
  140. else:
  141. preTks.append((t, (-12, '')))
  142. return self.dfs_(chars, s + 1, preTks, tkslist)
  143. def freq(self, tk):
  144. k = self.key_(tk)
  145. if k not in self.trie_:
  146. return 0
  147. return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5)
  148. def tag(self, tk):
  149. k = self.key_(tk)
  150. if k not in self.trie_:
  151. return ""
  152. return self.trie_[k][1]
  153. def score_(self, tfts):
  154. B = 30
  155. F, L, tks = 0, 0, []
  156. for tk, (freq, tag) in tfts:
  157. F += freq
  158. L += 0 if len(tk) < 2 else 1
  159. tks.append(tk)
  160. #F /= len(tks)
  161. L /= len(tks)
  162. logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
  163. return tks, B / len(tks) + L + F
  164. def sortTks_(self, tkslist):
  165. res = []
  166. for tfts in tkslist:
  167. tks, s = self.score_(tfts)
  168. res.append((tks, s))
  169. return sorted(res, key=lambda x: x[1], reverse=True)
  170. def merge_(self, tks):
  171. # if split chars is part of token
  172. res = []
  173. tks = re.sub(r"[ ]+", " ", tks).split()
  174. s = 0
  175. while True:
  176. if s >= len(tks):
  177. break
  178. E = s + 1
  179. for e in range(s + 2, min(len(tks) + 2, s + 6)):
  180. tk = "".join(tks[s:e])
  181. if re.search(self.SPLIT_CHAR, tk) and self.freq(tk):
  182. E = e
  183. res.append("".join(tks[s:E]))
  184. s = E
  185. return " ".join(res)
  186. def maxForward_(self, line):
  187. res = []
  188. s = 0
  189. while s < len(line):
  190. e = s + 1
  191. t = line[s:e]
  192. while e < len(line) and self.trie_.has_keys_with_prefix(
  193. self.key_(t)):
  194. e += 1
  195. t = line[s:e]
  196. while e - 1 > s and self.key_(t) not in self.trie_:
  197. e -= 1
  198. t = line[s:e]
  199. if self.key_(t) in self.trie_:
  200. res.append((t, self.trie_[self.key_(t)]))
  201. else:
  202. res.append((t, (0, '')))
  203. s = e
  204. return self.score_(res)
  205. def maxBackward_(self, line):
  206. res = []
  207. s = len(line) - 1
  208. while s >= 0:
  209. e = s + 1
  210. t = line[s:e]
  211. while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)):
  212. s -= 1
  213. t = line[s:e]
  214. while s + 1 < e and self.key_(t) not in self.trie_:
  215. s += 1
  216. t = line[s:e]
  217. if self.key_(t) in self.trie_:
  218. res.append((t, self.trie_[self.key_(t)]))
  219. else:
  220. res.append((t, (0, '')))
  221. s -= 1
  222. return self.score_(res[::-1])
  223. def english_normalize_(self, tks):
  224. return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks]
  225. def _split_by_lang(self, line):
  226. txt_lang_pairs = []
  227. arr = re.split(self.SPLIT_CHAR, line)
  228. for a in arr:
  229. if not a:
  230. continue
  231. s = 0
  232. e = s + 1
  233. zh = is_chinese(a[s])
  234. while e < len(a):
  235. _zh = is_chinese(a[e])
  236. if _zh == zh:
  237. e += 1
  238. continue
  239. txt_lang_pairs.append((a[s: e], zh))
  240. s = e
  241. e = s + 1
  242. zh = _zh
  243. if s >= len(a):
  244. continue
  245. txt_lang_pairs.append((a[s: e], zh))
  246. return txt_lang_pairs
  247. def tokenize(self, line):
  248. line = re.sub(r"\W+", " ", line)
  249. line = self._strQ2B(line).lower()
  250. line = self._tradi2simp(line)
  251. arr = self._split_by_lang(line)
  252. res = []
  253. for L,lang in arr:
  254. if not lang:
  255. res.extend([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(L)])
  256. continue
  257. if len(L) < 2 or re.match(
  258. r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
  259. res.append(L)
  260. continue
  261. # use maxforward for the first time
  262. tks, s = self.maxForward_(L)
  263. tks1, s1 = self.maxBackward_(L)
  264. if self.DEBUG:
  265. logging.debug("[FW] {} {}".format(tks, s))
  266. logging.debug("[BW] {} {}".format(tks1, s1))
  267. i, j, _i, _j = 0, 0, 0, 0
  268. same = 0
  269. while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
  270. same += 1
  271. if same > 0:
  272. res.append(" ".join(tks[j: j + same]))
  273. _i = i + same
  274. _j = j + same
  275. j = _j + 1
  276. i = _i + 1
  277. while i < len(tks1) and j < len(tks):
  278. tk1, tk = "".join(tks1[_i:i]), "".join(tks[_j:j])
  279. if tk1 != tk:
  280. if len(tk1) > len(tk):
  281. j += 1
  282. else:
  283. i += 1
  284. continue
  285. if tks1[i] != tks[j]:
  286. i += 1
  287. j += 1
  288. continue
  289. # backward tokens from_i to i are different from forward tokens from _j to j.
  290. tkslist = []
  291. self.dfs_("".join(tks[_j:j]), 0, [], tkslist)
  292. res.append(" ".join(self.sortTks_(tkslist)[0][0]))
  293. same = 1
  294. while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
  295. same += 1
  296. res.append(" ".join(tks[j: j + same]))
  297. _i = i + same
  298. _j = j + same
  299. j = _j + 1
  300. i = _i + 1
  301. if _i < len(tks1):
  302. assert _j < len(tks)
  303. assert "".join(tks1[_i:]) == "".join(tks[_j:])
  304. tkslist = []
  305. self.dfs_("".join(tks[_j:]), 0, [], tkslist)
  306. res.append(" ".join(self.sortTks_(tkslist)[0][0]))
  307. res = " ".join(res)
  308. logging.debug("[TKS] {}".format(self.merge_(res)))
  309. return self.merge_(res)
  310. def fine_grained_tokenize(self, tks):
  311. tks = tks.split()
  312. zh_num = len([1 for c in tks if c and is_chinese(c[0])])
  313. if zh_num < len(tks) * 0.2:
  314. res = []
  315. for tk in tks:
  316. res.extend(tk.split("/"))
  317. return " ".join(res)
  318. res = []
  319. for tk in tks:
  320. if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
  321. res.append(tk)
  322. continue
  323. tkslist = []
  324. if len(tk) > 10:
  325. tkslist.append(tk)
  326. else:
  327. self.dfs_(tk, 0, [], tkslist)
  328. if len(tkslist) < 2:
  329. res.append(tk)
  330. continue
  331. stk = self.sortTks_(tkslist)[1][0]
  332. if len(stk) == len(tk):
  333. stk = tk
  334. else:
  335. if re.match(r"[a-z\.-]+$", tk):
  336. for t in stk:
  337. if len(t) < 3:
  338. stk = tk
  339. break
  340. else:
  341. stk = " ".join(stk)
  342. else:
  343. stk = " ".join(stk)
  344. res.append(stk)
  345. return " ".join(self.english_normalize_(res))
  346. def is_chinese(s):
  347. if s >= u'\u4e00' and s <= u'\u9fa5':
  348. return True
  349. else:
  350. return False
  351. def is_number(s):
  352. if s >= u'\u0030' and s <= u'\u0039':
  353. return True
  354. else:
  355. return False
  356. def is_alphabet(s):
  357. if (s >= u'\u0041' and s <= u'\u005a') or (
  358. s >= u'\u0061' and s <= u'\u007a'):
  359. return True
  360. else:
  361. return False
  362. def naiveQie(txt):
  363. tks = []
  364. for t in txt.split():
  365. if tks and re.match(r".*[a-zA-Z]$", tks[-1]
  366. ) and re.match(r".*[a-zA-Z]$", t):
  367. tks.append(" ")
  368. tks.append(t)
  369. return tks
  370. tokenizer = RagTokenizer()
  371. tokenize = tokenizer.tokenize
  372. fine_grained_tokenize = tokenizer.fine_grained_tokenize
  373. tag = tokenizer.tag
  374. freq = tokenizer.freq
  375. loadUserDict = tokenizer.loadUserDict
  376. addUserDict = tokenizer.addUserDict
  377. tradi2simp = tokenizer._tradi2simp
  378. strQ2B = tokenizer._strQ2B
  379. if __name__ == '__main__':
  380. tknzr = RagTokenizer(debug=True)
  381. # huqie.addUserDict("/tmp/tmp.new.tks.dict")
  382. tks = tknzr.tokenize(
  383. "哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
  384. logging.info(tknzr.fine_grained_tokenize(tks))
  385. tks = tknzr.tokenize(
  386. "公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。")
  387. logging.info(tknzr.fine_grained_tokenize(tks))
  388. tks = tknzr.tokenize(
  389. "多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
  390. logging.info(tknzr.fine_grained_tokenize(tks))
  391. tks = tknzr.tokenize(
  392. "实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
  393. logging.info(tknzr.fine_grained_tokenize(tks))
  394. tks = tknzr.tokenize("虽然我不怎么玩")
  395. logging.info(tknzr.fine_grained_tokenize(tks))
  396. tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
  397. logging.info(tknzr.fine_grained_tokenize(tks))
  398. tks = tknzr.tokenize(
  399. "涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
  400. logging.info(tknzr.fine_grained_tokenize(tks))
  401. tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?")
  402. logging.info(tknzr.fine_grained_tokenize(tks))
  403. tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
  404. logging.info(tknzr.fine_grained_tokenize(tks))
  405. tks = tknzr.tokenize(
  406. "数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
  407. logging.info(tknzr.fine_grained_tokenize(tks))
  408. if len(sys.argv) < 2:
  409. sys.exit()
  410. tknzr.DEBUG = False
  411. tknzr.loadUserDict(sys.argv[1])
  412. of = open(sys.argv[2], "r")
  413. while True:
  414. line = of.readline()
  415. if not line:
  416. break
  417. logging.info(tknzr.tokenize(line))
  418. of.close()