You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graphdata.py 9.1 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import random
  16. import pytest
  17. import numpy as np
  18. import mindspore.dataset as ds
  19. from mindspore import log as logger
  20. from mindspore.dataset.engine import SamplingStrategy
  21. DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
  22. SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns"
  23. def test_graphdata_getfullneighbor():
  24. """
  25. Test get all neighbors
  26. """
  27. logger.info('test get all neighbors.\n')
  28. g = ds.GraphData(DATASET_FILE, 2)
  29. nodes = g.get_all_nodes(1)
  30. assert len(nodes) == 10
  31. neighbor = g.get_all_neighbors(nodes, 2)
  32. assert neighbor.shape == (10, 6)
  33. row_tensor = g.get_node_feature(neighbor.tolist(), [2, 3])
  34. assert row_tensor[0].shape == (10, 6)
  35. def test_graphdata_getnodefeature_input_check():
  36. """
  37. Test get node feature input check
  38. """
  39. logger.info('test getnodefeature input check.\n')
  40. g = ds.GraphData(DATASET_FILE)
  41. with pytest.raises(TypeError):
  42. input_list = [1, [1, 1]]
  43. g.get_node_feature(input_list, [1])
  44. with pytest.raises(TypeError):
  45. input_list = [[1, 1], 1]
  46. g.get_node_feature(input_list, [1])
  47. with pytest.raises(TypeError):
  48. input_list = [[1, 1], [1, 1, 1]]
  49. g.get_node_feature(input_list, [1])
  50. with pytest.raises(TypeError):
  51. input_list = [[1, 1, 1], [1, 1]]
  52. g.get_node_feature(input_list, [1])
  53. with pytest.raises(TypeError):
  54. input_list = [[1, 1], [1, [1, 1]]]
  55. g.get_node_feature(input_list, [1])
  56. with pytest.raises(TypeError):
  57. input_list = [[1, 1], [[1, 1], 1]]
  58. g.get_node_feature(input_list, [1])
  59. with pytest.raises(TypeError):
  60. input_list = [[1, 1], [1, 1]]
  61. g.get_node_feature(input_list, 1)
  62. with pytest.raises(TypeError):
  63. input_list = [[1, 0.1], [1, 1]]
  64. g.get_node_feature(input_list, 1)
  65. with pytest.raises(TypeError):
  66. input_list = np.array([[1, 0.1], [1, 1]])
  67. g.get_node_feature(input_list, 1)
  68. with pytest.raises(TypeError):
  69. input_list = [[1, 1], [1, 1]]
  70. g.get_node_feature(input_list, ["a"])
  71. with pytest.raises(TypeError):
  72. input_list = [[1, 1], [1, 1]]
  73. g.get_node_feature(input_list, [1, "a"])
  74. def test_graphdata_getsampledneighbors():
  75. """
  76. Test sampled neighbors
  77. """
  78. logger.info('test get sampled neighbors.\n')
  79. g = ds.GraphData(DATASET_FILE, 1)
  80. edges = g.get_all_edges(0)
  81. nodes = g.get_nodes_from_edges(edges)
  82. assert len(nodes) == 40
  83. neighbor = g.get_sampled_neighbors(
  84. np.unique(nodes[0:21, 0]), [2, 3], [2, 1], SamplingStrategy.RANDOM)
  85. assert neighbor.shape == (10, 9)
  86. neighbor = g.get_sampled_neighbors(
  87. np.unique(nodes[0:21, 0]), [2, 3], [2, 1], SamplingStrategy.EDGE_WEIGHT)
  88. assert neighbor.shape == (10, 9)
  89. def test_graphdata_getnegsampledneighbors():
  90. """
  91. Test neg sampled neighbors
  92. """
  93. logger.info('test get negative sampled neighbors.\n')
  94. g = ds.GraphData(DATASET_FILE, 2)
  95. nodes = g.get_all_nodes(1)
  96. assert len(nodes) == 10
  97. neighbor = g.get_neg_sampled_neighbors(nodes, 5, 2)
  98. assert neighbor.shape == (10, 6)
  99. def test_graphdata_graphinfo():
  100. """
  101. Test graph info
  102. """
  103. logger.info('test graph info.\n')
  104. g = ds.GraphData(DATASET_FILE, 2)
  105. graph_info = g.graph_info()
  106. assert graph_info['node_type'] == [1, 2]
  107. assert graph_info['edge_type'] == [0]
  108. assert graph_info['node_num'] == {1: 10, 2: 10}
  109. assert graph_info['edge_num'] == {0: 40}
  110. assert graph_info['node_feature_type'] == [1, 2, 3, 4]
  111. assert graph_info['edge_feature_type'] == [1, 2]
  112. class RandomBatchedSampler(ds.Sampler):
  113. # RandomBatchedSampler generate random sequence without replacement in a batched manner
  114. def __init__(self, index_range, num_edges_per_sample):
  115. super().__init__()
  116. self.index_range = index_range
  117. self.num_edges_per_sample = num_edges_per_sample
  118. def __iter__(self):
  119. indices = [i+1 for i in range(self.index_range)]
  120. # Reset random seed here if necessary
  121. # random.seed(0)
  122. random.shuffle(indices)
  123. for i in range(0, self.index_range, self.num_edges_per_sample):
  124. # Drop reminder
  125. if i + self.num_edges_per_sample <= self.index_range:
  126. yield indices[i: i + self.num_edges_per_sample]
  127. class GNNGraphDataset():
  128. def __init__(self, g, batch_num):
  129. self.g = g
  130. self.batch_num = batch_num
  131. def __len__(self):
  132. # Total sample size of GNN dataset
  133. # In this case, the size should be total_num_edges/num_edges_per_sample
  134. return self.g.graph_info()['edge_num'][0] // self.batch_num
  135. def __getitem__(self, index):
  136. # index will be a list of indices yielded from RandomBatchedSampler
  137. # Fetch edges/nodes/samples/features based on indices
  138. nodes = self.g.get_nodes_from_edges(index.astype(np.int32))
  139. nodes = nodes[:, 0]
  140. neg_nodes = self.g.get_neg_sampled_neighbors(
  141. node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1)
  142. nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[
  143. 2, 2], neighbor_types=[2, 1])
  144. neg_nodes_neighbors = self.g.get_sampled_neighbors(
  145. node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], neighbor_types=[2, 2])
  146. nodes_neighbors_features = self.g.get_node_feature(
  147. node_list=nodes_neighbors, feature_types=[2, 3])
  148. neg_neighbors_features = self.g.get_node_feature(
  149. node_list=neg_nodes_neighbors, feature_types=[2, 3])
  150. return nodes_neighbors, neg_nodes_neighbors, nodes_neighbors_features[0], neg_neighbors_features[1]
  151. def test_graphdata_generatordataset():
  152. """
  153. Test generator dataset
  154. """
  155. logger.info('test generator dataset.\n')
  156. g = ds.GraphData(DATASET_FILE)
  157. batch_num = 2
  158. edge_num = g.graph_info()['edge_num'][0]
  159. out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"]
  160. dataset = ds.GeneratorDataset(source=GNNGraphDataset(g, batch_num), column_names=out_column_names,
  161. sampler=RandomBatchedSampler(edge_num, batch_num), num_parallel_workers=4)
  162. dataset = dataset.repeat(2)
  163. itr = dataset.create_dict_iterator(num_epochs=1, output_numpy=True)
  164. i = 0
  165. for data in itr:
  166. assert data['neighbors'].shape == (2, 7)
  167. assert data['neg_neighbors'].shape == (6, 7)
  168. assert data['neighbors_features'].shape == (2, 7)
  169. assert data['neg_neighbors_features'].shape == (6, 7)
  170. i += 1
  171. assert i == 40
  172. def test_graphdata_randomwalkdefault():
  173. """
  174. Test random walk defaults
  175. """
  176. logger.info('test randomwalk with default parameters.\n')
  177. g = ds.GraphData(SOCIAL_DATA_FILE, 1)
  178. nodes = g.get_all_nodes(1)
  179. assert len(nodes) == 33
  180. meta_path = [1 for _ in range(39)]
  181. walks = g.random_walk(nodes, meta_path)
  182. assert walks.shape == (33, 40)
  183. def test_graphdata_randomwalk():
  184. """
  185. Test random walk
  186. """
  187. logger.info('test random walk with given parameters.\n')
  188. g = ds.GraphData(SOCIAL_DATA_FILE, 1)
  189. nodes = g.get_all_nodes(1)
  190. assert len(nodes) == 33
  191. meta_path = [1 for _ in range(39)]
  192. walks = g.random_walk(nodes, meta_path, 2.0, 0.5, -1)
  193. assert walks.shape == (33, 40)
  194. def test_graphdata_getedgefeature():
  195. """
  196. Test get edge feature
  197. """
  198. logger.info('test get_edge_feature.\n')
  199. g = ds.GraphData(DATASET_FILE)
  200. edges = g.get_all_edges(0)
  201. features = g.get_edge_feature(edges, [1, 2])
  202. assert features[0].shape == (40,)
  203. assert features[1].shape == (40,)
  204. def test_graphdata_getedgesfromnodes():
  205. """
  206. Test get edges from nodes
  207. """
  208. logger.info('test get_edges_from_nodes\n')
  209. g = ds.GraphData(DATASET_FILE)
  210. nodes_pair_list = [(101, 201), (103, 207), (204, 105), (108, 208), (110, 210), (210, 110)]
  211. edges = g.get_edges_from_nodes(nodes_pair_list)
  212. assert edges.tolist() == [1, 9, 31, 17, 20, 40]
  213. if __name__ == '__main__':
  214. test_graphdata_getfullneighbor()
  215. test_graphdata_getnodefeature_input_check()
  216. test_graphdata_getsampledneighbors()
  217. test_graphdata_getnegsampledneighbors()
  218. test_graphdata_graphinfo()
  219. test_graphdata_generatordataset()
  220. test_graphdata_randomwalkdefault()
  221. test_graphdata_randomwalk()
  222. test_graphdata_getedgefeature()
  223. test_graphdata_getedgesfromnodes()