You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graphdata_distributed.py 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import os
  16. import random
  17. import time
  18. from multiprocessing import Process
  19. import numpy as np
  20. import mindspore.dataset as ds
  21. from mindspore import log as logger
  22. from mindspore.dataset.engine import SamplingStrategy
  23. DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
  24. def graphdata_startserver(server_port):
  25. """
  26. start graphdata server
  27. """
  28. logger.info('test start server.\n')
  29. ds.GraphData(DATASET_FILE, 1, 'server', port=server_port)
  30. class RandomBatchedSampler(ds.Sampler):
  31. # RandomBatchedSampler generate random sequence without replacement in a batched manner
  32. def __init__(self, index_range, num_edges_per_sample):
  33. super().__init__()
  34. self.index_range = index_range
  35. self.num_edges_per_sample = num_edges_per_sample
  36. def __iter__(self):
  37. indices = [i+1 for i in range(self.index_range)]
  38. # Reset random seed here if necessary
  39. # random.seed(0)
  40. random.shuffle(indices)
  41. for i in range(0, self.index_range, self.num_edges_per_sample):
  42. # Drop reminder
  43. if i + self.num_edges_per_sample <= self.index_range:
  44. yield indices[i: i + self.num_edges_per_sample]
  45. class GNNGraphDataset():
  46. def __init__(self, g, batch_num):
  47. self.g = g
  48. self.batch_num = batch_num
  49. def __len__(self):
  50. # Total sample size of GNN dataset
  51. # In this case, the size should be total_num_edges/num_edges_per_sample
  52. return self.g.graph_info()['edge_num'][0] // self.batch_num
  53. def __getitem__(self, index):
  54. # index will be a list of indices yielded from RandomBatchedSampler
  55. # Fetch edges/nodes/samples/features based on indices
  56. nodes = self.g.get_nodes_from_edges(index.astype(np.int32))
  57. nodes = nodes[:, 0]
  58. neg_nodes = self.g.get_neg_sampled_neighbors(
  59. node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1)
  60. nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[
  61. 2, 2], neighbor_types=[2, 1], strategy=SamplingStrategy.RANDOM)
  62. neg_nodes_neighbors = self.g.get_sampled_neighbors(node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2],
  63. neighbor_types=[2, 1], strategy=SamplingStrategy.EDGE_WEIGHT)
  64. nodes_neighbors_features = self.g.get_node_feature(
  65. node_list=nodes_neighbors, feature_types=[2, 3])
  66. neg_neighbors_features = self.g.get_node_feature(
  67. node_list=neg_nodes_neighbors, feature_types=[2, 3])
  68. return nodes_neighbors, neg_nodes_neighbors, nodes_neighbors_features[0], neg_neighbors_features[1]
  69. def test_graphdata_distributed():
  70. """
  71. Test distributed
  72. """
  73. ASAN = os.environ.get('ASAN_OPTIONS')
  74. if ASAN:
  75. logger.info("skip the graphdata distributed when asan mode")
  76. return
  77. logger.info('test distributed.\n')
  78. server_port = random.randint(10000, 60000)
  79. p1 = Process(target=graphdata_startserver, args=(server_port,))
  80. p1.start()
  81. time.sleep(5)
  82. g = ds.GraphData(DATASET_FILE, 1, 'client', port=server_port)
  83. nodes = g.get_all_nodes(1)
  84. assert nodes.tolist() == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110]
  85. row_tensor = g.get_node_feature(nodes.tolist(), [1, 2, 3])
  86. assert row_tensor[0].tolist() == [[0, 1, 0, 0, 0], [1, 0, 0, 0, 1], [0, 0, 1, 1, 0], [0, 0, 0, 0, 0],
  87. [1, 1, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0], [0, 0, 0, 1, 1],
  88. [0, 1, 1, 0, 0], [0, 1, 0, 1, 0]]
  89. assert row_tensor[2].tolist() == [1, 2, 3, 1, 4, 3, 5, 3, 5, 4]
  90. edges = g.get_all_edges(0)
  91. assert edges.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
  92. 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]
  93. features = g.get_edge_feature(edges, [1, 2])
  94. assert features[0].tolist() == [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
  95. 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]
  96. nodes_pair_list = [(101, 201), (103, 207), (204, 105), (108, 208), (110, 210), (202, 102), (201, 107), (208, 108)]
  97. edges = g.get_edges_from_nodes(nodes_pair_list)
  98. assert edges.tolist() == [1, 9, 31, 17, 20, 25, 34, 37]
  99. batch_num = 2
  100. edge_num = g.graph_info()['edge_num'][0]
  101. out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"]
  102. dataset = ds.GeneratorDataset(source=GNNGraphDataset(g, batch_num), column_names=out_column_names,
  103. sampler=RandomBatchedSampler(edge_num, batch_num), num_parallel_workers=4,
  104. python_multiprocessing=False)
  105. dataset = dataset.repeat(2)
  106. itr = dataset.create_dict_iterator(num_epochs=1, output_numpy=True)
  107. i = 0
  108. for data in itr:
  109. assert data['neighbors'].shape == (2, 7)
  110. assert data['neg_neighbors'].shape == (6, 7)
  111. assert data['neighbors_features'].shape == (2, 7)
  112. assert data['neg_neighbors_features'].shape == (6, 7)
  113. i += 1
  114. assert i == 40
  115. if __name__ == '__main__':
  116. test_graphdata_distributed()