選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

entity_embedding.py 1.8KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. # Copyright (c) 2024 Microsoft Corporation.
  2. # Licensed under the MIT License
  3. """
  4. Reference:
  5. - [graphrag](https://github.com/microsoft/graphrag)
  6. """
  7. from typing import Any
  8. import numpy as np
  9. import networkx as nx
  10. from graphrag.leiden import stable_largest_connected_component
  11. @dataclass
  12. class NodeEmbeddings:
  13. """Node embeddings class definition."""
  14. nodes: list[str]
  15. embeddings: np.ndarray
  16. def embed_nod2vec(
  17. graph: nx.Graph | nx.DiGraph,
  18. dimensions: int = 1536,
  19. num_walks: int = 10,
  20. walk_length: int = 40,
  21. window_size: int = 2,
  22. iterations: int = 3,
  23. random_seed: int = 86,
  24. ) -> NodeEmbeddings:
  25. """Generate node embeddings using Node2Vec."""
  26. # generate embedding
  27. lcc_tensors = gc.embed.node2vec_embed( # type: ignore
  28. graph=graph,
  29. dimensions=dimensions,
  30. window_size=window_size,
  31. iterations=iterations,
  32. num_walks=num_walks,
  33. walk_length=walk_length,
  34. random_seed=random_seed,
  35. )
  36. return NodeEmbeddings(embeddings=lcc_tensors[0], nodes=lcc_tensors[1])
  37. def run(graph: nx.Graph, args: dict[str, Any]) -> NodeEmbeddings:
  38. """Run method definition."""
  39. if args.get("use_lcc", True):
  40. graph = stable_largest_connected_component(graph)
  41. # create graph embedding using node2vec
  42. embeddings = embed_nod2vec(
  43. graph=graph,
  44. dimensions=args.get("dimensions", 1536),
  45. num_walks=args.get("num_walks", 10),
  46. walk_length=args.get("walk_length", 40),
  47. window_size=args.get("window_size", 2),
  48. iterations=args.get("iterations", 3),
  49. random_seed=args.get("random_seed", 86),
  50. )
  51. pairs = zip(embeddings.nodes, embeddings.embeddings.tolist(), strict=True)
  52. sorted_pairs = sorted(pairs, key=lambda x: x[0])
  53. return dict(sorted_pairs)