Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

query.py 5.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # -*- coding: utf-8 -*-
  2. import json
  3. import re
  4. import logging
  5. import copy
  6. import math
  7. from elasticsearch_dsl import Q, Search
  8. from rag.nlp import huqie, term_weight, synonym
  9. class EsQueryer:
  10. def __init__(self, es):
  11. self.tw = term_weight.Dealer()
  12. self.es = es
  13. self.syn = synonym.Dealer(None)
  14. self.flds = ["ask_tks^10", "ask_small_tks"]
  15. @staticmethod
  16. def subSpecialChar(line):
  17. return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|~\^])", r"\\\1", line).strip()
  18. @staticmethod
  19. def isChinese(line):
  20. arr = re.split(r"[ \t]+", line)
  21. if len(arr) <= 3:
  22. return True
  23. e = 0
  24. for t in arr:
  25. if not re.match(r"[a-zA-Z]+$", t):
  26. e += 1
  27. return e * 1. / len(arr) >= 0.7
  28. @staticmethod
  29. def rmWWW(txt):
  30. txt = re.sub(
  31. r"是*(什么样的|哪家|那家|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*",
  32. "",
  33. txt)
  34. return re.sub(
  35. r"(what|who|how|which|where|why|(is|are|were|was) there) (is|are|were|was|to)*", "", txt, re.IGNORECASE)
  36. def question(self, txt, tbl="qa", min_match="60%"):
  37. txt = re.sub(
  38. r"[ \r\n\t,,。??/`!!&]+",
  39. " ",
  40. huqie.tradi2simp(
  41. huqie.strQ2B(
  42. txt.lower()))).strip()
  43. txt = EsQueryer.rmWWW(txt)
  44. if not self.isChinese(txt):
  45. tks = [t for t in txt.split(" ") if t.strip()]
  46. q = tks
  47. for i in range(1, len(tks)):
  48. q.append("\"%s %s\"^2" % (tks[i - 1], tks[i]))
  49. if not q:
  50. q.append(txt)
  51. return Q("bool",
  52. must=Q("query_string", fields=self.flds,
  53. type="best_fields", query=" OR ".join(q),
  54. boost=1, minimum_should_match=min_match)
  55. ), txt.split(" ")
  56. def needQieqie(tk):
  57. if len(tk) < 4:
  58. return False
  59. if re.match(r"[0-9a-z\.\+#_\*-]+$", tk):
  60. return False
  61. return True
  62. qs, keywords = [], []
  63. for tt in self.tw.split(txt): # .split(" "):
  64. if not tt:
  65. continue
  66. twts = self.tw.weights([tt])
  67. syns = self.syn.lookup(tt)
  68. logging.info(json.dumps(twts, ensure_ascii=False))
  69. tms = []
  70. for tk, w in sorted(twts, key=lambda x: x[1] * -1):
  71. sm = huqie.qieqie(tk).split(" ") if needQieqie(tk) else []
  72. sm = [
  73. re.sub(
  74. r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
  75. "",
  76. m) for m in sm]
  77. sm = [EsQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
  78. sm = [m for m in sm if len(m) > 1]
  79. if len(sm) < 2:
  80. sm = []
  81. keywords.append(re.sub(r"[ \\\"']+", "", tk))
  82. tk_syns = self.syn.lookup(tk)
  83. tk = EsQueryer.subSpecialChar(tk)
  84. if tk.find(" ") > 0:
  85. tk = "\"%s\"" % tk
  86. if tk_syns:
  87. tk = f"({tk} %s)" % " ".join(tk_syns)
  88. if sm:
  89. tk = f"{tk} OR \"%s\" OR (\"%s\"~2)^0.5" % (
  90. " ".join(sm), " ".join(sm))
  91. tms.append((tk, w))
  92. tms = " ".join([f"({t})^{w}" for t, w in tms])
  93. if len(twts) > 1:
  94. tms += f" (\"%s\"~4)^1.5" % (" ".join([t for t, _ in twts]))
  95. if re.match(r"[0-9a-z ]+$", tt):
  96. tms = f"(\"{tt}\" OR \"%s\")" % huqie.qie(tt)
  97. syns = " OR ".join(
  98. ["\"%s\"^0.7" % EsQueryer.subSpecialChar(huqie.qie(s)) for s in syns])
  99. if syns:
  100. tms = f"({tms})^5 OR ({syns})^0.7"
  101. qs.append(tms)
  102. flds = copy.deepcopy(self.flds)
  103. mst = []
  104. if qs:
  105. mst.append(
  106. Q("query_string", fields=flds, type="best_fields",
  107. query=" OR ".join([f"({t})" for t in qs if t]), boost=1, minimum_should_match=min_match)
  108. )
  109. return Q("bool",
  110. must=mst,
  111. ), keywords
  112. def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3,
  113. vtweight=0.7):
  114. from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
  115. import numpy as np
  116. sims = CosineSimilarity([avec], bvecs)
  117. def toDict(tks):
  118. d = {}
  119. if isinstance(tks, type("")):
  120. tks = tks.split(" ")
  121. for t, c in self.tw.weights(tks):
  122. if t not in d:
  123. d[t] = 0
  124. d[t] += c
  125. return d
  126. atks = toDict(atks)
  127. btkss = [toDict(tks) for tks in btkss]
  128. tksim = [self.similarity(atks, btks) for btks in btkss]
  129. return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
  130. def similarity(self, qtwt, dtwt):
  131. if isinstance(dtwt, type("")):
  132. dtwt = {t: w for t, w in self.tw.weights(self.tw.split(dtwt))}
  133. if isinstance(qtwt, type("")):
  134. qtwt = {t: w for t, w in self.tw.weights(self.tw.split(qtwt))}
  135. s = 1e-9
  136. for k, v in qtwt.items():
  137. if k in dtwt:
  138. s += v# * dtwt[k]
  139. q = 1e-9
  140. for k, v in qtwt.items():
  141. q += v * v
  142. d = 1e-9
  143. for k, v in dtwt.items():
  144. d += v * v
  145. return s / q#math.sqrt(q) / math.sqrt(d)