You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

search.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import re
  18. import json
  19. from typing import List, Optional, Dict, Union
  20. from dataclasses import dataclass
  21. from rag.utils import rmSpace
  22. from rag.nlp import rag_tokenizer, query
  23. import numpy as np
  24. from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr
  25. def index_name(uid): return f"ragflow_{uid}"
  26. class Dealer:
  27. def __init__(self, dataStore: DocStoreConnection):
  28. self.qryr = query.FulltextQueryer()
  29. self.dataStore = dataStore
  30. @dataclass
  31. class SearchResult:
  32. total: int
  33. ids: List[str]
  34. query_vector: List[float] = None
  35. field: Optional[Dict] = None
  36. highlight: Optional[Dict] = None
  37. aggregation: Union[List, Dict, None] = None
  38. keywords: Optional[List[str]] = None
  39. group_docs: List[List] = None
  40. def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
  41. qv, _ = emb_mdl.encode_queries(txt)
  42. embedding_data = [float(v) for v in qv]
  43. vector_column_name = f"q_{len(embedding_data)}_vec"
  44. return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
  45. def get_filters(self, req):
  46. condition = dict()
  47. for key, field in {"kb_ids": "kb_id", "doc_ids": "doc_id"}.items():
  48. if key in req and req[key] is not None:
  49. condition[field] = req[key]
  50. # TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
  51. for key in ["knowledge_graph_kwd"]:
  52. if key in req and req[key] is not None:
  53. condition[key] = req[key]
  54. return condition
  55. def search(self, req, idx_names: list[str], kb_ids: list[str], emb_mdl=None, highlight = False):
  56. filters = self.get_filters(req)
  57. orderBy = OrderByExpr()
  58. pg = int(req.get("page", 1)) - 1
  59. topk = int(req.get("topk", 1024))
  60. ps = int(req.get("size", topk))
  61. offset, limit = pg * ps, (pg + 1) * ps
  62. src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
  63. "doc_id", "position_list", "knowledge_graph_kwd",
  64. "available_int", "content_with_weight"])
  65. kwds = set([])
  66. qst = req.get("question", "")
  67. q_vec = []
  68. if not qst:
  69. if req.get("sort"):
  70. orderBy.desc("create_timestamp_flt")
  71. res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
  72. total=self.dataStore.getTotal(res)
  73. logging.debug("Dealer.search TOTAL: {}".format(total))
  74. else:
  75. highlightFields = ["content_ltks", "title_tks"] if highlight else []
  76. matchText, keywords = self.qryr.question(qst, min_match=0.3)
  77. if emb_mdl is None:
  78. matchExprs = [matchText]
  79. res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids)
  80. total=self.dataStore.getTotal(res)
  81. logging.debug("Dealer.search TOTAL: {}".format(total))
  82. else:
  83. matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
  84. q_vec = matchDense.embedding_data
  85. src.append(f"q_{len(q_vec)}_vec")
  86. fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"})
  87. matchExprs = [matchText, matchDense, fusionExpr]
  88. res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids)
  89. total=self.dataStore.getTotal(res)
  90. logging.debug("Dealer.search TOTAL: {}".format(total))
  91. # If result is empty, try again with lower min_match
  92. if total == 0:
  93. matchText, _ = self.qryr.question(qst, min_match=0.1)
  94. if "doc_ids" in filters:
  95. del filters["doc_ids"]
  96. matchDense.extra_options["similarity"] = 0.17
  97. res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], orderBy, offset, limit, idx_names, kb_ids)
  98. total=self.dataStore.getTotal(res)
  99. logging.debug("Dealer.search 2 TOTAL: {}".format(total))
  100. for k in keywords:
  101. kwds.add(k)
  102. for kk in rag_tokenizer.fine_grained_tokenize(k).split(" "):
  103. if len(kk) < 2:
  104. continue
  105. if kk in kwds:
  106. continue
  107. kwds.add(kk)
  108. logging.debug(f"TOTAL: {total}")
  109. ids=self.dataStore.getChunkIds(res)
  110. keywords=list(kwds)
  111. highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
  112. aggs = self.dataStore.getAggregation(res, "docnm_kwd")
  113. return self.SearchResult(
  114. total=total,
  115. ids=ids,
  116. query_vector=q_vec,
  117. aggregation=aggs,
  118. highlight=highlight,
  119. field=self.dataStore.getFields(res, src),
  120. keywords=keywords
  121. )
  122. @staticmethod
  123. def trans2floats(txt):
  124. return [float(t) for t in txt.split("\t")]
  125. def insert_citations(self, answer, chunks, chunk_v,
  126. embd_mdl, tkweight=0.1, vtweight=0.9):
  127. assert len(chunks) == len(chunk_v)
  128. if not chunks:
  129. return answer, set([])
  130. pieces = re.split(r"(```)", answer)
  131. if len(pieces) >= 3:
  132. i = 0
  133. pieces_ = []
  134. while i < len(pieces):
  135. if pieces[i] == "```":
  136. st = i
  137. i += 1
  138. while i < len(pieces) and pieces[i] != "```":
  139. i += 1
  140. if i < len(pieces):
  141. i += 1
  142. pieces_.append("".join(pieces[st: i]) + "\n")
  143. else:
  144. pieces_.extend(
  145. re.split(
  146. r"([^\|][;。?!!\n]|[a-z][.?;!][ \n])",
  147. pieces[i]))
  148. i += 1
  149. pieces = pieces_
  150. else:
  151. pieces = re.split(r"([^\|][;。?!!\n]|[a-z][.?;!][ \n])", answer)
  152. for i in range(1, len(pieces)):
  153. if re.match(r"([^\|][;。?!!\n]|[a-z][.?;!][ \n])", pieces[i]):
  154. pieces[i - 1] += pieces[i][0]
  155. pieces[i] = pieces[i][1:]
  156. idx = []
  157. pieces_ = []
  158. for i, t in enumerate(pieces):
  159. if len(t) < 5:
  160. continue
  161. idx.append(i)
  162. pieces_.append(t)
  163. logging.debug("{} => {}".format(answer, pieces_))
  164. if not pieces_:
  165. return answer, set([])
  166. ans_v, _ = embd_mdl.encode(pieces_)
  167. assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
  168. len(ans_v[0]), len(chunk_v[0]))
  169. chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split(" ")
  170. for ck in chunks]
  171. cites = {}
  172. thr = 0.63
  173. while thr>0.3 and len(cites.keys()) == 0 and pieces_ and chunks_tks:
  174. for i, a in enumerate(pieces_):
  175. sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
  176. chunk_v,
  177. rag_tokenizer.tokenize(
  178. self.qryr.rmWWW(pieces_[i])).split(" "),
  179. chunks_tks,
  180. tkweight, vtweight)
  181. mx = np.max(sim) * 0.99
  182. logging.debug("{} SIM: {}".format(pieces_[i], mx))
  183. if mx < thr:
  184. continue
  185. cites[idx[i]] = list(
  186. set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
  187. thr *= 0.8
  188. res = ""
  189. seted = set([])
  190. for i, p in enumerate(pieces):
  191. res += p
  192. if i not in idx:
  193. continue
  194. if i not in cites:
  195. continue
  196. for c in cites[i]:
  197. assert int(c) < len(chunk_v)
  198. for c in cites[i]:
  199. if c in seted:
  200. continue
  201. res += f" ##{c}$$"
  202. seted.add(c)
  203. return res, seted
  204. def rerank(self, sres, query, tkweight=0.3,
  205. vtweight=0.7, cfield="content_ltks"):
  206. _, keywords = self.qryr.question(query)
  207. vector_size = len(sres.query_vector)
  208. vector_column = f"q_{vector_size}_vec"
  209. zero_vector = [0.0] * vector_size
  210. ins_embd = []
  211. for chunk_id in sres.ids:
  212. vector = sres.field[chunk_id].get(vector_column, zero_vector)
  213. if isinstance(vector, str):
  214. vector = [float(v) for v in vector.split("\t")]
  215. ins_embd.append(vector)
  216. if not ins_embd:
  217. return [], [], []
  218. for i in sres.ids:
  219. if isinstance(sres.field[i].get("important_kwd", []), str):
  220. sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]]
  221. ins_tw = []
  222. for i in sres.ids:
  223. content_ltks = sres.field[i][cfield].split(" ")
  224. title_tks = [t for t in sres.field[i].get("title_tks", "").split(" ") if t]
  225. important_kwd = sres.field[i].get("important_kwd", [])
  226. tks = content_ltks + title_tks + important_kwd
  227. ins_tw.append(tks)
  228. sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
  229. ins_embd,
  230. keywords,
  231. ins_tw, tkweight, vtweight)
  232. return sim, tksim, vtsim
  233. def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
  234. vtweight=0.7, cfield="content_ltks"):
  235. _, keywords = self.qryr.question(query)
  236. for i in sres.ids:
  237. if isinstance(sres.field[i].get("important_kwd", []), str):
  238. sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]]
  239. ins_tw = []
  240. for i in sres.ids:
  241. content_ltks = sres.field[i][cfield].split(" ")
  242. title_tks = [t for t in sres.field[i].get("title_tks", "").split(" ") if t]
  243. important_kwd = sres.field[i].get("important_kwd", [])
  244. tks = content_ltks + title_tks + important_kwd
  245. ins_tw.append(tks)
  246. tksim = self.qryr.token_similarity(keywords, ins_tw)
  247. vtsim,_ = rerank_mdl.similarity(query, [rmSpace(" ".join(tks)) for tks in ins_tw])
  248. return tkweight*np.array(tksim) + vtweight*vtsim, tksim, vtsim
  249. def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
  250. return self.qryr.hybrid_similarity(ans_embd,
  251. ins_embd,
  252. rag_tokenizer.tokenize(ans).split(" "),
  253. rag_tokenizer.tokenize(inst).split(" "))
  254. def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2,
  255. vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None, highlight=False):
  256. ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
  257. if not question:
  258. return ranks
  259. RERANK_PAGE_LIMIT = 3
  260. req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size*RERANK_PAGE_LIMIT, 128),
  261. "question": question, "vector": True, "topk": top,
  262. "similarity": similarity_threshold,
  263. "available_int": 1}
  264. if page > RERANK_PAGE_LIMIT:
  265. req["page"] = page
  266. req["size"] = page_size
  267. if isinstance(tenant_ids, str):
  268. tenant_ids = tenant_ids.split(",")
  269. sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight)
  270. ranks["total"] = sres.total
  271. if page <= RERANK_PAGE_LIMIT:
  272. if rerank_mdl:
  273. sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
  274. sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
  275. else:
  276. sim, tsim, vsim = self.rerank(
  277. sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
  278. idx = np.argsort(sim * -1)[(page-1)*page_size:page*page_size]
  279. else:
  280. sim = tsim = vsim = [1]*len(sres.ids)
  281. idx = list(range(len(sres.ids)))
  282. dim = len(sres.query_vector)
  283. vector_column = f"q_{dim}_vec"
  284. zero_vector = [0.0] * dim
  285. for i in idx:
  286. if sim[i] < similarity_threshold:
  287. break
  288. if len(ranks["chunks"]) >= page_size:
  289. if aggs:
  290. continue
  291. break
  292. id = sres.ids[i]
  293. chunk = sres.field[id]
  294. dnm = chunk["docnm_kwd"]
  295. did = chunk["doc_id"]
  296. position_list = chunk.get("position_list", "[]")
  297. if not position_list:
  298. position_list = "[]"
  299. d = {
  300. "chunk_id": id,
  301. "content_ltks": chunk["content_ltks"],
  302. "content_with_weight": chunk["content_with_weight"],
  303. "doc_id": chunk["doc_id"],
  304. "docnm_kwd": dnm,
  305. "kb_id": chunk["kb_id"],
  306. "important_kwd": chunk.get("important_kwd", []),
  307. "image_id": chunk.get("img_id", ""),
  308. "similarity": sim[i],
  309. "vector_similarity": vsim[i],
  310. "term_similarity": tsim[i],
  311. "vector": chunk.get(vector_column, zero_vector),
  312. "positions": json.loads(position_list)
  313. }
  314. if highlight:
  315. if id in sres.highlight:
  316. d["highlight"] = rmSpace(sres.highlight[id])
  317. else:
  318. d["highlight"] = d["content_with_weight"]
  319. ranks["chunks"].append(d)
  320. if dnm not in ranks["doc_aggs"]:
  321. ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
  322. ranks["doc_aggs"][dnm]["count"] += 1
  323. ranks["doc_aggs"] = [{"doc_name": k,
  324. "doc_id": v["doc_id"],
  325. "count": v["count"]} for k,
  326. v in sorted(ranks["doc_aggs"].items(),
  327. key=lambda x:x[1]["count"] * -1)]
  328. return ranks
  329. def sql_retrieval(self, sql, fetch_size=128, format="json"):
  330. tbl = self.dataStore.sql(sql, fetch_size, format)
  331. return tbl
  332. def chunk_list(self, doc_id: str, tenant_id: str, kb_ids: list[str], max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]):
  333. condition = {"doc_id": doc_id}
  334. res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), 0, max_count, index_name(tenant_id), kb_ids)
  335. dict_chunks = self.dataStore.getFields(res, fields)
  336. return dict_chunks.values()