Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

rag_tokenizer.py 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. # -*- coding: utf-8 -*-
  2. import copy
  3. import datrie
  4. import math
  5. import os
  6. import re
  7. import string
  8. import sys
  9. from hanziconv import HanziConv
  10. from huggingface_hub import snapshot_download
  11. from nltk import word_tokenize
  12. from nltk.stem import PorterStemmer, WordNetLemmatizer
  13. from api.utils.file_utils import get_project_base_directory
  14. class RagTokenizer:
  15. def key_(self, line):
  16. return str(line.lower().encode("utf-8"))[2:-1]
  17. def rkey_(self, line):
  18. return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
  19. def loadDict_(self, fnm):
  20. print("[HUQIE]:Build trie", fnm, file=sys.stderr)
  21. try:
  22. of = open(fnm, "r", encoding='utf-8')
  23. while True:
  24. line = of.readline()
  25. if not line:
  26. break
  27. line = re.sub(r"[\r\n]+", "", line)
  28. line = re.split(r"[ \t]", line)
  29. k = self.key_(line[0])
  30. F = int(math.log(float(line[1]) / self.DENOMINATOR) + .5)
  31. if k not in self.trie_ or self.trie_[k][0] < F:
  32. self.trie_[self.key_(line[0])] = (F, line[2])
  33. self.trie_[self.rkey_(line[0])] = 1
  34. self.trie_.save(fnm + ".trie")
  35. of.close()
  36. except Exception as e:
  37. print("[HUQIE]:Faild to build trie, ", fnm, e, file=sys.stderr)
  38. def __init__(self, debug=False):
  39. self.DEBUG = debug
  40. self.DENOMINATOR = 1000000
  41. self.trie_ = datrie.Trie(string.printable)
  42. self.DIR_ = os.path.join(get_project_base_directory(), "rag/res", "huqie")
  43. self.stemmer = PorterStemmer()
  44. self.lemmatizer = WordNetLemmatizer()
  45. self.SPLIT_CHAR = r"([ ,\.<>/?;'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-z\.-]+|[0-9,\.-]+)"
  46. try:
  47. self.trie_ = datrie.Trie.load(self.DIR_ + ".txt.trie")
  48. return
  49. except Exception as e:
  50. print("[HUQIE]:Build default trie", file=sys.stderr)
  51. self.trie_ = datrie.Trie(string.printable)
  52. self.loadDict_(self.DIR_ + ".txt")
  53. def loadUserDict(self, fnm):
  54. try:
  55. self.trie_ = datrie.Trie.load(fnm + ".trie")
  56. return
  57. except Exception as e:
  58. self.trie_ = datrie.Trie(string.printable)
  59. self.loadDict_(fnm)
  60. def addUserDict(self, fnm):
  61. self.loadDict_(fnm)
  62. def _strQ2B(self, ustring):
  63. """把字符串全角转半角"""
  64. rstring = ""
  65. for uchar in ustring:
  66. inside_code = ord(uchar)
  67. if inside_code == 0x3000:
  68. inside_code = 0x0020
  69. else:
  70. inside_code -= 0xfee0
  71. if inside_code < 0x0020 or inside_code > 0x7e: # 转完之后不是半角字符返回原来的字符
  72. rstring += uchar
  73. else:
  74. rstring += chr(inside_code)
  75. return rstring
  76. def _tradi2simp(self, line):
  77. return HanziConv.toSimplified(line)
  78. def dfs_(self, chars, s, preTks, tkslist):
  79. MAX_L = 10
  80. res = s
  81. # if s > MAX_L or s>= len(chars):
  82. if s >= len(chars):
  83. tkslist.append(preTks)
  84. return res
  85. # pruning
  86. S = s + 1
  87. if s + 2 <= len(chars):
  88. t1, t2 = "".join(chars[s:s + 1]), "".join(chars[s:s + 2])
  89. if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(
  90. self.key_(t2)):
  91. S = s + 2
  92. if len(preTks) > 2 and len(
  93. preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
  94. t1 = preTks[-1][0] + "".join(chars[s:s + 1])
  95. if self.trie_.has_keys_with_prefix(self.key_(t1)):
  96. S = s + 2
  97. ################
  98. for e in range(S, len(chars) + 1):
  99. t = "".join(chars[s:e])
  100. k = self.key_(t)
  101. if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
  102. break
  103. if k in self.trie_:
  104. pretks = copy.deepcopy(preTks)
  105. if k in self.trie_:
  106. pretks.append((t, self.trie_[k]))
  107. else:
  108. pretks.append((t, (-12, '')))
  109. res = max(res, self.dfs_(chars, e, pretks, tkslist))
  110. if res > s:
  111. return res
  112. t = "".join(chars[s:s + 1])
  113. k = self.key_(t)
  114. if k in self.trie_:
  115. preTks.append((t, self.trie_[k]))
  116. else:
  117. preTks.append((t, (-12, '')))
  118. return self.dfs_(chars, s + 1, preTks, tkslist)
  119. def freq(self, tk):
  120. k = self.key_(tk)
  121. if k not in self.trie_:
  122. return 0
  123. return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5)
  124. def tag(self, tk):
  125. k = self.key_(tk)
  126. if k not in self.trie_:
  127. return ""
  128. return self.trie_[k][1]
  129. def score_(self, tfts):
  130. B = 30
  131. F, L, tks = 0, 0, []
  132. for tk, (freq, tag) in tfts:
  133. F += freq
  134. L += 0 if len(tk) < 2 else 1
  135. tks.append(tk)
  136. F /= len(tks)
  137. L /= len(tks)
  138. if self.DEBUG:
  139. print("[SC]", tks, len(tks), L, F, B / len(tks) + L + F)
  140. return tks, B / len(tks) + L + F
  141. def sortTks_(self, tkslist):
  142. res = []
  143. for tfts in tkslist:
  144. tks, s = self.score_(tfts)
  145. res.append((tks, s))
  146. return sorted(res, key=lambda x: x[1], reverse=True)
  147. def merge_(self, tks):
  148. patts = [
  149. (r"[ ]+", " "),
  150. (r"([0-9\+\.,%\*=-]) ([0-9\+\.,%\*=-])", r"\1\2"),
  151. ]
  152. # for p,s in patts: tks = re.sub(p, s, tks)
  153. # if split chars is part of token
  154. res = []
  155. tks = re.sub(r"[ ]+", " ", tks).split(" ")
  156. s = 0
  157. while True:
  158. if s >= len(tks):
  159. break
  160. E = s + 1
  161. for e in range(s + 2, min(len(tks) + 2, s + 6)):
  162. tk = "".join(tks[s:e])
  163. if re.search(self.SPLIT_CHAR, tk) and self.freq(tk):
  164. E = e
  165. res.append("".join(tks[s:E]))
  166. s = E
  167. return " ".join(res)
  168. def maxForward_(self, line):
  169. res = []
  170. s = 0
  171. while s < len(line):
  172. e = s + 1
  173. t = line[s:e]
  174. while e < len(line) and self.trie_.has_keys_with_prefix(
  175. self.key_(t)):
  176. e += 1
  177. t = line[s:e]
  178. while e - 1 > s and self.key_(t) not in self.trie_:
  179. e -= 1
  180. t = line[s:e]
  181. if self.key_(t) in self.trie_:
  182. res.append((t, self.trie_[self.key_(t)]))
  183. else:
  184. res.append((t, (0, '')))
  185. s = e
  186. return self.score_(res)
  187. def maxBackward_(self, line):
  188. res = []
  189. s = len(line) - 1
  190. while s >= 0:
  191. e = s + 1
  192. t = line[s:e]
  193. while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)):
  194. s -= 1
  195. t = line[s:e]
  196. while s + 1 < e and self.key_(t) not in self.trie_:
  197. s += 1
  198. t = line[s:e]
  199. if self.key_(t) in self.trie_:
  200. res.append((t, self.trie_[self.key_(t)]))
  201. else:
  202. res.append((t, (0, '')))
  203. s -= 1
  204. return self.score_(res[::-1])
  205. def tokenize(self, line):
  206. line = self._strQ2B(line).lower()
  207. line = self._tradi2simp(line)
  208. zh_num = len([1 for c in line if is_chinese(c)])
  209. if zh_num < len(line) * 0.2:
  210. return " ".join([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(line)])
  211. arr = re.split(self.SPLIT_CHAR, line)
  212. res = []
  213. for L in arr:
  214. if len(L) < 2 or re.match(
  215. r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
  216. res.append(L)
  217. continue
  218. # print(L)
  219. # use maxforward for the first time
  220. tks, s = self.maxForward_(L)
  221. tks1, s1 = self.maxBackward_(L)
  222. if self.DEBUG:
  223. print("[FW]", tks, s)
  224. print("[BW]", tks1, s1)
  225. diff = [0 for _ in range(max(len(tks1), len(tks)))]
  226. for i in range(min(len(tks1), len(tks))):
  227. if tks[i] != tks1[i]:
  228. diff[i] = 1
  229. if s1 > s:
  230. tks = tks1
  231. i = 0
  232. while i < len(tks):
  233. s = i
  234. while s < len(tks) and diff[s] == 0:
  235. s += 1
  236. if s == len(tks):
  237. res.append(" ".join(tks[i:]))
  238. break
  239. if s > i:
  240. res.append(" ".join(tks[i:s]))
  241. e = s
  242. while e < len(tks) and e - s < 5 and diff[e] == 1:
  243. e += 1
  244. tkslist = []
  245. self.dfs_("".join(tks[s:e + 1]), 0, [], tkslist)
  246. res.append(" ".join(self.sortTks_(tkslist)[0][0]))
  247. i = e + 1
  248. res = " ".join(res)
  249. if self.DEBUG:
  250. print("[TKS]", self.merge_(res))
  251. return self.merge_(res)
  252. def fine_grained_tokenize(self, tks):
  253. tks = tks.split(" ")
  254. zh_num = len([1 for c in tks if c and is_chinese(c[0])])
  255. if zh_num < len(tks) * 0.2:
  256. res = []
  257. for tk in tks:
  258. res.extend(tk.split("/"))
  259. return " ".join(res)
  260. res = []
  261. for tk in tks:
  262. if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
  263. res.append(tk)
  264. continue
  265. tkslist = []
  266. if len(tk) > 10:
  267. tkslist.append(tk)
  268. else:
  269. self.dfs_(tk, 0, [], tkslist)
  270. if len(tkslist) < 2:
  271. res.append(tk)
  272. continue
  273. stk = self.sortTks_(tkslist)[1][0]
  274. if len(stk) == len(tk):
  275. stk = tk
  276. else:
  277. if re.match(r"[a-z\.-]+$", tk):
  278. for t in stk:
  279. if len(t) < 3:
  280. stk = tk
  281. break
  282. else:
  283. stk = " ".join(stk)
  284. else:
  285. stk = " ".join(stk)
  286. res.append(stk)
  287. return " ".join(res)
  288. def is_chinese(s):
  289. if s >= u'\u4e00' and s <= u'\u9fa5':
  290. return True
  291. else:
  292. return False
  293. def is_number(s):
  294. if s >= u'\u0030' and s <= u'\u0039':
  295. return True
  296. else:
  297. return False
  298. def is_alphabet(s):
  299. if (s >= u'\u0041' and s <= u'\u005a') or (
  300. s >= u'\u0061' and s <= u'\u007a'):
  301. return True
  302. else:
  303. return False
  304. def naiveQie(txt):
  305. tks = []
  306. for t in txt.split(" "):
  307. if tks and re.match(r".*[a-zA-Z]$", tks[-1]
  308. ) and re.match(r".*[a-zA-Z]$", t):
  309. tks.append(" ")
  310. tks.append(t)
  311. return tks
  312. tokenizer = RagTokenizer()
  313. tokenize = tokenizer.tokenize
  314. fine_grained_tokenize = tokenizer.fine_grained_tokenize
  315. tag = tokenizer.tag
  316. freq = tokenizer.freq
  317. loadUserDict = tokenizer.loadUserDict
  318. addUserDict = tokenizer.addUserDict
  319. tradi2simp = tokenizer._tradi2simp
  320. strQ2B = tokenizer._strQ2B
  321. if __name__ == '__main__':
  322. tknzr = RagTokenizer(debug=True)
  323. # huqie.addUserDict("/tmp/tmp.new.tks.dict")
  324. tks = tknzr.tokenize(
  325. "哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
  326. print(tknzr.fine_grained_tokenize(tks))
  327. tks = tknzr.tokenize(
  328. "公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。")
  329. print(tknzr.fine_grained_tokenize(tks))
  330. tks = tknzr.tokenize(
  331. "多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
  332. print(tknzr.fine_grained_tokenize(tks))
  333. tks = tknzr.tokenize(
  334. "实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
  335. print(tknzr.fine_grained_tokenize(tks))
  336. tks = tknzr.tokenize("虽然我不怎么玩")
  337. print(tknzr.fine_grained_tokenize(tks))
  338. tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
  339. print(tknzr.fine_grained_tokenize(tks))
  340. tks = tknzr.tokenize(
  341. "涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
  342. print(tknzr.fine_grained_tokenize(tks))
  343. tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?")
  344. print(tknzr.fine_grained_tokenize(tks))
  345. tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
  346. print(tknzr.fine_grained_tokenize(tks))
  347. tks = tknzr.tokenize(
  348. "数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
  349. print(tknzr.fine_grained_tokenize(tks))
  350. if len(sys.argv) < 2:
  351. sys.exit()
  352. tknzr.DEBUG = False
  353. tknzr.loadUserDict(sys.argv[1])
  354. of = open(sys.argv[2], "r")
  355. while True:
  356. line = of.readline()
  357. if not line:
  358. break
  359. print(tknzr.tokenize(line))
  360. of.close()