You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graphdata_distributed.py 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import os
  16. import random
  17. import time
  18. from multiprocessing import Process
  19. import numpy as np
  20. import mindspore.dataset as ds
  21. from mindspore import log as logger
  22. from mindspore.dataset.engine import SamplingStrategy
  23. from mindspore.dataset.engine import OutputFormat
  24. DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
  25. def graphdata_startserver(server_port):
  26. """
  27. start graphdata server
  28. """
  29. logger.info('test start server.\n')
  30. ds.GraphData(DATASET_FILE, 1, 'server', port=server_port)
  31. class RandomBatchedSampler(ds.Sampler):
  32. # RandomBatchedSampler generate random sequence without replacement in a batched manner
  33. def __init__(self, index_range, num_edges_per_sample):
  34. super().__init__()
  35. self.index_range = index_range
  36. self.num_edges_per_sample = num_edges_per_sample
  37. def __iter__(self):
  38. indices = [i+1 for i in range(self.index_range)]
  39. # Reset random seed here if necessary
  40. # random.seed(0)
  41. random.shuffle(indices)
  42. for i in range(0, self.index_range, self.num_edges_per_sample):
  43. # Drop reminder
  44. if i + self.num_edges_per_sample <= self.index_range:
  45. yield indices[i: i + self.num_edges_per_sample]
  46. class GNNGraphDataset():
  47. def __init__(self, g, batch_num):
  48. self.g = g
  49. self.batch_num = batch_num
  50. def __len__(self):
  51. # Total sample size of GNN dataset
  52. # In this case, the size should be total_num_edges/num_edges_per_sample
  53. return self.g.graph_info()['edge_num'][0] // self.batch_num
  54. def __getitem__(self, index):
  55. # index will be a list of indices yielded from RandomBatchedSampler
  56. # Fetch edges/nodes/samples/features based on indices
  57. nodes = self.g.get_nodes_from_edges(index.astype(np.int32))
  58. nodes = nodes[:, 0]
  59. neg_nodes = self.g.get_neg_sampled_neighbors(
  60. node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1)
  61. nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[
  62. 2, 2], neighbor_types=[2, 1], strategy=SamplingStrategy.RANDOM)
  63. neg_nodes_neighbors = self.g.get_sampled_neighbors(node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2],
  64. neighbor_types=[2, 1], strategy=SamplingStrategy.EDGE_WEIGHT)
  65. nodes_neighbors_features = self.g.get_node_feature(
  66. node_list=nodes_neighbors, feature_types=[2, 3])
  67. neg_neighbors_features = self.g.get_node_feature(
  68. node_list=neg_nodes_neighbors, feature_types=[2, 3])
  69. return nodes_neighbors, neg_nodes_neighbors, nodes_neighbors_features[0], neg_neighbors_features[1]
  70. def test_graphdata_distributed():
  71. """
  72. Test distributed
  73. """
  74. ASAN = os.environ.get('ASAN_OPTIONS')
  75. if ASAN:
  76. logger.info("skip the graphdata distributed when asan mode")
  77. return
  78. logger.info('test distributed.\n')
  79. server_port = random.randint(10000, 60000)
  80. p1 = Process(target=graphdata_startserver, args=(server_port,))
  81. p1.start()
  82. time.sleep(5)
  83. g = ds.GraphData(DATASET_FILE, 1, 'client', port=server_port)
  84. nodes = g.get_all_nodes(1)
  85. assert nodes.tolist() == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110]
  86. row_tensor = g.get_node_feature(nodes.tolist(), [1, 2, 3])
  87. assert row_tensor[0].tolist() == [[0, 1, 0, 0, 0], [1, 0, 0, 0, 1], [0, 0, 1, 1, 0], [0, 0, 0, 0, 0],
  88. [1, 1, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0], [0, 0, 0, 1, 1],
  89. [0, 1, 1, 0, 0], [0, 1, 0, 1, 0]]
  90. assert row_tensor[2].tolist() == [1, 2, 3, 1, 4, 3, 5, 3, 5, 4]
  91. neighbor_normal = g.get_all_neighbors(nodes, 2, OutputFormat.NORMAL)
  92. assert neighbor_normal.shape == (10, 6)
  93. neighbor_coo = g.get_all_neighbors(nodes, 2, OutputFormat.COO)
  94. assert neighbor_coo.shape == (20, 2)
  95. offset_table, neighbor_csr = g.get_all_neighbors(nodes, 2, OutputFormat.CSR)
  96. assert offset_table.shape == (10,)
  97. assert neighbor_csr.shape == (20,)
  98. edges = g.get_all_edges(0)
  99. assert edges.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
  100. 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]
  101. features = g.get_edge_feature(edges, [1, 2])
  102. assert features[0].tolist() == [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
  103. 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]
  104. nodes_pair_list = [(101, 201), (103, 207), (204, 105), (108, 208), (110, 210), (202, 102), (201, 107), (208, 108)]
  105. edges = g.get_edges_from_nodes(nodes_pair_list)
  106. assert edges.tolist() == [1, 9, 31, 17, 20, 25, 34, 37]
  107. batch_num = 2
  108. edge_num = g.graph_info()['edge_num'][0]
  109. out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"]
  110. dataset = ds.GeneratorDataset(source=GNNGraphDataset(g, batch_num), column_names=out_column_names,
  111. sampler=RandomBatchedSampler(edge_num, batch_num), num_parallel_workers=4,
  112. python_multiprocessing=False)
  113. dataset = dataset.repeat(2)
  114. itr = dataset.create_dict_iterator(num_epochs=1, output_numpy=True)
  115. i = 0
  116. for data in itr:
  117. assert data['neighbors'].shape == (2, 7)
  118. assert data['neg_neighbors'].shape == (6, 7)
  119. assert data['neighbors_features'].shape == (2, 7)
  120. assert data['neg_neighbors_features'].shape == (6, 7)
  121. i += 1
  122. assert i == 40
  123. if __name__ == '__main__':
  124. test_graphdata_distributed()