You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 22KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618
  1. # Copyright (c) 2024 Microsoft Corporation.
  2. # Licensed under the MIT License
  3. """
  4. Reference:
  5. - [graphrag](https://github.com/microsoft/graphrag)
  6. - [LightRag](https://github.com/HKUDS/LightRAG)
  7. """
  8. import html
  9. import json
  10. import logging
  11. import re
  12. import time
  13. from collections import defaultdict
  14. from hashlib import md5
  15. from typing import Any, Callable
  16. import os
  17. import trio
  18. from typing import Set, Tuple
  19. import networkx as nx
  20. import numpy as np
  21. import xxhash
  22. from networkx.readwrite import json_graph
  23. import dataclasses
  24. from api import settings
  25. from api.utils import get_uuid
  26. from rag.nlp import search, rag_tokenizer
  27. from rag.utils.doc_store_conn import OrderByExpr
  28. from rag.utils.redis_conn import REDIS_CONN
  29. GRAPH_FIELD_SEP = "<SEP>"
  30. ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
  31. chat_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 10)))
  32. @dataclasses.dataclass
  33. class GraphChange:
  34. removed_nodes: Set[str] = dataclasses.field(default_factory=set)
  35. added_updated_nodes: Set[str] = dataclasses.field(default_factory=set)
  36. removed_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set)
  37. added_updated_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set)
  38. def perform_variable_replacements(
  39. input: str, history: list[dict] | None = None, variables: dict | None = None
  40. ) -> str:
  41. """Perform variable replacements on the input string and in a chat log."""
  42. if history is None:
  43. history = []
  44. if variables is None:
  45. variables = {}
  46. result = input
  47. def replace_all(input: str) -> str:
  48. result = input
  49. for k, v in variables.items():
  50. result = result.replace(f"{{{k}}}", str(v))
  51. return result
  52. result = replace_all(result)
  53. for i, entry in enumerate(history):
  54. if entry.get("role") == "system":
  55. entry["content"] = replace_all(entry.get("content") or "")
  56. return result
  57. def clean_str(input: Any) -> str:
  58. """Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
  59. # If we get non-string input, just give it back
  60. if not isinstance(input, str):
  61. return input
  62. result = html.unescape(input.strip())
  63. # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
  64. return re.sub(r"[\"\x00-\x1f\x7f-\x9f]", "", result)
  65. def dict_has_keys_with_types(
  66. data: dict, expected_fields: list[tuple[str, type]]
  67. ) -> bool:
  68. """Return True if the given dictionary has the given keys with the given types."""
  69. for field, field_type in expected_fields:
  70. if field not in data:
  71. return False
  72. value = data[field]
  73. if not isinstance(value, field_type):
  74. return False
  75. return True
  76. def get_llm_cache(llmnm, txt, history, genconf):
  77. hasher = xxhash.xxh64()
  78. hasher.update(str(llmnm).encode("utf-8"))
  79. hasher.update(str(txt).encode("utf-8"))
  80. hasher.update(str(history).encode("utf-8"))
  81. hasher.update(str(genconf).encode("utf-8"))
  82. k = hasher.hexdigest()
  83. bin = REDIS_CONN.get(k)
  84. if not bin:
  85. return
  86. return bin
  87. def set_llm_cache(llmnm, txt, v, history, genconf):
  88. hasher = xxhash.xxh64()
  89. hasher.update(str(llmnm).encode("utf-8"))
  90. hasher.update(str(txt).encode("utf-8"))
  91. hasher.update(str(history).encode("utf-8"))
  92. hasher.update(str(genconf).encode("utf-8"))
  93. k = hasher.hexdigest()
  94. REDIS_CONN.set(k, v.encode("utf-8"), 24*3600)
  95. def get_embed_cache(llmnm, txt):
  96. hasher = xxhash.xxh64()
  97. hasher.update(str(llmnm).encode("utf-8"))
  98. hasher.update(str(txt).encode("utf-8"))
  99. k = hasher.hexdigest()
  100. bin = REDIS_CONN.get(k)
  101. if not bin:
  102. return
  103. return np.array(json.loads(bin))
  104. def set_embed_cache(llmnm, txt, arr):
  105. hasher = xxhash.xxh64()
  106. hasher.update(str(llmnm).encode("utf-8"))
  107. hasher.update(str(txt).encode("utf-8"))
  108. k = hasher.hexdigest()
  109. arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr)
  110. REDIS_CONN.set(k, arr.encode("utf-8"), 24*3600)
  111. def get_tags_from_cache(kb_ids):
  112. hasher = xxhash.xxh64()
  113. hasher.update(str(kb_ids).encode("utf-8"))
  114. k = hasher.hexdigest()
  115. bin = REDIS_CONN.get(k)
  116. if not bin:
  117. return
  118. return bin
  119. def set_tags_to_cache(kb_ids, tags):
  120. hasher = xxhash.xxh64()
  121. hasher.update(str(kb_ids).encode("utf-8"))
  122. k = hasher.hexdigest()
  123. REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
  124. def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True):
  125. """
  126. Ensure all nodes and edges in the graph have some essential attribute.
  127. """
  128. def is_valid_item(node_attrs: dict) -> bool:
  129. valid_node = True
  130. for attr in ["description", "source_id"]:
  131. if attr not in node_attrs:
  132. valid_node = False
  133. break
  134. return valid_node
  135. if check_attribute:
  136. purged_nodes = []
  137. for node, node_attrs in graph.nodes(data=True):
  138. if not is_valid_item(node_attrs):
  139. purged_nodes.append(node)
  140. for node in purged_nodes:
  141. graph.remove_node(node)
  142. if purged_nodes and callback:
  143. callback(msg=f"Purged {len(purged_nodes)} nodes from graph due to missing essential attributes.")
  144. purged_edges = []
  145. for source, target, attr in graph.edges(data=True):
  146. if check_attribute:
  147. if not is_valid_item(attr):
  148. purged_edges.append((source, target))
  149. if "keywords" not in attr:
  150. attr["keywords"] = []
  151. for source, target in purged_edges:
  152. graph.remove_edge(source, target)
  153. if purged_edges and callback:
  154. callback(msg=f"Purged {len(purged_edges)} edges from graph due to missing essential attributes.")
  155. def get_from_to(node1, node2):
  156. if node1 < node2:
  157. return (node1, node2)
  158. else:
  159. return (node2, node1)
  160. def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange):
  161. """Merge graph g2 into g1 in place."""
  162. for node_name, attr in g2.nodes(data=True):
  163. change.added_updated_nodes.add(node_name)
  164. if not g1.has_node(node_name):
  165. g1.add_node(node_name, **attr)
  166. continue
  167. node = g1.nodes[node_name]
  168. node["description"] += GRAPH_FIELD_SEP + attr["description"]
  169. # A node's source_id indicates which chunks it came from.
  170. node["source_id"] += attr["source_id"]
  171. for source, target, attr in g2.edges(data=True):
  172. change.added_updated_edges.add(get_from_to(source, target))
  173. edge = g1.get_edge_data(source, target)
  174. if edge is None:
  175. g1.add_edge(source, target, **attr)
  176. continue
  177. edge["weight"] += attr.get("weight", 0)
  178. edge["description"] += GRAPH_FIELD_SEP + attr["description"]
  179. edge["keywords"] += attr["keywords"]
  180. # A edge's source_id indicates which chunks it came from.
  181. edge["source_id"] += attr["source_id"]
  182. for node_degree in g1.degree:
  183. g1.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
  184. # A graph's source_id indicates which documents it came from.
  185. if "source_id" not in g1.graph:
  186. g1.graph["source_id"] = []
  187. g1.graph["source_id"] += g2.graph.get("source_id", [])
  188. return g1
  189. def compute_args_hash(*args):
  190. return md5(str(args).encode()).hexdigest()
  191. def handle_single_entity_extraction(
  192. record_attributes: list[str],
  193. chunk_key: str,
  194. ):
  195. if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
  196. return None
  197. # add this record as a node in the G
  198. entity_name = clean_str(record_attributes[1].upper())
  199. if not entity_name.strip():
  200. return None
  201. entity_type = clean_str(record_attributes[2].upper())
  202. entity_description = clean_str(record_attributes[3])
  203. entity_source_id = chunk_key
  204. return dict(
  205. entity_name=entity_name.upper(),
  206. entity_type=entity_type.upper(),
  207. description=entity_description,
  208. source_id=entity_source_id,
  209. )
  210. def handle_single_relationship_extraction(record_attributes: list[str], chunk_key: str):
  211. if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
  212. return None
  213. # add this record as edge
  214. source = clean_str(record_attributes[1].upper())
  215. target = clean_str(record_attributes[2].upper())
  216. edge_description = clean_str(record_attributes[3])
  217. edge_keywords = clean_str(record_attributes[4])
  218. edge_source_id = chunk_key
  219. weight = (
  220. float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
  221. )
  222. pair = sorted([source.upper(), target.upper()])
  223. return dict(
  224. src_id=pair[0],
  225. tgt_id=pair[1],
  226. weight=weight,
  227. description=edge_description,
  228. keywords=edge_keywords,
  229. source_id=edge_source_id,
  230. metadata={"created_at": time.time()},
  231. )
  232. def pack_user_ass_to_openai_messages(*args: str):
  233. roles = ["user", "assistant"]
  234. return [
  235. {"role": roles[i % 2], "content": content} for i, content in enumerate(args)
  236. ]
  237. def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
  238. """Split a string by multiple markers"""
  239. if not markers:
  240. return [content]
  241. results = re.split("|".join(re.escape(marker) for marker in markers), content)
  242. return [r.strip() for r in results if r.strip()]
  243. def is_float_regex(value):
  244. return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
  245. def chunk_id(chunk):
  246. return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()
  247. async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
  248. chunk = {
  249. "id": get_uuid(),
  250. "important_kwd": [ent_name],
  251. "title_tks": rag_tokenizer.tokenize(ent_name),
  252. "entity_kwd": ent_name,
  253. "knowledge_graph_kwd": "entity",
  254. "entity_type_kwd": meta["entity_type"],
  255. "content_with_weight": json.dumps(meta, ensure_ascii=False),
  256. "content_ltks": rag_tokenizer.tokenize(meta["description"]),
  257. "source_id": meta["source_id"],
  258. "kb_id": kb_id,
  259. "available_int": 0
  260. }
  261. chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
  262. ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
  263. if ebd is None:
  264. ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name]))
  265. ebd = ebd[0]
  266. set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
  267. assert ebd is not None
  268. chunk["q_%d_vec" % len(ebd)] = ebd
  269. chunks.append(chunk)
  270. def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
  271. ents = from_ent_name
  272. if isinstance(ents, str):
  273. ents = [from_ent_name]
  274. if isinstance(to_ent_name, str):
  275. to_ent_name = [to_ent_name]
  276. ents.extend(to_ent_name)
  277. ents = list(set(ents))
  278. conds = {
  279. "fields": ["content_with_weight"],
  280. "size": size,
  281. "from_entity_kwd": ents,
  282. "to_entity_kwd": ents,
  283. "knowledge_graph_kwd": ["relation"]
  284. }
  285. res = []
  286. es_res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
  287. for id in es_res.ids:
  288. try:
  289. if size == 1:
  290. return json.loads(es_res.field[id]["content_with_weight"])
  291. res.append(json.loads(es_res.field[id]["content_with_weight"]))
  292. except Exception:
  293. continue
  294. return res
  295. async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks):
  296. chunk = {
  297. "id": get_uuid(),
  298. "from_entity_kwd": from_ent_name,
  299. "to_entity_kwd": to_ent_name,
  300. "knowledge_graph_kwd": "relation",
  301. "content_with_weight": json.dumps(meta, ensure_ascii=False),
  302. "content_ltks": rag_tokenizer.tokenize(meta["description"]),
  303. "important_kwd": meta["keywords"],
  304. "source_id": meta["source_id"],
  305. "weight_int": int(meta["weight"]),
  306. "kb_id": kb_id,
  307. "available_int": 0
  308. }
  309. chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
  310. txt = f"{from_ent_name}->{to_ent_name}"
  311. ebd = get_embed_cache(embd_mdl.llm_name, txt)
  312. if ebd is None:
  313. ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt+f": {meta['description']}"]))
  314. ebd = ebd[0]
  315. set_embed_cache(embd_mdl.llm_name, txt, ebd)
  316. assert ebd is not None
  317. chunk["q_%d_vec" % len(ebd)] = ebd
  318. chunks.append(chunk)
  319. async def does_graph_contains(tenant_id, kb_id, doc_id):
  320. # Get doc_ids of graph
  321. fields = ["source_id"]
  322. condition = {
  323. "knowledge_graph_kwd": ["graph"],
  324. "removed_kwd": "N",
  325. }
  326. res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
  327. fields2 = settings.docStoreConn.getFields(res, fields)
  328. graph_doc_ids = set()
  329. for chunk_id in fields2.keys():
  330. graph_doc_ids = set(fields2[chunk_id]["source_id"])
  331. return doc_id in graph_doc_ids
  332. async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
  333. conds = {
  334. "fields": ["source_id"],
  335. "removed_kwd": "N",
  336. "size": 1,
  337. "knowledge_graph_kwd": ["graph"]
  338. }
  339. res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]))
  340. doc_ids = []
  341. if res.total == 0:
  342. return doc_ids
  343. for id in res.ids:
  344. doc_ids = res.field[id]["source_id"]
  345. return doc_ids
  346. async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
  347. conds = {
  348. "fields": ["content_with_weight", "removed_kwd", "source_id"],
  349. "size": 1,
  350. "knowledge_graph_kwd": ["graph"]
  351. }
  352. res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]))
  353. if not res.total == 0:
  354. for id in res.ids:
  355. try:
  356. if res.field[id]["removed_kwd"] == "N":
  357. g = json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges")
  358. if "source_id" not in g.graph:
  359. g.graph["source_id"] = res.field[id]["source_id"]
  360. else:
  361. g = await rebuild_graph(tenant_id, kb_id, exclude_rebuild)
  362. return g
  363. except Exception:
  364. continue
  365. result = None
  366. return result
  367. async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback):
  368. start = trio.current_time()
  369. await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id))
  370. if change.removed_nodes:
  371. await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id))
  372. if change.removed_edges:
  373. async with trio.open_nursery() as nursery:
  374. for from_node, to_node in change.removed_edges:
  375. nursery.start_soon(lambda from_node=from_node, to_node=to_node: trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id)))
  376. now = trio.current_time()
  377. if callback:
  378. callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.")
  379. start = now
  380. chunks = [{
  381. "id": get_uuid(),
  382. "content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False),
  383. "knowledge_graph_kwd": "graph",
  384. "kb_id": kb_id,
  385. "source_id": graph.graph.get("source_id", []),
  386. "available_int": 0,
  387. "removed_kwd": "N"
  388. }]
  389. # generate updated subgraphs
  390. for source in graph.graph["source_id"]:
  391. subgraph = graph.subgraph([n for n in graph.nodes if source in graph.nodes[n]["source_id"]]).copy()
  392. subgraph.graph["source_id"] = [source]
  393. for n in subgraph.nodes:
  394. subgraph.nodes[n]["source_id"] = [source]
  395. chunks.append({
  396. "id": get_uuid(),
  397. "content_with_weight": json.dumps(nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False),
  398. "knowledge_graph_kwd": "subgraph",
  399. "kb_id": kb_id,
  400. "source_id": [source],
  401. "available_int": 0,
  402. "removed_kwd": "N"
  403. })
  404. async with trio.open_nursery() as nursery:
  405. for node in change.added_updated_nodes:
  406. node_attrs = graph.nodes[node]
  407. nursery.start_soon(graph_node_to_chunk, kb_id, embd_mdl, node, node_attrs, chunks)
  408. for from_node, to_node in change.added_updated_edges:
  409. edge_attrs = graph.get_edge_data(from_node, to_node)
  410. if not edge_attrs:
  411. # added_updated_edges could record a non-existing edge if both from_node and to_node participate in nodes merging.
  412. continue
  413. nursery.start_soon(graph_edge_to_chunk, kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)
  414. now = trio.current_time()
  415. if callback:
  416. callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")
  417. start = now
  418. es_bulk_size = 4
  419. for b in range(0, len(chunks), es_bulk_size):
  420. doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(tenant_id), kb_id))
  421. if doc_store_result:
  422. error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
  423. raise Exception(error_message)
  424. now = trio.current_time()
  425. if callback:
  426. callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.")
  427. def is_continuous_subsequence(subseq, seq):
  428. def find_all_indexes(tup, value):
  429. indexes = []
  430. start = 0
  431. while True:
  432. try:
  433. index = tup.index(value, start)
  434. indexes.append(index)
  435. start = index + 1
  436. except ValueError:
  437. break
  438. return indexes
  439. index_list = find_all_indexes(seq,subseq[0])
  440. for idx in index_list:
  441. if idx!=len(seq)-1:
  442. if seq[idx+1]==subseq[-1]:
  443. return True
  444. return False
  445. def merge_tuples(list1, list2):
  446. result = []
  447. for tup in list1:
  448. last_element = tup[-1]
  449. if last_element in tup[:-1]:
  450. result.append(tup)
  451. else:
  452. matching_tuples = [t for t in list2 if t[0] == last_element]
  453. already_match_flag = 0
  454. for match in matching_tuples:
  455. matchh = (match[1], match[0])
  456. if is_continuous_subsequence(match, tup) or is_continuous_subsequence(matchh, tup):
  457. continue
  458. already_match_flag = 1
  459. merged_tuple = tup + match[1:]
  460. result.append(merged_tuple)
  461. if not already_match_flag:
  462. result.append(tup)
  463. return result
  464. async def get_entity_type2sampels(idxnms, kb_ids: list):
  465. es_res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids,
  466. "size": 10000,
  467. "fields": ["content_with_weight"]},
  468. idxnms, kb_ids))
  469. res = defaultdict(list)
  470. for id in es_res.ids:
  471. smp = es_res.field[id].get("content_with_weight")
  472. if not smp:
  473. continue
  474. try:
  475. smp = json.loads(smp)
  476. except Exception as e:
  477. logging.exception(e)
  478. for ty, ents in smp.items():
  479. res[ty].extend(ents)
  480. return res
  481. def flat_uniq_list(arr, key):
  482. res = []
  483. for a in arr:
  484. a = a[key]
  485. if isinstance(a, list):
  486. res.extend(a)
  487. else:
  488. res.append(a)
  489. return list(set(res))
  490. async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
  491. graph = nx.Graph()
  492. flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"]
  493. bs = 256
  494. for i in range(0, 1024*bs, bs):
  495. es_res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(flds, [],
  496. {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]},
  497. [],
  498. OrderByExpr(),
  499. i, bs, search.index_name(tenant_id), [kb_id]
  500. ))
  501. # tot = settings.docStoreConn.getTotal(es_res)
  502. es_res = settings.docStoreConn.getFields(es_res, flds)
  503. if len(es_res) == 0:
  504. break
  505. for id, d in es_res.items():
  506. assert d["knowledge_graph_kwd"] == "subgraph"
  507. if isinstance(exclude_rebuild, list):
  508. if sum([n in d["source_id"] for n in exclude_rebuild]):
  509. continue
  510. elif exclude_rebuild in d["source_id"]:
  511. continue
  512. next_graph = json_graph.node_link_graph(json.loads(d["content_with_weight"]), edges="edges")
  513. merged_graph = nx.compose(graph, next_graph)
  514. merged_source = {
  515. n: graph.nodes[n]["source_id"] + next_graph.nodes[n]["source_id"]
  516. for n in graph.nodes & next_graph.nodes
  517. }
  518. nx.set_node_attributes(merged_graph, merged_source, "source_id")
  519. if "source_id" in graph.graph:
  520. merged_graph.graph["source_id"] = graph.graph["source_id"] + next_graph.graph["source_id"]
  521. else:
  522. merged_graph.graph["source_id"] = next_graph.graph["source_id"]
  523. graph = merged_graph
  524. if len(graph.nodes) == 0:
  525. return None
  526. graph.graph["source_id"] = sorted(graph.graph["source_id"])
  527. return graph