You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

search.py 8.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. # -*- coding: utf-8 -*-
  2. import re
  3. from elasticsearch_dsl import Q, Search, A
  4. from typing import List, Optional, Tuple, Dict, Union
  5. from dataclasses import dataclass
  6. from rag.utils import rmSpace
  7. from rag.nlp import huqie, query
  8. import numpy as np
  9. def index_name(uid): return f"ragflow_{uid}"
  10. class Dealer:
  11. def __init__(self, es, emb_mdl):
  12. self.qryr = query.EsQueryer(es)
  13. self.qryr.flds = [
  14. "title_tks^10",
  15. "title_sm_tks^5",
  16. "content_ltks^2",
  17. "content_sm_ltks"]
  18. self.es = es
  19. self.emb_mdl = emb_mdl
  20. @dataclass
  21. class SearchResult:
  22. total: int
  23. ids: List[str]
  24. query_vector: List[float] = None
  25. field: Optional[Dict] = None
  26. highlight: Optional[Dict] = None
  27. aggregation: Union[List, Dict, None] = None
  28. keywords: Optional[List[str]] = None
  29. group_docs: List[List] = None
  30. def _vector(self, txt, sim=0.8, topk=10):
  31. return {
  32. "field": "q_vec",
  33. "k": topk,
  34. "similarity": sim,
  35. "num_candidates": 1000,
  36. "query_vector": self.emb_mdl.encode_queries(txt)
  37. }
  38. def search(self, req, idxnm, tks_num=3):
  39. keywords = []
  40. qst = req.get("question", "")
  41. bqry, keywords = self.qryr.question(qst)
  42. if req.get("kb_ids"):
  43. bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
  44. bqry.filter.append(Q("exists", field="q_tks"))
  45. bqry.boost = 0.05
  46. print(bqry)
  47. s = Search()
  48. pg = int(req.get("page", 1)) - 1
  49. ps = int(req.get("size", 1000))
  50. src = req.get("field", ["docnm_kwd", "content_ltks", "kb_id",
  51. "image_id", "doc_id", "q_vec"])
  52. s = s.query(bqry)[pg * ps:(pg + 1) * ps]
  53. s = s.highlight("content_ltks")
  54. s = s.highlight("title_ltks")
  55. if not qst:
  56. s = s.sort(
  57. {"create_time": {"order": "desc", "unmapped_type": "date"}})
  58. s = s.highlight_options(
  59. fragment_size=120,
  60. number_of_fragments=5,
  61. boundary_scanner_locale="zh-CN",
  62. boundary_scanner="SENTENCE",
  63. boundary_chars=",./;:\\!(),。?:!……()——、"
  64. )
  65. s = s.to_dict()
  66. q_vec = []
  67. if req.get("vector"):
  68. s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps)
  69. s["knn"]["filter"] = bqry.to_dict()
  70. del s["highlight"]
  71. q_vec = s["knn"]["query_vector"]
  72. res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
  73. print("TOTAL: ", self.es.getTotal(res))
  74. if self.es.getTotal(res) == 0 and "knn" in s:
  75. bqry, _ = self.qryr.question(qst, min_match="10%")
  76. if req.get("kb_ids"):
  77. bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
  78. s["query"] = bqry.to_dict()
  79. s["knn"]["filter"] = bqry.to_dict()
  80. s["knn"]["similarity"] = 0.7
  81. res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
  82. kwds = set([])
  83. for k in keywords:
  84. kwds.add(k)
  85. for kk in huqie.qieqie(k).split(" "):
  86. if len(kk) < 2:
  87. continue
  88. if kk in kwds:
  89. continue
  90. kwds.add(kk)
  91. aggs = self.getAggregation(res, "docnm_kwd")
  92. return self.SearchResult(
  93. total=self.es.getTotal(res),
  94. ids=self.es.getDocIds(res),
  95. query_vector=q_vec,
  96. aggregation=aggs,
  97. highlight=self.getHighlight(res),
  98. field=self.getFields(res, ["docnm_kwd", "content_ltks",
  99. "kb_id", "image_id", "doc_id", "q_vec"]),
  100. keywords=list(kwds)
  101. )
  102. def getAggregation(self, res, g):
  103. if not "aggregations" in res or "aggs_" + g not in res["aggregations"]:
  104. return
  105. bkts = res["aggregations"]["aggs_" + g]["buckets"]
  106. return [(b["key"], b["doc_count"]) for b in bkts]
  107. def getHighlight(self, res):
  108. def rmspace(line):
  109. eng = set(list("qwertyuioplkjhgfdsazxcvbnm"))
  110. r = []
  111. for t in line.split(" "):
  112. if not t:
  113. continue
  114. if len(r) > 0 and len(
  115. t) > 0 and r[-1][-1] in eng and t[0] in eng:
  116. r.append(" ")
  117. r.append(t)
  118. r = "".join(r)
  119. return r
  120. ans = {}
  121. for d in res["hits"]["hits"]:
  122. hlts = d.get("highlight")
  123. if not hlts:
  124. continue
  125. ans[d["_id"]] = "".join([a for a in list(hlts.items())[0][1]])
  126. return ans
  127. def getFields(self, sres, flds):
  128. res = {}
  129. if not flds:
  130. return {}
  131. for d in self.es.getSource(sres):
  132. m = {n: d.get(n) for n in flds if d.get(n) is not None}
  133. for n, v in m.items():
  134. if isinstance(v, type([])):
  135. m[n] = "\t".join([str(vv) for vv in v])
  136. continue
  137. if not isinstance(v, type("")):
  138. m[n] = str(m[n])
  139. m[n] = rmSpace(m[n])
  140. if m:
  141. res[d["id"]] = m
  142. return res
  143. @staticmethod
  144. def trans2floats(txt):
  145. return [float(t) for t in txt.split("\t")]
  146. def insert_citations(self, ans, top_idx, sres,
  147. vfield="q_vec", cfield="content_ltks"):
  148. ins_embd = [Dealer.trans2floats(
  149. sres.field[sres.ids[i]][vfield]) for i in top_idx]
  150. ins_tw = [sres.field[sres.ids[i]][cfield].split(" ") for i in top_idx]
  151. s = 0
  152. e = 0
  153. res = ""
  154. def citeit():
  155. nonlocal s, e, ans, res
  156. if not ins_embd:
  157. return
  158. embd = self.emb_mdl.encode(ans[s: e])
  159. sim = self.qryr.hybrid_similarity(embd,
  160. ins_embd,
  161. huqie.qie(ans[s:e]).split(" "),
  162. ins_tw)
  163. print(ans[s: e], sim)
  164. mx = np.max(sim) * 0.99
  165. if mx < 0.55:
  166. return
  167. cita = list(set([top_idx[i]
  168. for i in range(len(ins_embd)) if sim[i] > mx]))[:4]
  169. for i in cita:
  170. res += f"@?{i}?@"
  171. return cita
  172. punct = set(";。?!!")
  173. if not self.qryr.isChinese(ans):
  174. punct.add("?")
  175. punct.add(".")
  176. while e < len(ans):
  177. if e - s < 12 or ans[e] not in punct:
  178. e += 1
  179. continue
  180. if ans[e] == "." and e + \
  181. 1 < len(ans) and re.match(r"[0-9]", ans[e + 1]):
  182. e += 1
  183. continue
  184. if ans[e] == "." and e - 2 >= 0 and ans[e - 2] == "\n":
  185. e += 1
  186. continue
  187. res += ans[s: e]
  188. citeit()
  189. res += ans[e]
  190. e += 1
  191. s = e
  192. if s < len(ans):
  193. res += ans[s:]
  194. citeit()
  195. return res
  196. def rerank(self, sres, query, tkweight=0.3, vtweight=0.7,
  197. vfield="q_vec", cfield="content_ltks"):
  198. ins_embd = [
  199. Dealer.trans2floats(
  200. sres.field[i]["q_vec"]) for i in sres.ids]
  201. if not ins_embd:
  202. return []
  203. ins_tw = [sres.field[i][cfield].split(" ") for i in sres.ids]
  204. # return CosineSimilarity([sres.query_vector], ins_embd)[0]
  205. sim = self.qryr.hybrid_similarity(sres.query_vector,
  206. ins_embd,
  207. huqie.qie(query).split(" "),
  208. ins_tw, tkweight, vtweight)
  209. return sim
  210. if __name__ == "__main__":
  211. from util import es_conn
  212. SE = Dealer(es_conn.HuEs("infiniflow"))
  213. qs = [
  214. "胡凯",
  215. ""
  216. ]
  217. for q in qs:
  218. print(">>>>>>>>>>>>>>>>>>>>", q)
  219. print(SE.search(
  220. {"question": q, "kb_ids": "64f072a75f3b97c865718c4a"}, "infiniflow_*"))