You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graphdata.py 11 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import random
  16. import pytest
  17. import numpy as np
  18. import mindspore.dataset as ds
  19. from mindspore import log as logger
  20. from mindspore.dataset.engine import SamplingStrategy
  21. from mindspore.dataset.engine import OutputFormat
  22. DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
  23. SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns"
  24. def test_graphdata_getfullneighbor():
  25. """
  26. Test get all neighbors
  27. """
  28. logger.info('test get all neighbors.\n')
  29. g = ds.GraphData(DATASET_FILE, 2)
  30. nodes = g.get_all_nodes(1)
  31. assert len(nodes) == 10
  32. neighbor = g.get_all_neighbors(nodes, 2)
  33. assert neighbor.shape == (10, 6)
  34. row_tensor = g.get_node_feature(neighbor.tolist(), [2, 3])
  35. assert row_tensor[0].shape == (10, 6)
  36. def test_graphdata_getallneighbors_special_format():
  37. """
  38. Test get all neighbors with special format
  39. """
  40. logger.info('test get all neighbors with special format.\n')
  41. g = ds.GraphData(DATASET_FILE, 2)
  42. nodes = g.get_all_nodes(1)
  43. assert len(nodes) == 10
  44. neighbor_coo = g.get_all_neighbors(nodes, 2, OutputFormat.COO)
  45. assert neighbor_coo.shape == (20, 2)
  46. offset_table, neighbor_csr = g.get_all_neighbors(nodes, 2, OutputFormat.CSR)
  47. assert offset_table.shape == (10,)
  48. assert neighbor_csr.shape == (20,)
  49. def test_graphdata_getnodefeature_input_check():
  50. """
  51. Test get node feature input check
  52. """
  53. logger.info('test getnodefeature input check.\n')
  54. g = ds.GraphData(DATASET_FILE)
  55. with pytest.raises(TypeError):
  56. input_list = [1, [1, 1]]
  57. g.get_node_feature(input_list, [1])
  58. with pytest.raises(TypeError):
  59. input_list = [[1, 1], 1]
  60. g.get_node_feature(input_list, [1])
  61. with pytest.raises(TypeError):
  62. input_list = [[1, 1], [1, 1, 1]]
  63. g.get_node_feature(input_list, [1])
  64. with pytest.raises(TypeError):
  65. input_list = [[1, 1, 1], [1, 1]]
  66. g.get_node_feature(input_list, [1])
  67. with pytest.raises(TypeError):
  68. input_list = [[1, 1], [1, [1, 1]]]
  69. g.get_node_feature(input_list, [1])
  70. with pytest.raises(TypeError):
  71. input_list = [[1, 1], [[1, 1], 1]]
  72. g.get_node_feature(input_list, [1])
  73. with pytest.raises(TypeError):
  74. input_list = [[1, 1], [1, 1]]
  75. g.get_node_feature(input_list, 1)
  76. with pytest.raises(TypeError):
  77. input_list = [[1, 0.1], [1, 1]]
  78. g.get_node_feature(input_list, 1)
  79. with pytest.raises(TypeError):
  80. input_list = np.array([[1, 0.1], [1, 1]])
  81. g.get_node_feature(input_list, 1)
  82. with pytest.raises(TypeError):
  83. input_list = [[1, 1], [1, 1]]
  84. g.get_node_feature(input_list, ["a"])
  85. with pytest.raises(TypeError):
  86. input_list = [[1, 1], [1, 1]]
  87. g.get_node_feature(input_list, [1, "a"])
  88. def test_graphdata_getsampledneighbors():
  89. """
  90. Test sampled neighbors
  91. """
  92. logger.info('test get sampled neighbors.\n')
  93. g = ds.GraphData(DATASET_FILE, 1)
  94. edges = g.get_all_edges(0)
  95. nodes = g.get_nodes_from_edges(edges)
  96. assert len(nodes) == 40
  97. neighbor = g.get_sampled_neighbors(
  98. np.unique(nodes[0:21, 0]), [2, 3], [2, 1], SamplingStrategy.RANDOM)
  99. assert neighbor.shape == (10, 9)
  100. neighbor = g.get_sampled_neighbors(
  101. np.unique(nodes[0:21, 0]), [2, 3], [2, 1], SamplingStrategy.EDGE_WEIGHT)
  102. assert neighbor.shape == (10, 9)
  103. def test_graphdata_getnegsampledneighbors():
  104. """
  105. Test neg sampled neighbors
  106. """
  107. logger.info('test get negative sampled neighbors.\n')
  108. g = ds.GraphData(DATASET_FILE, 2)
  109. nodes = g.get_all_nodes(1)
  110. assert len(nodes) == 10
  111. neighbor = g.get_neg_sampled_neighbors(nodes, 5, 2)
  112. assert neighbor.shape == (10, 6)
  113. def test_graphdata_graphinfo():
  114. """
  115. Test graph info
  116. """
  117. logger.info('test graph info.\n')
  118. g = ds.GraphData(DATASET_FILE, 2)
  119. graph_info = g.graph_info()
  120. assert graph_info['node_type'] == [1, 2]
  121. assert graph_info['edge_type'] == [0]
  122. assert graph_info['node_num'] == {1: 10, 2: 10}
  123. assert graph_info['edge_num'] == {0: 40}
  124. assert graph_info['node_feature_type'] == [1, 2, 3, 4]
  125. assert graph_info['edge_feature_type'] == [1, 2]
  126. class RandomBatchedSampler(ds.Sampler):
  127. # RandomBatchedSampler generate random sequence without replacement in a batched manner
  128. def __init__(self, index_range, num_edges_per_sample):
  129. super().__init__()
  130. self.index_range = index_range
  131. self.num_edges_per_sample = num_edges_per_sample
  132. def __iter__(self):
  133. indices = [i+1 for i in range(self.index_range)]
  134. # Reset random seed here if necessary
  135. # random.seed(0)
  136. random.shuffle(indices)
  137. for i in range(0, self.index_range, self.num_edges_per_sample):
  138. # Drop reminder
  139. if i + self.num_edges_per_sample <= self.index_range:
  140. yield indices[i: i + self.num_edges_per_sample]
  141. class GNNGraphDataset():
  142. def __init__(self, g, batch_num):
  143. self.g = g
  144. self.batch_num = batch_num
  145. def __len__(self):
  146. # Total sample size of GNN dataset
  147. # In this case, the size should be total_num_edges/num_edges_per_sample
  148. return self.g.graph_info()['edge_num'][0] // self.batch_num
  149. def __getitem__(self, index):
  150. # index will be a list of indices yielded from RandomBatchedSampler
  151. # Fetch edges/nodes/samples/features based on indices
  152. nodes = self.g.get_nodes_from_edges(index.astype(np.int32))
  153. nodes = nodes[:, 0]
  154. neg_nodes = self.g.get_neg_sampled_neighbors(
  155. node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1)
  156. nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[
  157. 2, 2], neighbor_types=[2, 1])
  158. neg_nodes_neighbors = self.g.get_sampled_neighbors(
  159. node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], neighbor_types=[2, 2])
  160. nodes_neighbors_features = self.g.get_node_feature(
  161. node_list=nodes_neighbors, feature_types=[2, 3])
  162. neg_neighbors_features = self.g.get_node_feature(
  163. node_list=neg_nodes_neighbors, feature_types=[2, 3])
  164. return nodes_neighbors, neg_nodes_neighbors, nodes_neighbors_features[0], neg_neighbors_features[1]
  165. def test_graphdata_generatordataset():
  166. """
  167. Test generator dataset
  168. """
  169. logger.info('test generator dataset.\n')
  170. g = ds.GraphData(DATASET_FILE)
  171. batch_num = 2
  172. edge_num = g.graph_info()['edge_num'][0]
  173. out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"]
  174. dataset = ds.GeneratorDataset(source=GNNGraphDataset(g, batch_num), column_names=out_column_names,
  175. sampler=RandomBatchedSampler(edge_num, batch_num), num_parallel_workers=4)
  176. dataset = dataset.repeat(2)
  177. itr = dataset.create_dict_iterator(num_epochs=1, output_numpy=True)
  178. i = 0
  179. for data in itr:
  180. assert data['neighbors'].shape == (2, 7)
  181. assert data['neg_neighbors'].shape == (6, 7)
  182. assert data['neighbors_features'].shape == (2, 7)
  183. assert data['neg_neighbors_features'].shape == (6, 7)
  184. i += 1
  185. assert i == 40
  186. def test_graphdata_randomwalkdefault():
  187. """
  188. Test random walk defaults
  189. """
  190. logger.info('test randomwalk with default parameters.\n')
  191. g = ds.GraphData(SOCIAL_DATA_FILE, 1)
  192. nodes = g.get_all_nodes(1)
  193. assert len(nodes) == 33
  194. meta_path = [1 for _ in range(39)]
  195. walks = g.random_walk(nodes, meta_path)
  196. assert walks.shape == (33, 40)
  197. def test_graphdata_randomwalk():
  198. """
  199. Test random walk
  200. """
  201. logger.info('test random walk with given parameters.\n')
  202. g = ds.GraphData(SOCIAL_DATA_FILE, 1)
  203. nodes = g.get_all_nodes(1)
  204. assert len(nodes) == 33
  205. meta_path = [1 for _ in range(39)]
  206. walks = g.random_walk(nodes, meta_path, 2.0, 0.5, -1)
  207. assert walks.shape == (33, 40)
  208. def test_graphdata_getedgefeature():
  209. """
  210. Test get edge feature
  211. """
  212. logger.info('test get_edge_feature.\n')
  213. g = ds.GraphData(DATASET_FILE)
  214. edges = g.get_all_edges(0)
  215. features = g.get_edge_feature(edges, [1, 2])
  216. assert features[0].shape == (40,)
  217. assert features[1].shape == (40,)
  218. def test_graphdata_getedgefeature_invalidcase():
  219. """
  220. Test get edge feature with invalid edge id, 0 should be returned for those invalid edge id in correct index
  221. """
  222. logger.info('test get_edge_feature.\n')
  223. g = ds.GraphData(DATASET_FILE)
  224. edges = g.get_all_edges(0)
  225. edges[-6] = -1
  226. features = g.get_edge_feature(edges, [1, 2])
  227. assert features[0].shape == (40,)
  228. assert features[1].shape == (40,)
  229. assert features[0][-6] == 0
  230. assert features[1][-6] == 0.
  231. def test_graphdata_getnodefeature_invalidcase():
  232. """
  233. Test get node feature with invalid node id, 0 should be returned for those invalid node id in correct index
  234. """
  235. logger.info('test get_node_feature.\n')
  236. g = ds.GraphData(DATASET_FILE)
  237. nodes = g.get_all_nodes(node_type=1)
  238. nodes[5] = -1
  239. features = g.get_node_feature(node_list=nodes, feature_types=[2, 3])
  240. assert features[0].shape == (10,)
  241. assert features[1].shape == (10,)
  242. assert features[0][5] == 0.
  243. assert features[1][5] == 0
  244. def test_graphdata_getedgesfromnodes():
  245. """
  246. Test get edges from nodes
  247. """
  248. logger.info('test get_edges_from_nodes\n')
  249. g = ds.GraphData(DATASET_FILE)
  250. nodes_pair_list = [(101, 201), (103, 207), (204, 105), (108, 208), (110, 210), (210, 110)]
  251. edges = g.get_edges_from_nodes(node_list=nodes_pair_list)
  252. assert edges.tolist() == [1, 9, 31, 17, 20, 40]
  253. if __name__ == '__main__':
  254. test_graphdata_getfullneighbor()
  255. test_graphdata_getnodefeature_input_check()
  256. test_graphdata_getsampledneighbors()
  257. test_graphdata_getnegsampledneighbors()
  258. test_graphdata_graphinfo()
  259. test_graphdata_generatordataset()
  260. test_graphdata_randomwalkdefault()
  261. test_graphdata_randomwalk()
  262. test_graphdata_getedgefeature()
  263. test_graphdata_getedgesfromnodes()
  264. test_graphdata_getnodefeature_invalidcase()
  265. test_graphdata_getedgefeature_invalidcase()