import torch_geometric class GraphSAINTSamplerFactory: """ A simple factory class for creating varieties of :class:`torch_geometric.data.GraphSAINTSampler`. There exists potential sampling performance issues for the implementation of :class:`torch_geometric.data.GraphSAINTEdgeSampler` provided by PyTorch Geometric. Considering that the ultimate performance of GraphSAINT Edge Sampler and GraphSAINT Random Walk Sampler are similar according to the original literature `"GraphSAINT: Graph Sampling Based Inductive Learning Method" `_ which introduces the GraphSAINT approach, nevertheless, when the walk length for GraphSAINT Random Walk Sampler is specified as `2`, the Random walk operation is actually selecting edges. Therefore an effective implementation for GraphSAINT Edge Sampler is not very urgently needed. Meanwhile, the varieties of Subgraph-wise sampling is scheduled to be redesigned and refactored. With the aim of abstracting a unified sampling module for representative mainstream varieties of Node-wise Sampling, Layer-wise Sampling, and Subgraph-wise Sampling. """ @classmethod def create_node_sampler( cls, data, num_graphs_per_epoch: int, node_budget: int, sample_coverage_factor: int = 50, **kwargs ) -> torch_geometric.data.GraphSAINTNodeSampler: """ A simple static method for instantiating :class:`torch_geometric.data.GraphSAINTNodeSampler` with more explicit arguments. Arguments ------------ data: The conventional data of integral graph for sampling. num_graphs_per_epoch: number of subgraphs to sampler per epoch. node_budget: budget of nodes to sample for one sampled subgraph. sample_coverage_factor: The average number of samples per node should be used to compute normalization statistics. **kwargs: Additional optional arguments of :class:`torch.utils.data.DataLoader`, including :obj:`batch_size` or :obj:`num_workers`. Returns -------- Instance of :class:`torch_geometric.data.GraphSAINTNodeSampler`. """ return torch_geometric.data.GraphSAINTNodeSampler( data, node_budget, num_graphs_per_epoch, sample_coverage_factor, log=False, **kwargs ) @classmethod def create_edge_sampler( cls, data, num_graphs_per_epoch: int, edge_budget: int, sample_coverage_factor: int = 50, **kwargs ) -> torch_geometric.data.GraphSAINTEdgeSampler: """ A simple static method for instantiating :class:`torch_geometric.data.GraphSAINTEdgeSampler` with more explicit arguments. Arguments ------------ data: The conventional data of integral graph for sampling. num_graphs_per_epoch: number of subgraphs to sampler per epoch. edge_budget: budget of edges to sample for one sampled subgraph. sample_coverage_factor: The average number of samples per node should be used to compute normalization statistics. **kwargs: Additional optional arguments of :class:`torch.utils.data.DataLoader`, including :obj:`batch_size` or :obj:`num_workers`. Returns -------- Instance of :class:`torch_geometric.data.GraphSAINTEdgeSampler`. """ return torch_geometric.data.GraphSAINTEdgeSampler( data, edge_budget, num_graphs_per_epoch, sample_coverage_factor, log=False, **kwargs ) @classmethod def create_random_walk_sampler( cls, data, num_graphs_per_epoch: int, num_walks: int, walk_length: int, sample_coverage_factor: int = 50, **kwargs ) -> torch_geometric.data.GraphSAINTRandomWalkSampler: """ A simple static method for instantiating :class:`torch_geometric.data.GraphSAINTEdgeSampler` with more explicit arguments. Arguments ------------ data: The conventional data of integral graph for sampling. num_graphs_per_epoch: number of subgraphs to sampler per epoch. num_walks: The number of random walks for sampling. walk_length: The length of each random walk. sample_coverage_factor: The average number of samples per node should be used to compute normalization statistics. **kwargs: Additional optional arguments of :class:`torch.utils.data.DataLoader`, including :obj:`batch_size` or :obj:`num_workers`. Returns -------- Instance of :class:`torch_geometric.data.GraphSAINTEdgeSampler`. """ return torch_geometric.data.GraphSAINTRandomWalkSampler( data, num_walks, walk_length, num_graphs_per_epoch, sample_coverage_factor, log=False, **kwargs )