You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graphdata.py 2.6 kB

5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import pytest
  16. import mindspore.dataset as ds
  17. from mindspore import log as logger
  18. DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
  19. def test_graphdata_getfullneighbor():
  20. g = ds.GraphData(DATASET_FILE, 2)
  21. nodes = g.get_all_nodes(1)
  22. assert len(nodes) == 10
  23. nodes_list = nodes.tolist()
  24. neighbor = g.get_all_neighbors(nodes_list, 2)
  25. assert neighbor.shape == (10, 6)
  26. row_tensor = g.get_node_feature(neighbor.tolist(), [2, 3])
  27. assert row_tensor[0].shape == (10, 6)
  28. def test_graphdata_getnodefeature_input_check():
  29. g = ds.GraphData(DATASET_FILE)
  30. with pytest.raises(TypeError):
  31. input_list = [1, [1, 1]]
  32. g.get_node_feature(input_list, [1])
  33. with pytest.raises(TypeError):
  34. input_list = [[1, 1], 1]
  35. g.get_node_feature(input_list, [1])
  36. with pytest.raises(TypeError):
  37. input_list = [[1, 1], [1, 1, 1]]
  38. g.get_node_feature(input_list, [1])
  39. with pytest.raises(TypeError):
  40. input_list = [[1, 1, 1], [1, 1]]
  41. g.get_node_feature(input_list, [1])
  42. with pytest.raises(TypeError):
  43. input_list = [[1, 1], [1, [1, 1]]]
  44. g.get_node_feature(input_list, [1])
  45. with pytest.raises(TypeError):
  46. input_list = [[1, 1], [[1, 1], 1]]
  47. g.get_node_feature(input_list, [1])
  48. with pytest.raises(TypeError):
  49. input_list = [[1, 1], [1, 1]]
  50. g.get_node_feature(input_list, 1)
  51. with pytest.raises(TypeError):
  52. input_list = [[1, 1], [1, 1]]
  53. g.get_node_feature(input_list, ["a"])
  54. with pytest.raises(TypeError):
  55. input_list = [[1, 1], [1, 1]]
  56. g.get_node_feature(input_list, [1, "a"])
  57. if __name__ == '__main__':
  58. test_graphdata_getfullneighbor()
  59. logger.info('test_graphdata_getfullneighbor Ended.\n')
  60. test_graphdata_getnodefeature_input_check()
  61. logger.info('test_graphdata_getnodefeature_input_check Ended.\n')