You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

term_weight.py 7.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import math
  17. import json
  18. import re
  19. import os
  20. import numpy as np
  21. from rag.nlp import rag_tokenizer
  22. from api.utils.file_utils import get_project_base_directory
  23. class Dealer:
  24. def __init__(self):
  25. self.stop_words = set(["请问",
  26. "您",
  27. "你",
  28. "我",
  29. "他",
  30. "是",
  31. "的",
  32. "就",
  33. "有",
  34. "于",
  35. "及",
  36. "即",
  37. "在",
  38. "为",
  39. "最",
  40. "有",
  41. "从",
  42. "以",
  43. "了",
  44. "将",
  45. "与",
  46. "吗",
  47. "吧",
  48. "中",
  49. "#",
  50. "什么",
  51. "怎么",
  52. "哪个",
  53. "哪些",
  54. "啥",
  55. "相关"])
  56. def load_dict(fnm):
  57. res = {}
  58. f = open(fnm, "r")
  59. while True:
  60. l = f.readline()
  61. if not l:
  62. break
  63. arr = l.replace("\n", "").split("\t")
  64. if len(arr) < 2:
  65. res[arr[0]] = 0
  66. else:
  67. res[arr[0]] = int(arr[1])
  68. c = 0
  69. for _, v in res.items():
  70. c += v
  71. if c == 0:
  72. return set(res.keys())
  73. return res
  74. fnm = os.path.join(get_project_base_directory(), "rag/res")
  75. self.ne, self.df = {}, {}
  76. try:
  77. self.ne = json.load(open(os.path.join(fnm, "ner.json"), "r"))
  78. except Exception as e:
  79. print("[WARNING] Load ner.json FAIL!")
  80. try:
  81. self.df = load_dict(os.path.join(fnm, "term.freq"))
  82. except Exception as e:
  83. print("[WARNING] Load term.freq FAIL!")
  84. def pretoken(self, txt, num=False, stpwd=True):
  85. patt = [
  86. r"[~—\t @#%!<>,\.\?\":;'\{\}\[\]_=\(\)\|,。?》•●○↓《;‘’:“”【¥ 】…¥!、·()×`&\\/「」\\]"
  87. ]
  88. rewt = [
  89. ]
  90. for p, r in rewt:
  91. txt = re.sub(p, r, txt)
  92. res = []
  93. for t in rag_tokenizer.tokenize(txt).split(" "):
  94. tk = t
  95. if (stpwd and tk in self.stop_words) or (
  96. re.match(r"[0-9]$", tk) and not num):
  97. continue
  98. for p in patt:
  99. if re.match(p, t):
  100. tk = "#"
  101. break
  102. tk = re.sub(r"([\+\\-])", r"\\\1", tk)
  103. if tk != "#" and tk:
  104. res.append(tk)
  105. return res
  106. def tokenMerge(self, tks):
  107. def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
  108. res, i = [], 0
  109. while i < len(tks):
  110. j = i
  111. if i == 0 and oneTerm(tks[i]) and len(
  112. tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
  113. res.append(" ".join(tks[0:2]))
  114. i = 2
  115. continue
  116. while j < len(
  117. tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
  118. j += 1
  119. if j - i > 1:
  120. if j - i < 5:
  121. res.append(" ".join(tks[i:j]))
  122. i = j
  123. else:
  124. res.append(" ".join(tks[i:i + 2]))
  125. i = i + 2
  126. else:
  127. if len(tks[i]) > 0:
  128. res.append(tks[i])
  129. i += 1
  130. return [t for t in res if t]
  131. def ner(self, t):
  132. if not self.ne:
  133. return ""
  134. res = self.ne.get(t, "")
  135. if res:
  136. return res
  137. def split(self, txt):
  138. tks = []
  139. for t in re.sub(r"[ \t]+", " ", txt).split(" "):
  140. if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \
  141. re.match(r".*[a-zA-Z]$", t) and tks and \
  142. self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
  143. tks[-1] = tks[-1] + " " + t
  144. else:
  145. tks.append(t)
  146. return tks
  147. def weights(self, tks):
  148. def skill(t):
  149. if t not in self.sk:
  150. return 1
  151. return 6
  152. def ner(t):
  153. if re.match(r"[0-9,.]{2,}$", t):
  154. return 2
  155. if re.match(r"[a-z]{1,2}$", t):
  156. return 0.01
  157. if not self.ne or t not in self.ne:
  158. return 1
  159. m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3,
  160. "firstnm": 1}
  161. return m[self.ne[t]]
  162. def postag(t):
  163. t = rag_tokenizer.tag(t)
  164. if t in set(["r", "c", "d"]):
  165. return 0.3
  166. if t in set(["ns", "nt"]):
  167. return 3
  168. if t in set(["n"]):
  169. return 2
  170. if re.match(r"[0-9-]+", t):
  171. return 2
  172. return 1
  173. def freq(t):
  174. if re.match(r"[0-9. -]{2,}$", t):
  175. return 3
  176. s = rag_tokenizer.freq(t)
  177. if not s and re.match(r"[a-z. -]+$", t):
  178. return 300
  179. if not s:
  180. s = 0
  181. if not s and len(t) >= 4:
  182. s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split(" ") if len(tt) > 1]
  183. if len(s) > 1:
  184. s = np.min([freq(tt) for tt in s]) / 6.
  185. else:
  186. s = 0
  187. return max(s, 10)
  188. def df(t):
  189. if re.match(r"[0-9. -]{2,}$", t):
  190. return 5
  191. if t in self.df:
  192. return self.df[t] + 3
  193. elif re.match(r"[a-z. -]+$", t):
  194. return 300
  195. elif len(t) >= 4:
  196. s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split(" ") if len(tt) > 1]
  197. if len(s) > 1:
  198. return max(3, np.min([df(tt) for tt in s]) / 6.)
  199. return 3
  200. def idf(s, N): return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
  201. tw = []
  202. for tk in tks:
  203. tt = self.tokenMerge(self.pretoken(tk, True))
  204. idf1 = np.array([idf(freq(t), 10000000) for t in tt])
  205. idf2 = np.array([idf(df(t), 1000000000) for t in tt])
  206. wts = (0.3 * idf1 + 0.7 * idf2) * \
  207. np.array([ner(t) * postag(t) for t in tt])
  208. tw.extend(zip(tt, wts))
  209. S = np.sum([s for _, s in tw])
  210. return [(t, s / S) for t, s in tw]