您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553
  1. # Copyright (c) 2024 Microsoft Corporation.
  2. # Licensed under the MIT License
  3. """
  4. Reference:
  5. - [graphrag](https://github.com/microsoft/graphrag)
  6. - [LightRag](https://github.com/HKUDS/LightRAG)
  7. """
  8. import html
  9. import json
  10. import logging
  11. import re
  12. import time
  13. from collections import defaultdict
  14. from copy import deepcopy
  15. from hashlib import md5
  16. from typing import Any, Callable
  17. import networkx as nx
  18. import numpy as np
  19. import xxhash
  20. from networkx.readwrite import json_graph
  21. from api import settings
  22. from rag.nlp import search, rag_tokenizer
  23. from rag.utils.doc_store_conn import OrderByExpr
  24. from rag.utils.redis_conn import REDIS_CONN
  25. ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
  26. def perform_variable_replacements(
  27. input: str, history: list[dict] | None = None, variables: dict | None = None
  28. ) -> str:
  29. """Perform variable replacements on the input string and in a chat log."""
  30. if history is None:
  31. history = []
  32. if variables is None:
  33. variables = {}
  34. result = input
  35. def replace_all(input: str) -> str:
  36. result = input
  37. for k, v in variables.items():
  38. result = result.replace(f"{{{k}}}", v)
  39. return result
  40. result = replace_all(result)
  41. for i, entry in enumerate(history):
  42. if entry.get("role") == "system":
  43. entry["content"] = replace_all(entry.get("content") or "")
  44. return result
  45. def clean_str(input: Any) -> str:
  46. """Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
  47. # If we get non-string input, just give it back
  48. if not isinstance(input, str):
  49. return input
  50. result = html.unescape(input.strip())
  51. # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
  52. return re.sub(r"[\"\x00-\x1f\x7f-\x9f]", "", result)
  53. def dict_has_keys_with_types(
  54. data: dict, expected_fields: list[tuple[str, type]]
  55. ) -> bool:
  56. """Return True if the given dictionary has the given keys with the given types."""
  57. for field, field_type in expected_fields:
  58. if field not in data:
  59. return False
  60. value = data[field]
  61. if not isinstance(value, field_type):
  62. return False
  63. return True
  64. def get_llm_cache(llmnm, txt, history, genconf):
  65. hasher = xxhash.xxh64()
  66. hasher.update(str(llmnm).encode("utf-8"))
  67. hasher.update(str(txt).encode("utf-8"))
  68. hasher.update(str(history).encode("utf-8"))
  69. hasher.update(str(genconf).encode("utf-8"))
  70. k = hasher.hexdigest()
  71. bin = REDIS_CONN.get(k)
  72. if not bin:
  73. return
  74. return bin
  75. def set_llm_cache(llmnm, txt, v, history, genconf):
  76. hasher = xxhash.xxh64()
  77. hasher.update(str(llmnm).encode("utf-8"))
  78. hasher.update(str(txt).encode("utf-8"))
  79. hasher.update(str(history).encode("utf-8"))
  80. hasher.update(str(genconf).encode("utf-8"))
  81. k = hasher.hexdigest()
  82. REDIS_CONN.set(k, v.encode("utf-8"), 24*3600)
  83. def get_embed_cache(llmnm, txt):
  84. hasher = xxhash.xxh64()
  85. hasher.update(str(llmnm).encode("utf-8"))
  86. hasher.update(str(txt).encode("utf-8"))
  87. k = hasher.hexdigest()
  88. bin = REDIS_CONN.get(k)
  89. if not bin:
  90. return
  91. return np.array(json.loads(bin))
  92. def set_embed_cache(llmnm, txt, arr):
  93. hasher = xxhash.xxh64()
  94. hasher.update(str(llmnm).encode("utf-8"))
  95. hasher.update(str(txt).encode("utf-8"))
  96. k = hasher.hexdigest()
  97. arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr)
  98. REDIS_CONN.set(k, arr.encode("utf-8"), 24*3600)
  99. def get_tags_from_cache(kb_ids):
  100. hasher = xxhash.xxh64()
  101. hasher.update(str(kb_ids).encode("utf-8"))
  102. k = hasher.hexdigest()
  103. bin = REDIS_CONN.get(k)
  104. if not bin:
  105. return
  106. return bin
  107. def set_tags_to_cache(kb_ids, tags):
  108. hasher = xxhash.xxh64()
  109. hasher.update(str(kb_ids).encode("utf-8"))
  110. k = hasher.hexdigest()
  111. REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
  112. def graph_merge(g1, g2):
  113. g = g2.copy()
  114. for n, attr in g1.nodes(data=True):
  115. if n not in g2.nodes():
  116. g.add_node(n, **attr)
  117. continue
  118. for source, target, attr in g1.edges(data=True):
  119. if g.has_edge(source, target):
  120. g[source][target].update({"weight": attr.get("weight", 0)+1})
  121. continue
  122. g.add_edge(source, target)#, **attr)
  123. for node_degree in g.degree:
  124. g.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
  125. return g
  126. def compute_args_hash(*args):
  127. return md5(str(args).encode()).hexdigest()
  128. def handle_single_entity_extraction(
  129. record_attributes: list[str],
  130. chunk_key: str,
  131. ):
  132. if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
  133. return None
  134. # add this record as a node in the G
  135. entity_name = clean_str(record_attributes[1].upper())
  136. if not entity_name.strip():
  137. return None
  138. entity_type = clean_str(record_attributes[2].upper())
  139. entity_description = clean_str(record_attributes[3])
  140. entity_source_id = chunk_key
  141. return dict(
  142. entity_name=entity_name.upper(),
  143. entity_type=entity_type.upper(),
  144. description=entity_description,
  145. source_id=entity_source_id,
  146. )
  147. def handle_single_relationship_extraction(record_attributes: list[str], chunk_key: str):
  148. if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
  149. return None
  150. # add this record as edge
  151. source = clean_str(record_attributes[1].upper())
  152. target = clean_str(record_attributes[2].upper())
  153. edge_description = clean_str(record_attributes[3])
  154. edge_keywords = clean_str(record_attributes[4])
  155. edge_source_id = chunk_key
  156. weight = (
  157. float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
  158. )
  159. pair = sorted([source.upper(), target.upper()])
  160. return dict(
  161. src_id=pair[0],
  162. tgt_id=pair[1],
  163. weight=weight,
  164. description=edge_description,
  165. keywords=edge_keywords,
  166. source_id=edge_source_id,
  167. metadata={"created_at": time.time()},
  168. )
  169. def pack_user_ass_to_openai_messages(*args: str):
  170. roles = ["user", "assistant"]
  171. return [
  172. {"role": roles[i % 2], "content": content} for i, content in enumerate(args)
  173. ]
  174. def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
  175. """Split a string by multiple markers"""
  176. if not markers:
  177. return [content]
  178. results = re.split("|".join(re.escape(marker) for marker in markers), content)
  179. return [r.strip() for r in results if r.strip()]
  180. def is_float_regex(value):
  181. return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
  182. def chunk_id(chunk):
  183. return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()
  184. def get_entity(tenant_id, kb_id, ent_name):
  185. conds = {
  186. "fields": ["content_with_weight"],
  187. "entity_kwd": ent_name,
  188. "size": 10000,
  189. "knowledge_graph_kwd": ["entity"]
  190. }
  191. res = []
  192. es_res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])
  193. for id in es_res.ids:
  194. try:
  195. if isinstance(ent_name, str):
  196. return json.loads(es_res.field[id]["content_with_weight"])
  197. res.append(json.loads(es_res.field[id]["content_with_weight"]))
  198. except Exception:
  199. continue
  200. return res
  201. def set_entity(tenant_id, kb_id, embd_mdl, ent_name, meta):
  202. chunk = {
  203. "important_kwd": [ent_name],
  204. "title_tks": rag_tokenizer.tokenize(ent_name),
  205. "entity_kwd": ent_name,
  206. "knowledge_graph_kwd": "entity",
  207. "entity_type_kwd": meta["entity_type"],
  208. "content_with_weight": json.dumps(meta, ensure_ascii=False),
  209. "content_ltks": rag_tokenizer.tokenize(meta["description"]),
  210. "source_id": list(set(meta["source_id"])),
  211. "kb_id": kb_id,
  212. "available_int": 0
  213. }
  214. chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
  215. res = settings.retrievaler.search({"entity_kwd": ent_name, "size": 1, "fields": []},
  216. search.index_name(tenant_id), [kb_id])
  217. if res.ids:
  218. settings.docStoreConn.update({"entity_kwd": ent_name}, chunk, search.index_name(tenant_id), kb_id)
  219. else:
  220. ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
  221. if ebd is None:
  222. try:
  223. ebd, _ = embd_mdl.encode([ent_name])
  224. ebd = ebd[0]
  225. set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
  226. except Exception as e:
  227. logging.exception(f"Fail to embed entity: {e}")
  228. if ebd is not None:
  229. chunk["q_%d_vec" % len(ebd)] = ebd
  230. settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)
  231. def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
  232. ents = from_ent_name
  233. if isinstance(ents, str):
  234. ents = [from_ent_name]
  235. if isinstance(to_ent_name, str):
  236. to_ent_name = [to_ent_name]
  237. ents.extend(to_ent_name)
  238. ents = list(set(ents))
  239. conds = {
  240. "fields": ["content_with_weight"],
  241. "size": size,
  242. "from_entity_kwd": ents,
  243. "to_entity_kwd": ents,
  244. "knowledge_graph_kwd": ["relation"]
  245. }
  246. res = []
  247. es_res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
  248. for id in es_res.ids:
  249. try:
  250. if size == 1:
  251. return json.loads(es_res.field[id]["content_with_weight"])
  252. res.append(json.loads(es_res.field[id]["content_with_weight"]))
  253. except Exception:
  254. continue
  255. return res
  256. def set_relation(tenant_id, kb_id, embd_mdl, from_ent_name, to_ent_name, meta):
  257. chunk = {
  258. "from_entity_kwd": from_ent_name,
  259. "to_entity_kwd": to_ent_name,
  260. "knowledge_graph_kwd": "relation",
  261. "content_with_weight": json.dumps(meta, ensure_ascii=False),
  262. "content_ltks": rag_tokenizer.tokenize(meta["description"]),
  263. "important_kwd": meta["keywords"],
  264. "source_id": list(set(meta["source_id"])),
  265. "weight_int": int(meta["weight"]),
  266. "kb_id": kb_id,
  267. "available_int": 0
  268. }
  269. chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
  270. res = settings.retrievaler.search({"from_entity_kwd": to_ent_name, "to_entity_kwd": to_ent_name, "size": 1, "fields": []},
  271. search.index_name(tenant_id), [kb_id])
  272. if res.ids:
  273. settings.docStoreConn.update({"from_entity_kwd": from_ent_name, "to_entity_kwd": to_ent_name},
  274. chunk,
  275. search.index_name(tenant_id), kb_id)
  276. else:
  277. txt = f"{from_ent_name}->{to_ent_name}"
  278. ebd = get_embed_cache(embd_mdl.llm_name, txt)
  279. if ebd is None:
  280. try:
  281. ebd, _ = embd_mdl.encode([txt+f": {meta['description']}"])
  282. ebd = ebd[0]
  283. set_embed_cache(embd_mdl.llm_name, txt, ebd)
  284. except Exception as e:
  285. logging.exception(f"Fail to embed entity relation: {e}")
  286. if ebd is not None:
  287. chunk["q_%d_vec" % len(ebd)] = ebd
  288. settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)
  289. def get_graph(tenant_id, kb_id):
  290. conds = {
  291. "fields": ["content_with_weight", "source_id"],
  292. "removed_kwd": "N",
  293. "size": 1,
  294. "knowledge_graph_kwd": ["graph"]
  295. }
  296. res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])
  297. for id in res.ids:
  298. try:
  299. return json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges"), \
  300. res.field[id]["source_id"]
  301. except Exception:
  302. continue
  303. return rebuild_graph(tenant_id, kb_id)
  304. def set_graph(tenant_id, kb_id, graph, docids):
  305. chunk = {
  306. "content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False,
  307. indent=2),
  308. "knowledge_graph_kwd": "graph",
  309. "kb_id": kb_id,
  310. "source_id": list(docids),
  311. "available_int": 0,
  312. "removed_kwd": "N"
  313. }
  314. res = settings.retrievaler.search({"knowledge_graph_kwd": "graph", "size": 1, "fields": []}, search.index_name(tenant_id), [kb_id])
  315. if res.ids:
  316. settings.docStoreConn.update({"knowledge_graph_kwd": "graph"}, chunk,
  317. search.index_name(tenant_id), kb_id)
  318. else:
  319. settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)
  320. def is_continuous_subsequence(subseq, seq):
  321. def find_all_indexes(tup, value):
  322. indexes = []
  323. start = 0
  324. while True:
  325. try:
  326. index = tup.index(value, start)
  327. indexes.append(index)
  328. start = index + 1
  329. except ValueError:
  330. break
  331. return indexes
  332. index_list = find_all_indexes(seq,subseq[0])
  333. for idx in index_list:
  334. if idx!=len(seq)-1:
  335. if seq[idx+1]==subseq[-1]:
  336. return True
  337. return False
  338. def merge_tuples(list1, list2):
  339. result = []
  340. for tup in list1:
  341. last_element = tup[-1]
  342. if last_element in tup[:-1]:
  343. result.append(tup)
  344. else:
  345. matching_tuples = [t for t in list2 if t[0] == last_element]
  346. already_match_flag = 0
  347. for match in matching_tuples:
  348. matchh = (match[1], match[0])
  349. if is_continuous_subsequence(match, tup) or is_continuous_subsequence(matchh, tup):
  350. continue
  351. already_match_flag = 1
  352. merged_tuple = tup + match[1:]
  353. result.append(merged_tuple)
  354. if not already_match_flag:
  355. result.append(tup)
  356. return result
  357. def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop):
  358. def n_neighbor(id):
  359. nonlocal graph, n_hop
  360. count = 0
  361. source_edge = list(graph.edges(id))
  362. if not source_edge:
  363. return []
  364. count = count + 1
  365. while count < n_hop:
  366. count = count + 1
  367. sc_edge = deepcopy(source_edge)
  368. source_edge = []
  369. for pair in sc_edge:
  370. append_edge = list(graph.edges(pair[-1]))
  371. for tuples in merge_tuples([pair], append_edge):
  372. source_edge.append(tuples)
  373. nbrs = []
  374. for path in source_edge:
  375. n = {"path": path, "weights": []}
  376. wts = nx.get_edge_attributes(graph, 'weight')
  377. for i in range(len(path)-1):
  378. f, t = path[i], path[i+1]
  379. n["weights"].append(wts.get((f, t), 0))
  380. nbrs.append(n)
  381. return nbrs
  382. pr = nx.pagerank(graph)
  383. for n, p in pr.items():
  384. graph.nodes[n]["pagerank"] = p
  385. try:
  386. settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id},
  387. {"rank_flt": p,
  388. "n_hop_with_weight": json.dumps(n_neighbor(n), ensure_ascii=False)},
  389. search.index_name(tenant_id), kb_id)
  390. except Exception as e:
  391. logging.exception(e)
  392. ty2ents = defaultdict(list)
  393. for p, r in sorted(pr.items(), key=lambda x: x[1], reverse=True):
  394. ty = graph.nodes[p].get("entity_type")
  395. if not ty or len(ty2ents[ty]) > 12:
  396. continue
  397. ty2ents[ty].append(p)
  398. chunk = {
  399. "content_with_weight": json.dumps(ty2ents, ensure_ascii=False),
  400. "kb_id": kb_id,
  401. "knowledge_graph_kwd": "ty2ents",
  402. "available_int": 0
  403. }
  404. res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "size": 1, "fields": []},
  405. search.index_name(tenant_id), [kb_id])
  406. if res.ids:
  407. settings.docStoreConn.update({"knowledge_graph_kwd": "ty2ents"},
  408. chunk,
  409. search.index_name(tenant_id), kb_id)
  410. else:
  411. settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)
  412. def get_entity_type2sampels(idxnms, kb_ids: list):
  413. es_res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids,
  414. "size": 10000,
  415. "fields": ["content_with_weight"]},
  416. idxnms, kb_ids)
  417. res = defaultdict(list)
  418. for id in es_res.ids:
  419. smp = es_res.field[id].get("content_with_weight")
  420. if not smp:
  421. continue
  422. try:
  423. smp = json.loads(smp)
  424. except Exception as e:
  425. logging.exception(e)
  426. for ty, ents in smp.items():
  427. res[ty].extend(ents)
  428. return res
  429. def flat_uniq_list(arr, key):
  430. res = []
  431. for a in arr:
  432. a = a[key]
  433. if isinstance(a, list):
  434. res.extend(a)
  435. else:
  436. res.append(a)
  437. return list(set(res))
  438. def rebuild_graph(tenant_id, kb_id):
  439. graph = nx.Graph()
  440. src_ids = []
  441. flds = ["entity_kwd", "entity_type_kwd", "from_entity_kwd", "to_entity_kwd", "weight_int", "knowledge_graph_kwd", "source_id"]
  442. bs = 256
  443. for i in range(0, 10000000, bs):
  444. es_res = settings.docStoreConn.search(flds, [],
  445. {"kb_id": kb_id, "knowledge_graph_kwd": ["entity", "relation"]},
  446. [],
  447. OrderByExpr(),
  448. i, bs, search.index_name(tenant_id), [kb_id]
  449. )
  450. tot = settings.docStoreConn.getTotal(es_res)
  451. if tot == 0:
  452. return None, None
  453. es_res = settings.docStoreConn.getFields(es_res, flds)
  454. for id, d in es_res.items():
  455. src_ids.extend(d.get("source_id", []))
  456. if d["knowledge_graph_kwd"] == "entity":
  457. graph.add_node(d["entity_kwd"], entity_type=d["entity_type_kwd"])
  458. else:
  459. graph.add_edge(
  460. d["from_entity_kwd"],
  461. d["to_entity_kwd"],
  462. weight=int(d["weight_int"])
  463. )
  464. if len(es_res.keys()) < 128:
  465. return graph, list(set(src_ids))
  466. return graph, list(set(src_ids))