You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

term_weight.py 8.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import math
  18. import json
  19. import re
  20. import os
  21. import numpy as np
  22. from rag.nlp import rag_tokenizer
  23. from api.utils.file_utils import get_project_base_directory
  24. class Dealer:
  25. def __init__(self):
  26. self.stop_words = set(["请问",
  27. "您",
  28. "你",
  29. "我",
  30. "他",
  31. "是",
  32. "的",
  33. "就",
  34. "有",
  35. "于",
  36. "及",
  37. "即",
  38. "在",
  39. "为",
  40. "最",
  41. "有",
  42. "从",
  43. "以",
  44. "了",
  45. "将",
  46. "与",
  47. "吗",
  48. "吧",
  49. "中",
  50. "#",
  51. "什么",
  52. "怎么",
  53. "哪个",
  54. "哪些",
  55. "啥",
  56. "相关"])
  57. def load_dict(fnm):
  58. res = {}
  59. f = open(fnm, "r")
  60. while True:
  61. line = f.readline()
  62. if not line:
  63. break
  64. arr = line.replace("\n", "").split("\t")
  65. if len(arr) < 2:
  66. res[arr[0]] = 0
  67. else:
  68. res[arr[0]] = int(arr[1])
  69. c = 0
  70. for _, v in res.items():
  71. c += v
  72. if c == 0:
  73. return set(res.keys())
  74. return res
  75. fnm = os.path.join(get_project_base_directory(), "rag/res")
  76. self.ne, self.df = {}, {}
  77. try:
  78. self.ne = json.load(open(os.path.join(fnm, "ner.json"), "r"))
  79. except Exception:
  80. logging.warning("Load ner.json FAIL!")
  81. try:
  82. self.df = load_dict(os.path.join(fnm, "term.freq"))
  83. except Exception:
  84. logging.warning("Load term.freq FAIL!")
  85. def pretoken(self, txt, num=False, stpwd=True):
  86. patt = [
  87. r"[~—\t @#%!<>,\.\?\":;'\{\}\[\]_=\(\)\|,。?》•●○↓《;‘’:“”【¥ 】…¥!、·()×`&\\/「」\\]"
  88. ]
  89. rewt = [
  90. ]
  91. for p, r in rewt:
  92. txt = re.sub(p, r, txt)
  93. res = []
  94. for t in rag_tokenizer.tokenize(txt).split():
  95. tk = t
  96. if (stpwd and tk in self.stop_words) or (
  97. re.match(r"[0-9]$", tk) and not num):
  98. continue
  99. for p in patt:
  100. if re.match(p, t):
  101. tk = "#"
  102. break
  103. #tk = re.sub(r"([\+\\-])", r"\\\1", tk)
  104. if tk != "#" and tk:
  105. res.append(tk)
  106. return res
  107. def tokenMerge(self, tks):
  108. def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
  109. res, i = [], 0
  110. while i < len(tks):
  111. j = i
  112. if i == 0 and oneTerm(tks[i]) and len(
  113. tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
  114. res.append(" ".join(tks[0:2]))
  115. i = 2
  116. continue
  117. while j < len(
  118. tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
  119. j += 1
  120. if j - i > 1:
  121. if j - i < 5:
  122. res.append(" ".join(tks[i:j]))
  123. i = j
  124. else:
  125. res.append(" ".join(tks[i:i + 2]))
  126. i = i + 2
  127. else:
  128. if len(tks[i]) > 0:
  129. res.append(tks[i])
  130. i += 1
  131. return [t for t in res if t]
  132. def ner(self, t):
  133. if not self.ne:
  134. return ""
  135. res = self.ne.get(t, "")
  136. if res:
  137. return res
  138. def split(self, txt):
  139. tks = []
  140. for t in re.sub(r"[ \t]+", " ", txt).split():
  141. if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \
  142. re.match(r".*[a-zA-Z]$", t) and tks and \
  143. self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
  144. tks[-1] = tks[-1] + " " + t
  145. else:
  146. tks.append(t)
  147. return tks
  148. def weights(self, tks, preprocess=True):
  149. def skill(t):
  150. if t not in self.sk:
  151. return 1
  152. return 6
  153. def ner(t):
  154. if re.match(r"[0-9,.]{2,}$", t):
  155. return 2
  156. if re.match(r"[a-z]{1,2}$", t):
  157. return 0.01
  158. if not self.ne or t not in self.ne:
  159. return 1
  160. m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3,
  161. "firstnm": 1}
  162. return m[self.ne[t]]
  163. def postag(t):
  164. t = rag_tokenizer.tag(t)
  165. if t in set(["r", "c", "d"]):
  166. return 0.3
  167. if t in set(["ns", "nt"]):
  168. return 3
  169. if t in set(["n"]):
  170. return 2
  171. if re.match(r"[0-9-]+", t):
  172. return 2
  173. return 1
  174. def freq(t):
  175. if re.match(r"[0-9. -]{2,}$", t):
  176. return 3
  177. s = rag_tokenizer.freq(t)
  178. if not s and re.match(r"[a-z. -]+$", t):
  179. return 300
  180. if not s:
  181. s = 0
  182. if not s and len(t) >= 4:
  183. s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1]
  184. if len(s) > 1:
  185. s = np.min([freq(tt) for tt in s]) / 6.
  186. else:
  187. s = 0
  188. return max(s, 10)
  189. def df(t):
  190. if re.match(r"[0-9. -]{2,}$", t):
  191. return 5
  192. if t in self.df:
  193. return self.df[t] + 3
  194. elif re.match(r"[a-z. -]+$", t):
  195. return 300
  196. elif len(t) >= 4:
  197. s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1]
  198. if len(s) > 1:
  199. return max(3, np.min([df(tt) for tt in s]) / 6.)
  200. return 3
  201. def idf(s, N): return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
  202. tw = []
  203. if not preprocess:
  204. idf1 = np.array([idf(freq(t), 10000000) for t in tks])
  205. idf2 = np.array([idf(df(t), 1000000000) for t in tks])
  206. wts = (0.3 * idf1 + 0.7 * idf2) * \
  207. np.array([ner(t) * postag(t) for t in tks])
  208. wts = [s for s in wts]
  209. tw = list(zip(tks, wts))
  210. else:
  211. for tk in tks:
  212. tt = self.tokenMerge(self.pretoken(tk, True))
  213. idf1 = np.array([idf(freq(t), 10000000) for t in tt])
  214. idf2 = np.array([idf(df(t), 1000000000) for t in tt])
  215. wts = (0.3 * idf1 + 0.7 * idf2) * \
  216. np.array([ner(t) * postag(t) for t in tt])
  217. wts = [s for s in wts]
  218. tw.extend(zip(tt, wts))
  219. S = np.sum([s for _, s in tw])
  220. return [(t, s / S) for t, s in tw]