You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graphdata.py 7.1 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import random
  16. import pytest
  17. import numpy as np
  18. import mindspore.dataset as ds
  19. from mindspore import log as logger
  20. DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
  21. def test_graphdata_getfullneighbor():
  22. g = ds.GraphData(DATASET_FILE, 2)
  23. nodes = g.get_all_nodes(1)
  24. assert len(nodes) == 10
  25. neighbor = g.get_all_neighbors(nodes, 2)
  26. assert neighbor.shape == (10, 6)
  27. row_tensor = g.get_node_feature(neighbor.tolist(), [2, 3])
  28. assert row_tensor[0].shape == (10, 6)
  29. def test_graphdata_getnodefeature_input_check():
  30. g = ds.GraphData(DATASET_FILE)
  31. with pytest.raises(TypeError):
  32. input_list = [1, [1, 1]]
  33. g.get_node_feature(input_list, [1])
  34. with pytest.raises(TypeError):
  35. input_list = [[1, 1], 1]
  36. g.get_node_feature(input_list, [1])
  37. with pytest.raises(TypeError):
  38. input_list = [[1, 1], [1, 1, 1]]
  39. g.get_node_feature(input_list, [1])
  40. with pytest.raises(TypeError):
  41. input_list = [[1, 1, 1], [1, 1]]
  42. g.get_node_feature(input_list, [1])
  43. with pytest.raises(TypeError):
  44. input_list = [[1, 1], [1, [1, 1]]]
  45. g.get_node_feature(input_list, [1])
  46. with pytest.raises(TypeError):
  47. input_list = [[1, 1], [[1, 1], 1]]
  48. g.get_node_feature(input_list, [1])
  49. with pytest.raises(TypeError):
  50. input_list = [[1, 1], [1, 1]]
  51. g.get_node_feature(input_list, 1)
  52. with pytest.raises(TypeError):
  53. input_list = [[1, 0.1], [1, 1]]
  54. g.get_node_feature(input_list, 1)
  55. with pytest.raises(TypeError):
  56. input_list = np.array([[1, 0.1], [1, 1]])
  57. g.get_node_feature(input_list, 1)
  58. with pytest.raises(TypeError):
  59. input_list = [[1, 1], [1, 1]]
  60. g.get_node_feature(input_list, ["a"])
  61. with pytest.raises(TypeError):
  62. input_list = [[1, 1], [1, 1]]
  63. g.get_node_feature(input_list, [1, "a"])
  64. def test_graphdata_getsampledneighbors():
  65. g = ds.GraphData(DATASET_FILE, 1)
  66. edges = g.get_all_edges(0)
  67. nodes = g.get_nodes_from_edges(edges)
  68. assert len(nodes) == 40
  69. neighbor = g.get_sampled_neighbors(
  70. np.unique(nodes[0:21, 0]), [2, 3], [2, 1])
  71. assert neighbor.shape == (10, 9)
  72. def test_graphdata_getnegsampledneighbors():
  73. g = ds.GraphData(DATASET_FILE, 2)
  74. nodes = g.get_all_nodes(1)
  75. assert len(nodes) == 10
  76. neighbor = g.get_neg_sampled_neighbors(nodes, 5, 2)
  77. assert neighbor.shape == (10, 6)
  78. def test_graphdata_graphinfo():
  79. g = ds.GraphData(DATASET_FILE, 2)
  80. graph_info = g.graph_info()
  81. assert graph_info['node_type'] == [1, 2]
  82. assert graph_info['edge_type'] == [0]
  83. assert graph_info['node_num'] == {1: 10, 2: 10}
  84. assert graph_info['edge_num'] == {0: 40}
  85. assert graph_info['node_feature_type'] == [1, 2, 3, 4]
  86. assert graph_info['edge_feature_type'] == []
  87. class RandomBatchedSampler(ds.Sampler):
  88. # RandomBatchedSampler generate random sequence without replacement in a batched manner
  89. def __init__(self, index_range, num_edges_per_sample):
  90. super().__init__()
  91. self.index_range = index_range
  92. self.num_edges_per_sample = num_edges_per_sample
  93. def __iter__(self):
  94. indices = [i+1 for i in range(self.index_range)]
  95. # Reset random seed here if necessary
  96. # random.seed(0)
  97. random.shuffle(indices)
  98. for i in range(0, self.index_range, self.num_edges_per_sample):
  99. # Drop reminder
  100. if i + self.num_edges_per_sample <= self.index_range:
  101. yield indices[i: i + self.num_edges_per_sample]
  102. class GNNGraphDataset():
  103. def __init__(self, g, batch_num):
  104. self.g = g
  105. self.batch_num = batch_num
  106. def __len__(self):
  107. # Total sample size of GNN dataset
  108. # In this case, the size should be total_num_edges/num_edges_per_sample
  109. return self.g.graph_info()['edge_num'][0] // self.batch_num
  110. def __getitem__(self, index):
  111. # index will be a list of indices yielded from RandomBatchedSampler
  112. # Fetch edges/nodes/samples/features based on indices
  113. nodes = self.g.get_nodes_from_edges(index.astype(np.int32))
  114. nodes = nodes[:, 0]
  115. neg_nodes = self.g.get_neg_sampled_neighbors(
  116. node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1)
  117. nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[
  118. 2, 2], neighbor_types=[2, 1])
  119. neg_nodes_neighbors = self.g.get_sampled_neighbors(
  120. node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], neighbor_types=[2, 2])
  121. nodes_neighbors_features = self.g.get_node_feature(
  122. node_list=nodes_neighbors, feature_types=[2, 3])
  123. neg_neighbors_features = self.g.get_node_feature(
  124. node_list=neg_nodes_neighbors, feature_types=[2, 3])
  125. return nodes_neighbors, neg_nodes_neighbors, nodes_neighbors_features[0], neg_neighbors_features[1]
  126. def test_graphdata_generatordataset():
  127. g = ds.GraphData(DATASET_FILE)
  128. batch_num = 2
  129. edge_num = g.graph_info()['edge_num'][0]
  130. out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"]
  131. dataset = ds.GeneratorDataset(source=GNNGraphDataset(g, batch_num), column_names=out_column_names,
  132. sampler=RandomBatchedSampler(edge_num, batch_num), num_parallel_workers=4)
  133. dataset = dataset.repeat(2)
  134. itr = dataset.create_dict_iterator()
  135. i = 0
  136. for data in itr:
  137. assert data['neighbors'].shape == (2, 7)
  138. assert data['neg_neighbors'].shape == (6, 7)
  139. assert data['neighbors_features'].shape == (2, 7)
  140. assert data['neg_neighbors_features'].shape == (6, 7)
  141. i += 1
  142. assert i == 40
  143. if __name__ == '__main__':
  144. test_graphdata_getfullneighbor()
  145. logger.info('test_graphdata_getfullneighbor Ended.\n')
  146. test_graphdata_getnodefeature_input_check()
  147. logger.info('test_graphdata_getnodefeature_input_check Ended.\n')
  148. test_graphdata_getsampledneighbors()
  149. logger.info('test_graphdata_getsampledneighbors Ended.\n')
  150. test_graphdata_getnegsampledneighbors()
  151. logger.info('test_graphdata_getnegsampledneighbors Ended.\n')
  152. test_graphdata_graphinfo()
  153. logger.info('test_graphdata_graphinfo Ended.\n')
  154. test_graphdata_generatordataset()
  155. logger.info('test_graphdata_generatordataset Ended.\n')