您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

rag_tokenizer.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import copy
  17. import datrie
  18. import math
  19. import os
  20. import re
  21. import string
  22. import sys
  23. from hanziconv import HanziConv
  24. from huggingface_hub import snapshot_download
  25. from nltk import word_tokenize
  26. from nltk.stem import PorterStemmer, WordNetLemmatizer
  27. from api.utils.file_utils import get_project_base_directory
  28. class RagTokenizer:
  29. def key_(self, line):
  30. return str(line.lower().encode("utf-8"))[2:-1]
  31. def rkey_(self, line):
  32. return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
  33. def loadDict_(self, fnm):
  34. print("[HUQIE]:Build trie", fnm, file=sys.stderr)
  35. try:
  36. of = open(fnm, "r", encoding='utf-8')
  37. while True:
  38. line = of.readline()
  39. if not line:
  40. break
  41. line = re.sub(r"[\r\n]+", "", line)
  42. line = re.split(r"[ \t]", line)
  43. k = self.key_(line[0])
  44. F = int(math.log(float(line[1]) / self.DENOMINATOR) + .5)
  45. if k not in self.trie_ or self.trie_[k][0] < F:
  46. self.trie_[self.key_(line[0])] = (F, line[2])
  47. self.trie_[self.rkey_(line[0])] = 1
  48. self.trie_.save(fnm + ".trie")
  49. of.close()
  50. except Exception as e:
  51. print("[HUQIE]:Faild to build trie, ", fnm, e, file=sys.stderr)
  52. def __init__(self, debug=False):
  53. self.DEBUG = debug
  54. self.DENOMINATOR = 1000000
  55. self.trie_ = datrie.Trie(string.printable)
  56. self.DIR_ = os.path.join(get_project_base_directory(), "rag/res", "huqie")
  57. self.stemmer = PorterStemmer()
  58. self.lemmatizer = WordNetLemmatizer()
  59. self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-z\.-]+|[0-9,\.-]+)"
  60. try:
  61. self.trie_ = datrie.Trie.load(self.DIR_ + ".txt.trie")
  62. return
  63. except Exception as e:
  64. print("[HUQIE]:Build default trie", file=sys.stderr)
  65. self.trie_ = datrie.Trie(string.printable)
  66. self.loadDict_(self.DIR_ + ".txt")
  67. def loadUserDict(self, fnm):
  68. try:
  69. self.trie_ = datrie.Trie.load(fnm + ".trie")
  70. return
  71. except Exception as e:
  72. self.trie_ = datrie.Trie(string.printable)
  73. self.loadDict_(fnm)
  74. def addUserDict(self, fnm):
  75. self.loadDict_(fnm)
  76. def _strQ2B(self, ustring):
  77. """把字符串全角转半角"""
  78. rstring = ""
  79. for uchar in ustring:
  80. inside_code = ord(uchar)
  81. if inside_code == 0x3000:
  82. inside_code = 0x0020
  83. else:
  84. inside_code -= 0xfee0
  85. if inside_code < 0x0020 or inside_code > 0x7e: # 转完之后不是半角字符返回原来的字符
  86. rstring += uchar
  87. else:
  88. rstring += chr(inside_code)
  89. return rstring
  90. def _tradi2simp(self, line):
  91. return HanziConv.toSimplified(line)
  92. def dfs_(self, chars, s, preTks, tkslist):
  93. MAX_L = 10
  94. res = s
  95. # if s > MAX_L or s>= len(chars):
  96. if s >= len(chars):
  97. tkslist.append(preTks)
  98. return res
  99. # pruning
  100. S = s + 1
  101. if s + 2 <= len(chars):
  102. t1, t2 = "".join(chars[s:s + 1]), "".join(chars[s:s + 2])
  103. if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(
  104. self.key_(t2)):
  105. S = s + 2
  106. if len(preTks) > 2 and len(
  107. preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
  108. t1 = preTks[-1][0] + "".join(chars[s:s + 1])
  109. if self.trie_.has_keys_with_prefix(self.key_(t1)):
  110. S = s + 2
  111. ################
  112. for e in range(S, len(chars) + 1):
  113. t = "".join(chars[s:e])
  114. k = self.key_(t)
  115. if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
  116. break
  117. if k in self.trie_:
  118. pretks = copy.deepcopy(preTks)
  119. if k in self.trie_:
  120. pretks.append((t, self.trie_[k]))
  121. else:
  122. pretks.append((t, (-12, '')))
  123. res = max(res, self.dfs_(chars, e, pretks, tkslist))
  124. if res > s:
  125. return res
  126. t = "".join(chars[s:s + 1])
  127. k = self.key_(t)
  128. if k in self.trie_:
  129. preTks.append((t, self.trie_[k]))
  130. else:
  131. preTks.append((t, (-12, '')))
  132. return self.dfs_(chars, s + 1, preTks, tkslist)
  133. def freq(self, tk):
  134. k = self.key_(tk)
  135. if k not in self.trie_:
  136. return 0
  137. return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5)
  138. def tag(self, tk):
  139. k = self.key_(tk)
  140. if k not in self.trie_:
  141. return ""
  142. return self.trie_[k][1]
  143. def score_(self, tfts):
  144. B = 30
  145. F, L, tks = 0, 0, []
  146. for tk, (freq, tag) in tfts:
  147. F += freq
  148. L += 0 if len(tk) < 2 else 1
  149. tks.append(tk)
  150. F /= len(tks)
  151. L /= len(tks)
  152. if self.DEBUG:
  153. print("[SC]", tks, len(tks), L, F, B / len(tks) + L + F)
  154. return tks, B / len(tks) + L + F
  155. def sortTks_(self, tkslist):
  156. res = []
  157. for tfts in tkslist:
  158. tks, s = self.score_(tfts)
  159. res.append((tks, s))
  160. return sorted(res, key=lambda x: x[1], reverse=True)
  161. def merge_(self, tks):
  162. patts = [
  163. (r"[ ]+", " "),
  164. (r"([0-9\+\.,%\*=-]) ([0-9\+\.,%\*=-])", r"\1\2"),
  165. ]
  166. # for p,s in patts: tks = re.sub(p, s, tks)
  167. # if split chars is part of token
  168. res = []
  169. tks = re.sub(r"[ ]+", " ", tks).split(" ")
  170. s = 0
  171. while True:
  172. if s >= len(tks):
  173. break
  174. E = s + 1
  175. for e in range(s + 2, min(len(tks) + 2, s + 6)):
  176. tk = "".join(tks[s:e])
  177. if re.search(self.SPLIT_CHAR, tk) and self.freq(tk):
  178. E = e
  179. res.append("".join(tks[s:E]))
  180. s = E
  181. return " ".join(res)
  182. def maxForward_(self, line):
  183. res = []
  184. s = 0
  185. while s < len(line):
  186. e = s + 1
  187. t = line[s:e]
  188. while e < len(line) and self.trie_.has_keys_with_prefix(
  189. self.key_(t)):
  190. e += 1
  191. t = line[s:e]
  192. while e - 1 > s and self.key_(t) not in self.trie_:
  193. e -= 1
  194. t = line[s:e]
  195. if self.key_(t) in self.trie_:
  196. res.append((t, self.trie_[self.key_(t)]))
  197. else:
  198. res.append((t, (0, '')))
  199. s = e
  200. return self.score_(res)
  201. def maxBackward_(self, line):
  202. res = []
  203. s = len(line) - 1
  204. while s >= 0:
  205. e = s + 1
  206. t = line[s:e]
  207. while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)):
  208. s -= 1
  209. t = line[s:e]
  210. while s + 1 < e and self.key_(t) not in self.trie_:
  211. s += 1
  212. t = line[s:e]
  213. if self.key_(t) in self.trie_:
  214. res.append((t, self.trie_[self.key_(t)]))
  215. else:
  216. res.append((t, (0, '')))
  217. s -= 1
  218. return self.score_(res[::-1])
  219. def english_normalize_(self, tks):
  220. return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks]
  221. def tokenize(self, line):
  222. line = self._strQ2B(line).lower()
  223. line = self._tradi2simp(line)
  224. zh_num = len([1 for c in line if is_chinese(c)])
  225. if zh_num == 0:
  226. return " ".join([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(line)])
  227. arr = re.split(self.SPLIT_CHAR, line)
  228. res = []
  229. for L in arr:
  230. if len(L) < 2 or re.match(
  231. r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
  232. res.append(L)
  233. continue
  234. # print(L)
  235. # use maxforward for the first time
  236. tks, s = self.maxForward_(L)
  237. tks1, s1 = self.maxBackward_(L)
  238. if self.DEBUG:
  239. print("[FW]", tks, s)
  240. print("[BW]", tks1, s1)
  241. i, j, _i, _j = 0, 0, 0, 0
  242. same = 0
  243. while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
  244. same += 1
  245. if same > 0: res.append(" ".join(tks[j: j + same]))
  246. _i = i + same
  247. _j = j + same
  248. j = _j + 1
  249. i = _i + 1
  250. while i < len(tks1) and j < len(tks):
  251. tk1, tk = "".join(tks1[_i:i]), "".join(tks[_j:j])
  252. if tk1 != tk:
  253. if len(tk1) > len(tk):
  254. j += 1
  255. else:
  256. i += 1
  257. continue
  258. if tks1[i] != tks[j]:
  259. i += 1
  260. j += 1
  261. continue
  262. # backward tokens from_i to i are different from forward tokens from _j to j.
  263. tkslist = []
  264. self.dfs_("".join(tks[_j:j]), 0, [], tkslist)
  265. res.append(" ".join(self.sortTks_(tkslist)[0][0]))
  266. same = 1
  267. while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
  268. same += 1
  269. res.append(" ".join(tks[j: j + same]))
  270. _i = i + same
  271. _j = j + same
  272. j = _j + 1
  273. i = _i + 1
  274. if _i < len(tks1):
  275. assert _j < len(tks)
  276. assert "".join(tks1[_i:]) == "".join(tks[_j:])
  277. tkslist = []
  278. self.dfs_("".join(tks[_j:]), 0, [], tkslist)
  279. res.append(" ".join(self.sortTks_(tkslist)[0][0]))
  280. res = " ".join(self.english_normalize_(res))
  281. if self.DEBUG:
  282. print("[TKS]", self.merge_(res))
  283. return self.merge_(res)
  284. def fine_grained_tokenize(self, tks):
  285. tks = tks.split(" ")
  286. zh_num = len([1 for c in tks if c and is_chinese(c[0])])
  287. if zh_num < len(tks) * 0.2:
  288. res = []
  289. for tk in tks:
  290. res.extend(tk.split("/"))
  291. return " ".join(res)
  292. res = []
  293. for tk in tks:
  294. if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
  295. res.append(tk)
  296. continue
  297. tkslist = []
  298. if len(tk) > 10:
  299. tkslist.append(tk)
  300. else:
  301. self.dfs_(tk, 0, [], tkslist)
  302. if len(tkslist) < 2:
  303. res.append(tk)
  304. continue
  305. stk = self.sortTks_(tkslist)[1][0]
  306. if len(stk) == len(tk):
  307. stk = tk
  308. else:
  309. if re.match(r"[a-z\.-]+$", tk):
  310. for t in stk:
  311. if len(t) < 3:
  312. stk = tk
  313. break
  314. else:
  315. stk = " ".join(stk)
  316. else:
  317. stk = " ".join(stk)
  318. res.append(stk)
  319. return " ".join(self.english_normalize_(res))
  320. def is_chinese(s):
  321. if s >= u'\u4e00' and s <= u'\u9fa5':
  322. return True
  323. else:
  324. return False
  325. def is_number(s):
  326. if s >= u'\u0030' and s <= u'\u0039':
  327. return True
  328. else:
  329. return False
  330. def is_alphabet(s):
  331. if (s >= u'\u0041' and s <= u'\u005a') or (
  332. s >= u'\u0061' and s <= u'\u007a'):
  333. return True
  334. else:
  335. return False
  336. def naiveQie(txt):
  337. tks = []
  338. for t in txt.split(" "):
  339. if tks and re.match(r".*[a-zA-Z]$", tks[-1]
  340. ) and re.match(r".*[a-zA-Z]$", t):
  341. tks.append(" ")
  342. tks.append(t)
  343. return tks
  344. tokenizer = RagTokenizer()
  345. tokenize = tokenizer.tokenize
  346. fine_grained_tokenize = tokenizer.fine_grained_tokenize
  347. tag = tokenizer.tag
  348. freq = tokenizer.freq
  349. loadUserDict = tokenizer.loadUserDict
  350. addUserDict = tokenizer.addUserDict
  351. tradi2simp = tokenizer._tradi2simp
  352. strQ2B = tokenizer._strQ2B
  353. if __name__ == '__main__':
  354. tknzr = RagTokenizer(debug=True)
  355. # huqie.addUserDict("/tmp/tmp.new.tks.dict")
  356. tks = tknzr.tokenize(
  357. "哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
  358. print(tknzr.fine_grained_tokenize(tks))
  359. tks = tknzr.tokenize(
  360. "公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。")
  361. print(tknzr.fine_grained_tokenize(tks))
  362. tks = tknzr.tokenize(
  363. "多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
  364. print(tknzr.fine_grained_tokenize(tks))
  365. tks = tknzr.tokenize(
  366. "实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
  367. print(tknzr.fine_grained_tokenize(tks))
  368. tks = tknzr.tokenize("虽然我不怎么玩")
  369. print(tknzr.fine_grained_tokenize(tks))
  370. tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
  371. print(tknzr.fine_grained_tokenize(tks))
  372. tks = tknzr.tokenize(
  373. "涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
  374. print(tknzr.fine_grained_tokenize(tks))
  375. tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?")
  376. print(tknzr.fine_grained_tokenize(tks))
  377. tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
  378. print(tknzr.fine_grained_tokenize(tks))
  379. tks = tknzr.tokenize(
  380. "数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
  381. print(tknzr.fine_grained_tokenize(tks))
  382. if len(sys.argv) < 2:
  383. sys.exit()
  384. tknzr.DEBUG = False
  385. tknzr.loadUserDict(sys.argv[1])
  386. of = open(sys.argv[2], "r")
  387. while True:
  388. line = of.readline()
  389. if not line:
  390. break
  391. print(tknzr.tokenize(line))
  392. of.close()