You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graphdata.py 2.8 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import pytest
  16. import numpy as np
  17. import mindspore.dataset as ds
  18. from mindspore import log as logger
  19. DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
  20. def test_graphdata_getfullneighbor():
  21. g = ds.GraphData(DATASET_FILE, 2)
  22. nodes = g.get_all_nodes(1)
  23. assert len(nodes) == 10
  24. neighbor = g.get_all_neighbors(nodes, 2)
  25. assert neighbor.shape == (10, 6)
  26. row_tensor = g.get_node_feature(neighbor.tolist(), [2, 3])
  27. assert row_tensor[0].shape == (10, 6)
  28. def test_graphdata_getnodefeature_input_check():
  29. g = ds.GraphData(DATASET_FILE)
  30. with pytest.raises(TypeError):
  31. input_list = [1, [1, 1]]
  32. g.get_node_feature(input_list, [1])
  33. with pytest.raises(TypeError):
  34. input_list = [[1, 1], 1]
  35. g.get_node_feature(input_list, [1])
  36. with pytest.raises(TypeError):
  37. input_list = [[1, 1], [1, 1, 1]]
  38. g.get_node_feature(input_list, [1])
  39. with pytest.raises(TypeError):
  40. input_list = [[1, 1, 1], [1, 1]]
  41. g.get_node_feature(input_list, [1])
  42. with pytest.raises(TypeError):
  43. input_list = [[1, 1], [1, [1, 1]]]
  44. g.get_node_feature(input_list, [1])
  45. with pytest.raises(TypeError):
  46. input_list = [[1, 1], [[1, 1], 1]]
  47. g.get_node_feature(input_list, [1])
  48. with pytest.raises(TypeError):
  49. input_list = [[1, 1], [1, 1]]
  50. g.get_node_feature(input_list, 1)
  51. with pytest.raises(TypeError):
  52. input_list = [[1, 0.1], [1, 1]]
  53. g.get_node_feature(input_list, 1)
  54. with pytest.raises(TypeError):
  55. input_list = np.array([[1, 0.1], [1, 1]])
  56. g.get_node_feature(input_list, 1)
  57. with pytest.raises(TypeError):
  58. input_list = [[1, 1], [1, 1]]
  59. g.get_node_feature(input_list, ["a"])
  60. with pytest.raises(TypeError):
  61. input_list = [[1, 1], [1, 1]]
  62. g.get_node_feature(input_list, [1, "a"])
  63. if __name__ == '__main__':
  64. test_graphdata_getfullneighbor()
  65. logger.info('test_graphdata_getfullneighbor Ended.\n')
  66. test_graphdata_getnodefeature_input_check()
  67. logger.info('test_graphdata_getnodefeature_input_check Ended.\n')