Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

term_weight.py 6.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. # -*- coding: utf-8 -*-
  2. import math
  3. import json
  4. import re
  5. import os
  6. import numpy as np
  7. from rag.nlp import rag_tokenizer
  8. from api.utils.file_utils import get_project_base_directory
  9. class Dealer:
  10. def __init__(self):
  11. self.stop_words = set(["请问",
  12. "您",
  13. "你",
  14. "我",
  15. "他",
  16. "是",
  17. "的",
  18. "就",
  19. "有",
  20. "于",
  21. "及",
  22. "即",
  23. "在",
  24. "为",
  25. "最",
  26. "有",
  27. "从",
  28. "以",
  29. "了",
  30. "将",
  31. "与",
  32. "吗",
  33. "吧",
  34. "中",
  35. "#",
  36. "什么",
  37. "怎么",
  38. "哪个",
  39. "哪些",
  40. "啥",
  41. "相关"])
  42. def load_dict(fnm):
  43. res = {}
  44. f = open(fnm, "r")
  45. while True:
  46. l = f.readline()
  47. if not l:
  48. break
  49. arr = l.replace("\n", "").split("\t")
  50. if len(arr) < 2:
  51. res[arr[0]] = 0
  52. else:
  53. res[arr[0]] = int(arr[1])
  54. c = 0
  55. for _, v in res.items():
  56. c += v
  57. if c == 0:
  58. return set(res.keys())
  59. return res
  60. fnm = os.path.join(get_project_base_directory(), "rag/res")
  61. self.ne, self.df = {}, {}
  62. try:
  63. self.ne = json.load(open(os.path.join(fnm, "ner.json"), "r"))
  64. except Exception as e:
  65. print("[WARNING] Load ner.json FAIL!")
  66. try:
  67. self.df = load_dict(os.path.join(fnm, "term.freq"))
  68. except Exception as e:
  69. print("[WARNING] Load term.freq FAIL!")
  70. def pretoken(self, txt, num=False, stpwd=True):
  71. patt = [
  72. r"[~—\t @#%!<>,\.\?\":;'\{\}\[\]_=\(\)\|,。?》•●○↓《;‘’:“”【¥ 】…¥!、·()×`&\\/「」\\]"
  73. ]
  74. rewt = [
  75. ]
  76. for p, r in rewt:
  77. txt = re.sub(p, r, txt)
  78. res = []
  79. for t in rag_tokenizer.tokenize(txt).split(" "):
  80. tk = t
  81. if (stpwd and tk in self.stop_words) or (
  82. re.match(r"[0-9]$", tk) and not num):
  83. continue
  84. for p in patt:
  85. if re.match(p, t):
  86. tk = "#"
  87. break
  88. tk = re.sub(r"([\+\\-])", r"\\\1", tk)
  89. if tk != "#" and tk:
  90. res.append(tk)
  91. return res
  92. def tokenMerge(self, tks):
  93. def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
  94. res, i = [], 0
  95. while i < len(tks):
  96. j = i
  97. if i == 0 and oneTerm(tks[i]) and len(
  98. tks) > 1 and len(tks[i + 1]) > 1: # 多 工位
  99. res.append(" ".join(tks[0:2]))
  100. i = 2
  101. continue
  102. while j < len(
  103. tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
  104. j += 1
  105. if j - i > 1:
  106. if j - i < 5:
  107. res.append(" ".join(tks[i:j]))
  108. i = j
  109. else:
  110. res.append(" ".join(tks[i:i + 2]))
  111. i = i + 2
  112. else:
  113. if len(tks[i]) > 0:
  114. res.append(tks[i])
  115. i += 1
  116. return [t for t in res if t]
  117. def ner(self, t):
  118. if not self.ne:
  119. return ""
  120. res = self.ne.get(t, "")
  121. if res:
  122. return res
  123. def split(self, txt):
  124. tks = []
  125. for t in re.sub(r"[ \t]+", " ", txt).split(" "):
  126. if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \
  127. re.match(r".*[a-zA-Z]$", t) and tks and \
  128. self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
  129. tks[-1] = tks[-1] + " " + t
  130. else:
  131. tks.append(t)
  132. return tks
  133. def weights(self, tks):
  134. def skill(t):
  135. if t not in self.sk:
  136. return 1
  137. return 6
  138. def ner(t):
  139. if re.match(r"[0-9,.]{2,}$", t):
  140. return 2
  141. if re.match(r"[a-z]{1,2}$", t):
  142. return 0.01
  143. if not self.ne or t not in self.ne:
  144. return 1
  145. m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3,
  146. "firstnm": 1}
  147. return m[self.ne[t]]
  148. def postag(t):
  149. t = rag_tokenizer.tag(t)
  150. if t in set(["r", "c", "d"]):
  151. return 0.3
  152. if t in set(["ns", "nt"]):
  153. return 3
  154. if t in set(["n"]):
  155. return 2
  156. if re.match(r"[0-9-]+", t):
  157. return 2
  158. return 1
  159. def freq(t):
  160. if re.match(r"[0-9. -]{2,}$", t):
  161. return 3
  162. s = rag_tokenizer.freq(t)
  163. if not s and re.match(r"[a-z. -]+$", t):
  164. return 300
  165. if not s:
  166. s = 0
  167. if not s and len(t) >= 4:
  168. s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split(" ") if len(tt) > 1]
  169. if len(s) > 1:
  170. s = np.min([freq(tt) for tt in s]) / 6.
  171. else:
  172. s = 0
  173. return max(s, 10)
  174. def df(t):
  175. if re.match(r"[0-9. -]{2,}$", t):
  176. return 5
  177. if t in self.df:
  178. return self.df[t] + 3
  179. elif re.match(r"[a-z. -]+$", t):
  180. return 300
  181. elif len(t) >= 4:
  182. s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split(" ") if len(tt) > 1]
  183. if len(s) > 1:
  184. return max(3, np.min([df(tt) for tt in s]) / 6.)
  185. return 3
  186. def idf(s, N): return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
  187. tw = []
  188. for tk in tks:
  189. tt = self.tokenMerge(self.pretoken(tk, True))
  190. idf1 = np.array([idf(freq(t), 10000000) for t in tt])
  191. idf2 = np.array([idf(df(t), 1000000000) for t in tt])
  192. wts = (0.3 * idf1 + 0.7 * idf2) * \
  193. np.array([ner(t) * postag(t) for t in tt])
  194. tw.extend(zip(tt, wts))
  195. S = np.sum([s for _, s in tw])
  196. return [(t, s / S) for t, s in tw]