You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graphdata_distributed.py 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import os
  16. import random
  17. import time
  18. from multiprocessing import Process
  19. import numpy as np
  20. import mindspore.dataset as ds
  21. from mindspore import log as logger
  22. DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
  23. def graphdata_startserver(server_port):
  24. """
  25. start graphdata server
  26. """
  27. logger.info('test start server.\n')
  28. ds.GraphData(DATASET_FILE, 1, 'server', port=server_port)
  29. class RandomBatchedSampler(ds.Sampler):
  30. # RandomBatchedSampler generate random sequence without replacement in a batched manner
  31. def __init__(self, index_range, num_edges_per_sample):
  32. super().__init__()
  33. self.index_range = index_range
  34. self.num_edges_per_sample = num_edges_per_sample
  35. def __iter__(self):
  36. indices = [i+1 for i in range(self.index_range)]
  37. # Reset random seed here if necessary
  38. # random.seed(0)
  39. random.shuffle(indices)
  40. for i in range(0, self.index_range, self.num_edges_per_sample):
  41. # Drop reminder
  42. if i + self.num_edges_per_sample <= self.index_range:
  43. yield indices[i: i + self.num_edges_per_sample]
  44. class GNNGraphDataset():
  45. def __init__(self, g, batch_num):
  46. self.g = g
  47. self.batch_num = batch_num
  48. def __len__(self):
  49. # Total sample size of GNN dataset
  50. # In this case, the size should be total_num_edges/num_edges_per_sample
  51. return self.g.graph_info()['edge_num'][0] // self.batch_num
  52. def __getitem__(self, index):
  53. # index will be a list of indices yielded from RandomBatchedSampler
  54. # Fetch edges/nodes/samples/features based on indices
  55. nodes = self.g.get_nodes_from_edges(index.astype(np.int32))
  56. nodes = nodes[:, 0]
  57. neg_nodes = self.g.get_neg_sampled_neighbors(
  58. node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1)
  59. nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[
  60. 2, 2], neighbor_types=[2, 1])
  61. neg_nodes_neighbors = self.g.get_sampled_neighbors(
  62. node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], neighbor_types=[2, 2])
  63. nodes_neighbors_features = self.g.get_node_feature(
  64. node_list=nodes_neighbors, feature_types=[2, 3])
  65. neg_neighbors_features = self.g.get_node_feature(
  66. node_list=neg_nodes_neighbors, feature_types=[2, 3])
  67. return nodes_neighbors, neg_nodes_neighbors, nodes_neighbors_features[0], neg_neighbors_features[1]
  68. def test_graphdata_distributed():
  69. """
  70. Test distributed
  71. """
  72. ASAN = os.environ.get('ASAN_OPTIONS')
  73. if ASAN:
  74. logger.info("skip the graphdata distributed when asan mode")
  75. return
  76. logger.info('test distributed.\n')
  77. server_port = random.randint(10000, 60000)
  78. p1 = Process(target=graphdata_startserver, args=(server_port,))
  79. p1.start()
  80. time.sleep(5)
  81. g = ds.GraphData(DATASET_FILE, 1, 'client', port=server_port)
  82. nodes = g.get_all_nodes(1)
  83. assert nodes.tolist() == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110]
  84. row_tensor = g.get_node_feature(nodes.tolist(), [1, 2, 3])
  85. assert row_tensor[0].tolist() == [[0, 1, 0, 0, 0], [1, 0, 0, 0, 1], [0, 0, 1, 1, 0], [0, 0, 0, 0, 0],
  86. [1, 1, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0], [0, 0, 0, 1, 1],
  87. [0, 1, 1, 0, 0], [0, 1, 0, 1, 0]]
  88. assert row_tensor[2].tolist() == [1, 2, 3, 1, 4, 3, 5, 3, 5, 4]
  89. edges = g.get_all_edges(0)
  90. assert edges.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
  91. 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]
  92. features = g.get_edge_feature(edges, [1, 2])
  93. assert features[0].tolist() == [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
  94. 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]
  95. batch_num = 2
  96. edge_num = g.graph_info()['edge_num'][0]
  97. out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"]
  98. dataset = ds.GeneratorDataset(source=GNNGraphDataset(g, batch_num), column_names=out_column_names,
  99. sampler=RandomBatchedSampler(edge_num, batch_num), num_parallel_workers=4,
  100. python_multiprocessing=False)
  101. dataset = dataset.repeat(2)
  102. itr = dataset.create_dict_iterator(num_epochs=1, output_numpy=True)
  103. i = 0
  104. for data in itr:
  105. assert data['neighbors'].shape == (2, 7)
  106. assert data['neg_neighbors'].shape == (6, 7)
  107. assert data['neighbors_features'].shape == (2, 7)
  108. assert data['neg_neighbors_features'].shape == (6, 7)
  109. i += 1
  110. assert i == 40
  111. if __name__ == '__main__':
  112. test_graphdata_distributed()