| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 | 
							- # Copyright (c) 2024 Microsoft Corporation.
 - # Licensed under the MIT License
 - """
 - Reference:
 -  - [graphrag](https://github.com/microsoft/graphrag)
 - """
 - 
 - from typing import Any
 - import numpy as np
 - import networkx as nx
 - from dataclasses import dataclass
 - from graphrag.leiden import stable_largest_connected_component
 - import graspologic as gc
 - 
 - 
 - @dataclass
 - class NodeEmbeddings:
 -     """Node embeddings class definition."""
 - 
 -     nodes: list[str]
 -     embeddings: np.ndarray
 - 
 - 
 - def embed_nod2vec(
 -     graph: nx.Graph | nx.DiGraph,
 -     dimensions: int = 1536,
 -     num_walks: int = 10,
 -     walk_length: int = 40,
 -     window_size: int = 2,
 -     iterations: int = 3,
 -     random_seed: int = 86,
 - ) -> NodeEmbeddings:
 -     """Generate node embeddings using Node2Vec."""
 -     # generate embedding
 -     lcc_tensors = gc.embed.node2vec_embed(  # type: ignore
 -         graph=graph,
 -         dimensions=dimensions,
 -         window_size=window_size,
 -         iterations=iterations,
 -         num_walks=num_walks,
 -         walk_length=walk_length,
 -         random_seed=random_seed,
 -     )
 -     return NodeEmbeddings(embeddings=lcc_tensors[0], nodes=lcc_tensors[1])
 - 
 - 
 - def run(graph: nx.Graph, args: dict[str, Any]) -> NodeEmbeddings:
 -     """Run method definition."""
 -     if args.get("use_lcc", True):
 -         graph = stable_largest_connected_component(graph)
 - 
 -     # create graph embedding using node2vec
 -     embeddings = embed_nod2vec(
 -         graph=graph,
 -         dimensions=args.get("dimensions", 1536),
 -         num_walks=args.get("num_walks", 10),
 -         walk_length=args.get("walk_length", 40),
 -         window_size=args.get("window_size", 2),
 -         iterations=args.get("iterations", 3),
 -         random_seed=args.get("random_seed", 86),
 -     )
 - 
 -     pairs = zip(embeddings.nodes, embeddings.embeddings.tolist(), strict=True)
 -     sorted_pairs = sorted(pairs, key=lambda x: x[0])
 - 
 -     return dict(sorted_pairs)
 
 
  |