You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

entity_embedding.py 1.8KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. # Copyright (c) 2024 Microsoft Corporation.
  2. # Licensed under the MIT License
  3. """
  4. Reference:
  5. - [graphrag](https://github.com/microsoft/graphrag)
  6. """
  7. from typing import Any
  8. import numpy as np
  9. import networkx as nx
  10. from dataclasses import dataclass
  11. from graphrag.leiden import stable_largest_connected_component
  12. @dataclass
  13. class NodeEmbeddings:
  14. """Node embeddings class definition."""
  15. nodes: list[str]
  16. embeddings: np.ndarray
  17. def embed_nod2vec(
  18. graph: nx.Graph | nx.DiGraph,
  19. dimensions: int = 1536,
  20. num_walks: int = 10,
  21. walk_length: int = 40,
  22. window_size: int = 2,
  23. iterations: int = 3,
  24. random_seed: int = 86,
  25. ) -> NodeEmbeddings:
  26. """Generate node embeddings using Node2Vec."""
  27. # generate embedding
  28. lcc_tensors = gc.embed.node2vec_embed( # type: ignore
  29. graph=graph,
  30. dimensions=dimensions,
  31. window_size=window_size,
  32. iterations=iterations,
  33. num_walks=num_walks,
  34. walk_length=walk_length,
  35. random_seed=random_seed,
  36. )
  37. return NodeEmbeddings(embeddings=lcc_tensors[0], nodes=lcc_tensors[1])
  38. def run(graph: nx.Graph, args: dict[str, Any]) -> NodeEmbeddings:
  39. """Run method definition."""
  40. if args.get("use_lcc", True):
  41. graph = stable_largest_connected_component(graph)
  42. # create graph embedding using node2vec
  43. embeddings = embed_nod2vec(
  44. graph=graph,
  45. dimensions=args.get("dimensions", 1536),
  46. num_walks=args.get("num_walks", 10),
  47. walk_length=args.get("walk_length", 40),
  48. window_size=args.get("window_size", 2),
  49. iterations=args.get("iterations", 3),
  50. random_seed=args.get("random_seed", 86),
  51. )
  52. pairs = zip(embeddings.nodes, embeddings.embeddings.tolist(), strict=True)
  53. sorted_pairs = sorted(pairs, key=lambda x: x[0])
  54. return dict(sorted_pairs)