You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

entity_embedding.py 2.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. """
  17. Reference:
  18. - [graphrag](https://github.com/microsoft/graphrag)
  19. """
  20. from typing import Any
  21. import numpy as np
  22. import networkx as nx
  23. from graphrag.leiden import stable_largest_connected_component
  24. @dataclass
  25. class NodeEmbeddings:
  26. """Node embeddings class definition."""
  27. nodes: list[str]
  28. embeddings: np.ndarray
  29. def embed_nod2vec(
  30. graph: nx.Graph | nx.DiGraph,
  31. dimensions: int = 1536,
  32. num_walks: int = 10,
  33. walk_length: int = 40,
  34. window_size: int = 2,
  35. iterations: int = 3,
  36. random_seed: int = 86,
  37. ) -> NodeEmbeddings:
  38. """Generate node embeddings using Node2Vec."""
  39. # generate embedding
  40. lcc_tensors = gc.embed.node2vec_embed( # type: ignore
  41. graph=graph,
  42. dimensions=dimensions,
  43. window_size=window_size,
  44. iterations=iterations,
  45. num_walks=num_walks,
  46. walk_length=walk_length,
  47. random_seed=random_seed,
  48. )
  49. return NodeEmbeddings(embeddings=lcc_tensors[0], nodes=lcc_tensors[1])
  50. def run(graph: nx.Graph, args: dict[str, Any]) -> NodeEmbeddings:
  51. """Run method definition."""
  52. if args.get("use_lcc", True):
  53. graph = stable_largest_connected_component(graph)
  54. # create graph embedding using node2vec
  55. embeddings = embed_nod2vec(
  56. graph=graph,
  57. dimensions=args.get("dimensions", 1536),
  58. num_walks=args.get("num_walks", 10),
  59. walk_length=args.get("walk_length", 40),
  60. window_size=args.get("window_size", 2),
  61. iterations=args.get("iterations", 3),
  62. random_seed=args.get("random_seed", 86),
  63. )
  64. pairs = zip(embeddings.nodes, embeddings.embeddings.tolist(), strict=True)
  65. sorted_pairs = sorted(pairs, key=lambda x: x[0])
  66. return dict(sorted_pairs)