You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # Copyright (c) 2024 Microsoft Corporation.
  2. # Licensed under the MIT License
  3. """
  4. Reference:
  5. - [graphrag](https://github.com/microsoft/graphrag)
  6. """
  7. from typing import Any
  8. import numpy as np
  9. import networkx as nx
  10. from dataclasses import dataclass
  11. from graphrag.leiden import stable_largest_connected_component
  12. import graspologic as gc
  13. @dataclass
  14. class NodeEmbeddings:
  15. """Node embeddings class definition."""
  16. nodes: list[str]
  17. embeddings: np.ndarray
  18. def embed_nod2vec(
  19. graph: nx.Graph | nx.DiGraph,
  20. dimensions: int = 1536,
  21. num_walks: int = 10,
  22. walk_length: int = 40,
  23. window_size: int = 2,
  24. iterations: int = 3,
  25. random_seed: int = 86,
  26. ) -> NodeEmbeddings:
  27. """Generate node embeddings using Node2Vec."""
  28. # generate embedding
  29. lcc_tensors = gc.embed.node2vec_embed( # type: ignore
  30. graph=graph,
  31. dimensions=dimensions,
  32. window_size=window_size,
  33. iterations=iterations,
  34. num_walks=num_walks,
  35. walk_length=walk_length,
  36. random_seed=random_seed,
  37. )
  38. return NodeEmbeddings(embeddings=lcc_tensors[0], nodes=lcc_tensors[1])
  39. def run(graph: nx.Graph, args: dict[str, Any]) -> NodeEmbeddings:
  40. """Run method definition."""
  41. if args.get("use_lcc", True):
  42. graph = stable_largest_connected_component(graph)
  43. # create graph embedding using node2vec
  44. embeddings = embed_nod2vec(
  45. graph=graph,
  46. dimensions=args.get("dimensions", 1536),
  47. num_walks=args.get("num_walks", 10),
  48. walk_length=args.get("walk_length", 40),
  49. window_size=args.get("window_size", 2),
  50. iterations=args.get("iterations", 3),
  51. random_seed=args.get("random_seed", 86),
  52. )
  53. pairs = zip(embeddings.nodes, embeddings.embeddings.tolist(), strict=True)
  54. sorted_pairs = sorted(pairs, key=lambda x: x[0])
  55. return dict(sorted_pairs)