You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graphsaint_sampler.py 5.3 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. import torch_geometric
  2. class GraphSAINTSamplerFactory:
  3. """
  4. A simple factory class for creating varieties of
  5. :class:`torch_geometric.data.GraphSAINTSampler`.
  6. There exists potential sampling performance issues for
  7. the implementation of :class:`torch_geometric.data.GraphSAINTEdgeSampler`
  8. provided by PyTorch Geometric. Considering that the ultimate performance of
  9. GraphSAINT Edge Sampler and GraphSAINT Random Walk Sampler are similar
  10. according to the original literature
  11. `"GraphSAINT: Graph Sampling Based Inductive Learning Method"
  12. <https://arxiv.org/abs/1907.04931>`_ which introduces the GraphSAINT approach,
  13. nevertheless, when the walk length for GraphSAINT Random Walk Sampler is specified as `2`,
  14. the Random walk operation is actually selecting edges.
  15. Therefore an effective implementation for GraphSAINT Edge Sampler is not very urgently needed.
  16. Meanwhile, the varieties of Subgraph-wise sampling is scheduled to be redesigned and refactored.
  17. With the aim of abstracting a unified sampling module for representative mainstream varieties of
  18. Node-wise Sampling, Layer-wise Sampling, and Subgraph-wise Sampling.
  19. """
  20. @classmethod
  21. def create_node_sampler(
  22. cls,
  23. data,
  24. num_graphs_per_epoch: int,
  25. node_budget: int,
  26. sample_coverage_factor: int = 50,
  27. **kwargs
  28. ) -> torch_geometric.data.GraphSAINTNodeSampler:
  29. """
  30. A simple static method for instantiating :class:`torch_geometric.data.GraphSAINTNodeSampler`
  31. with more explicit arguments.
  32. Arguments
  33. ------------
  34. data:
  35. The conventional data of integral graph for sampling.
  36. num_graphs_per_epoch:
  37. number of subgraphs to sampler per epoch.
  38. node_budget:
  39. budget of nodes to sample for one sampled subgraph.
  40. sample_coverage_factor:
  41. The average number of samples per node should be used to
  42. compute normalization statistics.
  43. **kwargs:
  44. Additional optional arguments of :class:`torch.utils.data.DataLoader`,
  45. including :obj:`batch_size` or :obj:`num_workers`.
  46. Returns
  47. --------
  48. Instance of :class:`torch_geometric.data.GraphSAINTNodeSampler`.
  49. """
  50. return torch_geometric.data.GraphSAINTNodeSampler(
  51. data,
  52. node_budget,
  53. num_graphs_per_epoch,
  54. sample_coverage_factor,
  55. log=False,
  56. **kwargs
  57. )
  58. @classmethod
  59. def create_edge_sampler(
  60. cls,
  61. data,
  62. num_graphs_per_epoch: int,
  63. edge_budget: int,
  64. sample_coverage_factor: int = 50,
  65. **kwargs
  66. ) -> torch_geometric.data.GraphSAINTEdgeSampler:
  67. """
  68. A simple static method for instantiating :class:`torch_geometric.data.GraphSAINTEdgeSampler`
  69. with more explicit arguments.
  70. Arguments
  71. ------------
  72. data:
  73. The conventional data of integral graph for sampling.
  74. num_graphs_per_epoch:
  75. number of subgraphs to sampler per epoch.
  76. edge_budget:
  77. budget of edges to sample for one sampled subgraph.
  78. sample_coverage_factor:
  79. The average number of samples per node should be used to
  80. compute normalization statistics.
  81. **kwargs:
  82. Additional optional arguments of :class:`torch.utils.data.DataLoader`,
  83. including :obj:`batch_size` or :obj:`num_workers`.
  84. Returns
  85. --------
  86. Instance of :class:`torch_geometric.data.GraphSAINTEdgeSampler`.
  87. """
  88. return torch_geometric.data.GraphSAINTEdgeSampler(
  89. data,
  90. edge_budget,
  91. num_graphs_per_epoch,
  92. sample_coverage_factor,
  93. log=False,
  94. **kwargs
  95. )
  96. @classmethod
  97. def create_random_walk_sampler(
  98. cls,
  99. data,
  100. num_graphs_per_epoch: int,
  101. num_walks: int,
  102. walk_length: int,
  103. sample_coverage_factor: int = 50,
  104. **kwargs
  105. ) -> torch_geometric.data.GraphSAINTRandomWalkSampler:
  106. """
  107. A simple static method for instantiating :class:`torch_geometric.data.GraphSAINTEdgeSampler`
  108. with more explicit arguments.
  109. Arguments
  110. ------------
  111. data:
  112. The conventional data of integral graph for sampling.
  113. num_graphs_per_epoch:
  114. number of subgraphs to sampler per epoch.
  115. num_walks:
  116. The number of random walks for sampling.
  117. walk_length:
  118. The length of each random walk.
  119. sample_coverage_factor:
  120. The average number of samples per node should be used to
  121. compute normalization statistics.
  122. **kwargs:
  123. Additional optional arguments of :class:`torch.utils.data.DataLoader`,
  124. including :obj:`batch_size` or :obj:`num_workers`.
  125. Returns
  126. --------
  127. Instance of :class:`torch_geometric.data.GraphSAINTEdgeSampler`.
  128. """
  129. return torch_geometric.data.GraphSAINTRandomWalkSampler(
  130. data,
  131. num_walks,
  132. walk_length,
  133. num_graphs_per_epoch,
  134. sample_coverage_factor,
  135. log=False,
  136. **kwargs
  137. )