| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616 | 
							- # Copyright (c) 2024 Microsoft Corporation.
 - # Licensed under the MIT License
 - """
 - Reference:
 -  - [graphrag](https://github.com/microsoft/graphrag)
 -  - [LightRag](https://github.com/HKUDS/LightRAG)
 - """
 - 
 - import html
 - import json
 - import logging
 - import re
 - import time
 - from collections import defaultdict
 - from hashlib import md5
 - from typing import Any, Callable
 - import os
 - import trio
 - from typing import Set, Tuple
 - 
 - import networkx as nx
 - import numpy as np
 - import xxhash
 - from networkx.readwrite import json_graph
 - import dataclasses
 - 
 - from api import settings
 - from api.utils import get_uuid
 - from rag.nlp import search, rag_tokenizer
 - from rag.utils.doc_store_conn import OrderByExpr
 - from rag.utils.redis_conn import REDIS_CONN
 - 
 - GRAPH_FIELD_SEP = "<SEP>"
 - 
 - ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
 - 
 - chat_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 10)))
 - 
 - @dataclasses.dataclass
 - class GraphChange:
 -     removed_nodes: Set[str] = dataclasses.field(default_factory=set)
 -     added_updated_nodes: Set[str] = dataclasses.field(default_factory=set)
 -     removed_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set)
 -     added_updated_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set)
 - 
 - def perform_variable_replacements(
 -     input: str, history: list[dict] | None = None, variables: dict | None = None
 - ) -> str:
 -     """Perform variable replacements on the input string and in a chat log."""
 -     if history is None:
 -         history = []
 -     if variables is None:
 -         variables = {}
 -     result = input
 - 
 -     def replace_all(input: str) -> str:
 -         result = input
 -         for k, v in variables.items():
 -             result = result.replace(f"{{{k}}}", str(v))
 -         return result
 - 
 -     result = replace_all(result)
 -     for i, entry in enumerate(history):
 -         if entry.get("role") == "system":
 -             entry["content"] = replace_all(entry.get("content") or "")
 - 
 -     return result
 - 
 - 
 - def clean_str(input: Any) -> str:
 -     """Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
 -     # If we get non-string input, just give it back
 -     if not isinstance(input, str):
 -         return input
 - 
 -     result = html.unescape(input.strip())
 -     # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
 -     return re.sub(r"[\"\x00-\x1f\x7f-\x9f]", "", result)
 - 
 - 
 - def dict_has_keys_with_types(
 -     data: dict, expected_fields: list[tuple[str, type]]
 - ) -> bool:
 -     """Return True if the given dictionary has the given keys with the given types."""
 -     for field, field_type in expected_fields:
 -         if field not in data:
 -             return False
 - 
 -         value = data[field]
 -         if not isinstance(value, field_type):
 -             return False
 -     return True
 - 
 - 
 - def get_llm_cache(llmnm, txt, history, genconf):
 -     hasher = xxhash.xxh64()
 -     hasher.update(str(llmnm).encode("utf-8"))
 -     hasher.update(str(txt).encode("utf-8"))
 -     hasher.update(str(history).encode("utf-8"))
 -     hasher.update(str(genconf).encode("utf-8"))
 - 
 -     k = hasher.hexdigest()
 -     bin = REDIS_CONN.get(k)
 -     if not bin:
 -         return
 -     return bin
 - 
 - 
 - def set_llm_cache(llmnm, txt, v, history, genconf):
 -     hasher = xxhash.xxh64()
 -     hasher.update(str(llmnm).encode("utf-8"))
 -     hasher.update(str(txt).encode("utf-8"))
 -     hasher.update(str(history).encode("utf-8"))
 -     hasher.update(str(genconf).encode("utf-8"))
 - 
 -     k = hasher.hexdigest()
 -     REDIS_CONN.set(k, v.encode("utf-8"), 24*3600)
 - 
 - 
 - def get_embed_cache(llmnm, txt):
 -     hasher = xxhash.xxh64()
 -     hasher.update(str(llmnm).encode("utf-8"))
 -     hasher.update(str(txt).encode("utf-8"))
 - 
 -     k = hasher.hexdigest()
 -     bin = REDIS_CONN.get(k)
 -     if not bin:
 -         return
 -     return np.array(json.loads(bin))
 - 
 - 
 - def set_embed_cache(llmnm, txt, arr):
 -     hasher = xxhash.xxh64()
 -     hasher.update(str(llmnm).encode("utf-8"))
 -     hasher.update(str(txt).encode("utf-8"))
 - 
 -     k = hasher.hexdigest()
 -     arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr)
 -     REDIS_CONN.set(k, arr.encode("utf-8"), 24*3600)
 - 
 - 
 - def get_tags_from_cache(kb_ids):
 -     hasher = xxhash.xxh64()
 -     hasher.update(str(kb_ids).encode("utf-8"))
 - 
 -     k = hasher.hexdigest()
 -     bin = REDIS_CONN.get(k)
 -     if not bin:
 -         return
 -     return bin
 - 
 - 
 - def set_tags_to_cache(kb_ids, tags):
 -     hasher = xxhash.xxh64()
 -     hasher.update(str(kb_ids).encode("utf-8"))
 - 
 -     k = hasher.hexdigest()
 -     REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
 - 
 - def tidy_graph(graph: nx.Graph, callback):
 -     """
 -     Ensure all nodes and edges in the graph have some essential attribute.
 -     """
 -     def is_valid_node(node_attrs: dict) -> bool:
 -         valid_node = True
 -         for attr in ["description", "source_id"]:
 -             if attr not in node_attrs:
 -                 valid_node = False
 -                 break
 -         return valid_node
 -     purged_nodes = []
 -     for node, node_attrs in graph.nodes(data=True):
 -         if not is_valid_node(node_attrs):
 -             purged_nodes.append(node)
 -     for node in purged_nodes:
 -         graph.remove_node(node)
 -     if purged_nodes and callback:
 -         callback(msg=f"Purged {len(purged_nodes)} nodes from graph due to missing essential attributes.")
 - 
 -     purged_edges = []
 -     for source, target, attr in graph.edges(data=True):
 -         if not is_valid_node(attr):
 -             purged_edges.append((source, target))
 -         if "keywords" not in attr:
 -             attr["keywords"] = []
 -     for source, target in purged_edges:
 -         graph.remove_edge(source, target)
 -     if purged_edges and callback:
 -         callback(msg=f"Purged {len(purged_edges)} edges from graph due to missing essential attributes.")
 - 
 - def get_from_to(node1, node2):
 -     if node1 < node2:
 -         return (node1, node2)
 -     else:
 -         return (node2, node1)
 - 
 - def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange):
 -     """Merge graph g2 into g1 in place."""
 -     for node_name, attr in g2.nodes(data=True):
 -         change.added_updated_nodes.add(node_name)
 -         if not g1.has_node(node_name):
 -             g1.add_node(node_name, **attr)
 -             continue
 -         node = g1.nodes[node_name]
 -         node["description"] += GRAPH_FIELD_SEP + attr["description"]
 -         # A node's source_id indicates which chunks it came from.
 -         node["source_id"] += attr["source_id"]
 - 
 -     for source, target, attr in g2.edges(data=True):
 -         change.added_updated_edges.add(get_from_to(source, target))
 -         edge = g1.get_edge_data(source, target)
 -         if edge is None:
 -             g1.add_edge(source, target, **attr)
 -             continue
 -         edge["weight"] += attr.get("weight", 0)
 -         edge["description"] += GRAPH_FIELD_SEP + attr["description"]
 -         edge["keywords"] += attr["keywords"]
 -         # A edge's source_id indicates which chunks it came from.
 -         edge["source_id"] += attr["source_id"]
 - 
 -     for node_degree in g1.degree:
 -         g1.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
 -     # A graph's source_id indicates which documents it came from.
 -     if "source_id" not in g1.graph:
 -         g1.graph["source_id"] = []
 -     g1.graph["source_id"] += g2.graph.get("source_id", [])
 -     return g1
 - 
 - def compute_args_hash(*args):
 -     return md5(str(args).encode()).hexdigest()
 - 
 - 
 - def handle_single_entity_extraction(
 -     record_attributes: list[str],
 -     chunk_key: str,
 - ):
 -     if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
 -         return None
 -     # add this record as a node in the G
 -     entity_name = clean_str(record_attributes[1].upper())
 -     if not entity_name.strip():
 -         return None
 -     entity_type = clean_str(record_attributes[2].upper())
 -     entity_description = clean_str(record_attributes[3])
 -     entity_source_id = chunk_key
 -     return dict(
 -         entity_name=entity_name.upper(),
 -         entity_type=entity_type.upper(),
 -         description=entity_description,
 -         source_id=entity_source_id,
 -     )
 - 
 - 
 - def handle_single_relationship_extraction(record_attributes: list[str], chunk_key: str):
 -     if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
 -         return None
 -     # add this record as edge
 -     source = clean_str(record_attributes[1].upper())
 -     target = clean_str(record_attributes[2].upper())
 -     edge_description = clean_str(record_attributes[3])
 - 
 -     edge_keywords = clean_str(record_attributes[4])
 -     edge_source_id = chunk_key
 -     weight = (
 -         float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
 -     )
 -     pair = sorted([source.upper(), target.upper()])
 -     return dict(
 -         src_id=pair[0],
 -         tgt_id=pair[1],
 -         weight=weight,
 -         description=edge_description,
 -         keywords=edge_keywords,
 -         source_id=edge_source_id,
 -         metadata={"created_at": time.time()},
 -     )
 - 
 - 
 - def pack_user_ass_to_openai_messages(*args: str):
 -     roles = ["user", "assistant"]
 -     return [
 -         {"role": roles[i % 2], "content": content} for i, content in enumerate(args)
 -     ]
 - 
 - 
 - def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
 -     """Split a string by multiple markers"""
 -     if not markers:
 -         return [content]
 -     results = re.split("|".join(re.escape(marker) for marker in markers), content)
 -     return [r.strip() for r in results if r.strip()]
 - 
 - 
 - def is_float_regex(value):
 -     return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
 - 
 - 
 - def chunk_id(chunk):
 -     return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()
 - 
 - 
 - async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
 -     chunk = {
 -         "id": get_uuid(),
 -         "important_kwd": [ent_name],
 -         "title_tks": rag_tokenizer.tokenize(ent_name),
 -         "entity_kwd": ent_name,
 -         "knowledge_graph_kwd": "entity",
 -         "entity_type_kwd": meta["entity_type"],
 -         "content_with_weight": json.dumps(meta, ensure_ascii=False),
 -         "content_ltks": rag_tokenizer.tokenize(meta["description"]),
 -         "source_id": meta["source_id"],
 -         "kb_id": kb_id,
 -         "available_int": 0
 -     }
 -     chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
 -     ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
 -     if ebd is None:
 -         ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name]))
 -         ebd = ebd[0]
 -         set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
 -     assert ebd is not None
 -     chunk["q_%d_vec" % len(ebd)] = ebd
 -     chunks.append(chunk)
 - 
 - 
 - def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
 -     ents = from_ent_name
 -     if isinstance(ents, str):
 -         ents = [from_ent_name]
 -     if isinstance(to_ent_name, str):
 -         to_ent_name = [to_ent_name]
 -     ents.extend(to_ent_name)
 -     ents = list(set(ents))
 -     conds = {
 -         "fields": ["content_with_weight"],
 -         "size": size,
 -         "from_entity_kwd": ents,
 -         "to_entity_kwd": ents,
 -         "knowledge_graph_kwd": ["relation"]
 -     }
 -     res = []
 -     es_res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
 -     for id in es_res.ids:
 -         try:
 -             if size == 1:
 -                 return json.loads(es_res.field[id]["content_with_weight"])
 -             res.append(json.loads(es_res.field[id]["content_with_weight"]))
 -         except Exception:
 -             continue
 -     return res
 - 
 - 
 - async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks):
 -     chunk = {
 -         "id": get_uuid(),
 -         "from_entity_kwd": from_ent_name,
 -         "to_entity_kwd": to_ent_name,
 -         "knowledge_graph_kwd": "relation",
 -         "content_with_weight": json.dumps(meta, ensure_ascii=False),
 -         "content_ltks": rag_tokenizer.tokenize(meta["description"]),
 -         "important_kwd": meta["keywords"],
 -         "source_id": meta["source_id"],
 -         "weight_int": int(meta["weight"]),
 -         "kb_id": kb_id,
 -         "available_int": 0
 -     }
 -     chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
 -     txt = f"{from_ent_name}->{to_ent_name}"
 -     ebd = get_embed_cache(embd_mdl.llm_name, txt)
 -     if ebd is None:
 -         ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt+f": {meta['description']}"]))
 -         ebd = ebd[0]
 -         set_embed_cache(embd_mdl.llm_name, txt, ebd)
 -     assert ebd is not None
 -     chunk["q_%d_vec" % len(ebd)] = ebd
 -     chunks.append(chunk)
 - 
 - async def does_graph_contains(tenant_id, kb_id, doc_id):
 -     # Get doc_ids of graph
 -     fields = ["source_id"]
 -     condition = {
 -         "knowledge_graph_kwd": ["graph"],
 -         "removed_kwd": "N",
 -     }
 -     res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
 -     fields2 = settings.docStoreConn.getFields(res, fields)
 -     graph_doc_ids = set()
 -     for chunk_id in fields2.keys():
 -         graph_doc_ids = set(fields2[chunk_id]["source_id"])
 -     return doc_id in graph_doc_ids
 - 
 - async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
 -     conds = {
 -         "fields": ["source_id"],
 -         "removed_kwd": "N",
 -         "size": 1,
 -         "knowledge_graph_kwd": ["graph"]
 -     }
 -     res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]))
 -     doc_ids = []
 -     if res.total == 0:
 -         return doc_ids
 -     for id in res.ids:
 -         doc_ids = res.field[id]["source_id"]
 -     return doc_ids
 - 
 - 
 - async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
 -     conds = {
 -         "fields": ["content_with_weight", "removed_kwd", "source_id"],
 -         "size": 1,
 -         "knowledge_graph_kwd": ["graph"]
 -     }
 -     res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]))
 -     if not res.total == 0:
 -         for id in res.ids:
 -             try:
 -                 if res.field[id]["removed_kwd"] == "N":
 -                     g = json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges")
 -                     if "source_id" not in g.graph:
 -                         g.graph["source_id"] = res.field[id]["source_id"]
 -                 else:
 -                     g = await rebuild_graph(tenant_id, kb_id, exclude_rebuild)
 -                 return g
 -             except Exception:
 -                 continue
 -     result = None
 -     return result
 - 
 - 
 - async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback):
 -     start = trio.current_time()
 - 
 -     await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id))
 - 
 -     if change.removed_nodes:
 -         await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id))
 - 
 -     if change.removed_edges:
 -         async with trio.open_nursery() as nursery:
 -             for from_node, to_node in change.removed_edges:
 -                  nursery.start_soon(lambda from_node=from_node, to_node=to_node: trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id)))
 -     now = trio.current_time()
 -     if callback:
 -         callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.")
 -     start = now
 - 
 -     chunks = [{
 -         "id": get_uuid(),
 -         "content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False),
 -         "knowledge_graph_kwd": "graph",
 -         "kb_id": kb_id,
 -         "source_id": graph.graph.get("source_id", []),
 -         "available_int": 0,
 -         "removed_kwd": "N"
 -     }]
 -     
 -     # generate updated subgraphs
 -     for source in graph.graph["source_id"]:
 -         subgraph = graph.subgraph([n for n in graph.nodes if source in graph.nodes[n]["source_id"]]).copy()
 -         subgraph.graph["source_id"] = [source]
 -         for n in subgraph.nodes:
 -             subgraph.nodes[n]["source_id"] = [source]
 -         chunks.append({
 -             "id": get_uuid(),
 -             "content_with_weight": json.dumps(nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False),
 -             "knowledge_graph_kwd": "subgraph",
 -             "kb_id": kb_id,
 -             "source_id": [source],
 -             "available_int": 0,
 -             "removed_kwd": "N"
 -         })
 -     
 -     async with trio.open_nursery() as nursery:
 -         for node in change.added_updated_nodes:
 -             node_attrs = graph.nodes[node]
 -             nursery.start_soon(graph_node_to_chunk, kb_id, embd_mdl, node, node_attrs, chunks)
 -         for from_node, to_node in change.added_updated_edges:
 -             edge_attrs = graph.get_edge_data(from_node, to_node)
 -             if not edge_attrs:
 -                 # added_updated_edges could record a non-existing edge if both from_node and to_node participate in nodes merging.
 -                 continue
 -             nursery.start_soon(graph_edge_to_chunk, kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)
 -     now = trio.current_time()
 -     if callback:
 -         callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")
 -     start = now
 - 
 -     es_bulk_size = 4
 -     for b in range(0, len(chunks), es_bulk_size):
 -         doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(tenant_id), kb_id))
 -         if doc_store_result:
 -             error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
 -             raise Exception(error_message)
 -     now = trio.current_time()
 -     if callback:
 -         callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.")
 - 
 - 
 - def is_continuous_subsequence(subseq, seq):
 -     def find_all_indexes(tup, value):
 -         indexes = []
 -         start = 0
 -         while True:
 -             try:
 -                 index = tup.index(value, start)
 -                 indexes.append(index)
 -                 start = index + 1
 -             except ValueError:
 -                 break
 -         return indexes
 - 
 -     index_list = find_all_indexes(seq,subseq[0])
 -     for idx in index_list:
 -         if idx!=len(seq)-1:
 -             if seq[idx+1]==subseq[-1]:
 -                 return True
 -     return False
 - 
 - 
 - def merge_tuples(list1, list2):
 -     result = []
 -     for tup in list1:
 -         last_element = tup[-1]
 -         if last_element in tup[:-1]:
 -             result.append(tup)
 -         else:
 -             matching_tuples = [t for t in list2 if t[0] == last_element]
 -             already_match_flag = 0
 -             for match in matching_tuples:
 -                 matchh = (match[1], match[0])
 -                 if is_continuous_subsequence(match, tup) or is_continuous_subsequence(matchh, tup):
 -                     continue
 -                 already_match_flag = 1
 -                 merged_tuple = tup + match[1:]
 -                 result.append(merged_tuple)
 -             if not already_match_flag:
 -                 result.append(tup)
 -     return result
 - 
 - 
 - async def get_entity_type2sampels(idxnms, kb_ids: list):
 -     es_res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids,
 -                                        "size": 10000,
 -                                        "fields": ["content_with_weight"]},
 -                                       idxnms, kb_ids))
 - 
 -     res = defaultdict(list)
 -     for id in es_res.ids:
 -         smp = es_res.field[id].get("content_with_weight")
 -         if not smp:
 -             continue
 -         try:
 -             smp = json.loads(smp)
 -         except Exception as e:
 -             logging.exception(e)
 - 
 -         for ty, ents in smp.items():
 -             res[ty].extend(ents)
 -     return res
 - 
 - 
 - def flat_uniq_list(arr, key):
 -     res = []
 -     for a in arr:
 -         a = a[key]
 -         if isinstance(a, list):
 -             res.extend(a)
 -         else:
 -             res.append(a)
 -     return list(set(res))
 - 
 - 
 - async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
 -     graph = nx.Graph()
 -     flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"]
 -     bs = 256
 -     for i in range(0, 1024*bs, bs):
 -         es_res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(flds, [],
 -                                  {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]},
 -                                  [],
 -                                  OrderByExpr(),
 -                                  i, bs, search.index_name(tenant_id), [kb_id]
 -                                  ))
 -         # tot = settings.docStoreConn.getTotal(es_res)
 -         es_res = settings.docStoreConn.getFields(es_res, flds)
 - 
 -         if len(es_res) == 0:
 -             break
 - 
 -         for id, d in es_res.items():
 -             assert d["knowledge_graph_kwd"] == "subgraph"
 -             if isinstance(exclude_rebuild, list):
 -                 if sum([n in d["source_id"] for n in exclude_rebuild]):
 -                     continue
 -             elif exclude_rebuild in d["source_id"]:
 -                 continue
 -             
 -             next_graph = json_graph.node_link_graph(json.loads(d["content_with_weight"]), edges="edges")
 -             merged_graph = nx.compose(graph, next_graph)
 -             merged_source = {
 -                 n: graph.nodes[n]["source_id"] + next_graph.nodes[n]["source_id"]
 -                 for n in graph.nodes & next_graph.nodes
 -             }
 -             nx.set_node_attributes(merged_graph, merged_source, "source_id")
 -             if "source_id" in graph.graph:
 -                 merged_graph.graph["source_id"] = graph.graph["source_id"] + next_graph.graph["source_id"]
 -             else:
 -                 merged_graph.graph["source_id"] = next_graph.graph["source_id"]
 -             graph = merged_graph
 - 
 -     if len(graph.nodes) == 0:
 -         return None
 -     graph.graph["source_id"] = sorted(graph.graph["source_id"])
 -     return graph
 
 
  |