您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

leiden.py 5.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. """
  17. Reference:
  18. - [graphrag](https://github.com/microsoft/graphrag)
  19. """
  20. import logging
  21. from typing import Any, cast, List
  22. import html
  23. from graspologic.partition import hierarchical_leiden
  24. from graspologic.utils import largest_connected_component
  25. import networkx as nx
  26. from networkx import is_empty
  27. log = logging.getLogger(__name__)
  28. def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
  29. """Ensure an undirected graph with the same relationships will always be read the same way."""
  30. fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
  31. sorted_nodes = graph.nodes(data=True)
  32. sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
  33. fixed_graph.add_nodes_from(sorted_nodes)
  34. edges = list(graph.edges(data=True))
  35. # If the graph is undirected, we create the edges in a stable way, so we get the same results
  36. # for example:
  37. # A -> B
  38. # in graph theory is the same as
  39. # B -> A
  40. # in an undirected graph
  41. # however, this can lead to downstream issues because sometimes
  42. # consumers read graph.nodes() which ends up being [A, B] and sometimes it's [B, A]
  43. # but they base some of their logic on the order of the nodes, so the order ends up being important
  44. # so we sort the nodes in the edge in a stable way, so that we always get the same order
  45. if not graph.is_directed():
  46. def _sort_source_target(edge):
  47. source, target, edge_data = edge
  48. if source > target:
  49. temp = source
  50. source = target
  51. target = temp
  52. return source, target, edge_data
  53. edges = [_sort_source_target(edge) for edge in edges]
  54. def _get_edge_key(source: Any, target: Any) -> str:
  55. return f"{source} -> {target}"
  56. edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
  57. fixed_graph.add_edges_from(edges)
  58. return fixed_graph
  59. def normalize_node_names(graph: nx.Graph | nx.DiGraph) -> nx.Graph | nx.DiGraph:
  60. """Normalize node names."""
  61. node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
  62. return nx.relabel_nodes(graph, node_mapping)
  63. def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
  64. """Return the largest connected component of the graph, with nodes and edges sorted in a stable way."""
  65. graph = graph.copy()
  66. graph = cast(nx.Graph, largest_connected_component(graph))
  67. graph = normalize_node_names(graph)
  68. return _stabilize_graph(graph)
  69. def _compute_leiden_communities(
  70. graph: nx.Graph | nx.DiGraph,
  71. max_cluster_size: int,
  72. use_lcc: bool,
  73. seed=0xDEADBEEF,
  74. ) -> dict[int, dict[str, int]]:
  75. """Return Leiden root communities."""
  76. results: dict[int, dict[str, int]] = {}
  77. if is_empty(graph): return results
  78. if use_lcc:
  79. graph = stable_largest_connected_component(graph)
  80. community_mapping = hierarchical_leiden(
  81. graph, max_cluster_size=max_cluster_size, random_seed=seed
  82. )
  83. for partition in community_mapping:
  84. results[partition.level] = results.get(partition.level, {})
  85. results[partition.level][partition.node] = partition.cluster
  86. return results
  87. def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
  88. """Run method definition."""
  89. max_cluster_size = args.get("max_cluster_size", 12)
  90. use_lcc = args.get("use_lcc", True)
  91. if args.get("verbose", False):
  92. log.info(
  93. "Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc
  94. )
  95. if not graph.nodes(): return {}
  96. node_id_to_community_map = _compute_leiden_communities(
  97. graph=graph,
  98. max_cluster_size=max_cluster_size,
  99. use_lcc=use_lcc,
  100. seed=args.get("seed", 0xDEADBEEF),
  101. )
  102. levels = args.get("levels")
  103. # If they don't pass in levels, use them all
  104. if levels is None:
  105. levels = sorted(node_id_to_community_map.keys())
  106. results_by_level: dict[int, dict[str, list[str]]] = {}
  107. for level in levels:
  108. result = {}
  109. results_by_level[level] = result
  110. for node_id, raw_community_id in node_id_to_community_map[level].items():
  111. community_id = str(raw_community_id)
  112. if community_id not in result:
  113. result[community_id] = {"weight": 0, "nodes": []}
  114. result[community_id]["nodes"].append(node_id)
  115. result[community_id]["weight"] += graph.nodes[node_id].get("rank", 0) * graph.nodes[node_id].get("weight", 1)
  116. weights = [comm["weight"] for _, comm in result.items()]
  117. if not weights:continue
  118. max_weight = max(weights)
  119. for _, comm in result.items(): comm["weight"] /= max_weight
  120. return results_by_level
  121. def add_community_info2graph(graph: nx.Graph, commu_info: dict[str, dict[str, dict]]):
  122. for lev, cluster_info in commu_info.items():
  123. for cid, nodes in cluster_info.items():
  124. for n in nodes["nodes"]:
  125. if "community" not in graph.nodes[n]: graph.nodes[n]["community"] = {}
  126. graph.nodes[n]["community"].update({lev: cid})
  127. def add_community_info2graph(graph: nx.Graph, nodes: List[str], community_title):
  128. for n in nodes:
  129. if "communities" not in graph.nodes[n]:
  130. graph.nodes[n]["communities"] = []
  131. graph.nodes[n]["communities"].append(community_title)