You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

device_matrix_test.cc 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <vector>
  17. #include "common/common_test.h"
  18. #include "common/py_func_graph_fetcher.h"
  19. #include "frontend/parallel/device_matrix.h"
  20. namespace mindspore {
  21. namespace parallel {
  22. class TestDeviceMatrix : public UT::Common {
  23. public:
  24. TestDeviceMatrix() {}
  25. void SetUp() { UT::InitPythonPath(); }
  26. virtual void TearDown() {}
  27. };
  28. TEST_F(TestDeviceMatrix, Test2Dgroup_list) {
  29. RankList dev_list = {0, 1, 2, 3, 4, 5};
  30. Shape shape = {2, 3};
  31. DeviceMatrix arr(0, dev_list, shape);
  32. std::vector<RankList> group_list;
  33. if (arr.CreateGroupList() == Status::SUCCESS) group_list = arr.group_list();
  34. std::vector<RankList> group_list_expect = {{0, 3}, {0, 1, 2}};
  35. ASSERT_EQ(group_list, group_list_expect);
  36. }
  37. TEST_F(TestDeviceMatrix, Test3Dgroup_list) {
  38. RankList dev_list = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
  39. Shape shape = {2, 2, 3};
  40. DeviceMatrix arr(5, dev_list, shape);
  41. std::vector<RankList> group_list;
  42. if (arr.CreateGroupList() == Status::SUCCESS) group_list = arr.group_list();
  43. std::vector<RankList> group_list_expect = {{5, 11}, {2, 5}, {3, 4, 5}};
  44. ASSERT_EQ(group_list, group_list_expect);
  45. }
  46. TEST_F(TestDeviceMatrix, Test4DGetAlongDim) {
  47. RankList dev_list = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
  48. Shape shape = {2, 1, 4, 2};
  49. DeviceMatrix arr(5, dev_list, shape);
  50. std::vector<RankList> group_list;
  51. if (arr.CreateGroupList() == Status::SUCCESS) group_list = arr.group_list();
  52. std::vector<RankList> group_list_expect = {{5, 13}, {5}, {1, 3, 5, 7}, {4, 5}};
  53. ASSERT_EQ(group_list, group_list_expect);
  54. }
  55. TEST_F(TestDeviceMatrix, Test5DGetAlongDim) {
  56. RankList dev_list;
  57. for (int i = 0; i < 144; i++) dev_list.push_back(i);
  58. Shape shape = {3, 4, 2, 3, 2};
  59. DeviceMatrix arr(5, dev_list, shape);
  60. std::vector<RankList> group_list;
  61. if (arr.CreateGroupList() == Status::SUCCESS) group_list = arr.group_list();
  62. std::vector<RankList> group_list_expect = {{5, 53, 101}, {5, 17, 29, 41}, {5, 11}, {1, 3, 5}, {4, 5}};
  63. ASSERT_EQ(group_list, group_list_expect);
  64. }
  65. TEST_F(TestDeviceMatrix, TestCornerCaseGetAlongDim) {
  66. // Shape does not match the number of devices
  67. RankList dev_list = {0, 1, 2, 3, 4, 5, 6, 7, 8};
  68. Shape shape = {2, 2, 2};
  69. EXPECT_THROW({ DeviceMatrix arr(3, dev_list, shape); }, std::runtime_error);
  70. }
  71. TEST_F(TestDeviceMatrix, TestGetDeviceByTensorMapRandomOrderSliceOne) {
  72. RankList dev_list = {10, 3, 2, 9, 11, 100, 1, 0};
  73. Shape tensor_map = {-1, 0};
  74. RankList rank_list;
  75. Shape shape = {4, 2};
  76. DeviceMatrix arr(0, dev_list, shape);
  77. arr.GetDevicesByTensorMap(tensor_map, &rank_list);
  78. RankList rank_list_except = {3, 9, 100, 0};
  79. ASSERT_EQ(rank_list, rank_list_except);
  80. }
  81. TEST_F(TestDeviceMatrix, TestGetDeviceByTensorMapRandomOrderSliceTwo) {
  82. RankList dev_list = {10, 3, 2, 9, 11, 100, 1, 0};
  83. Shape tensor_map = {1, 0};
  84. RankList rank_list;
  85. Shape shape = {4, 2};
  86. DeviceMatrix arr(0, dev_list, shape);
  87. arr.GetDevicesByTensorMap(tensor_map, &rank_list);
  88. RankList rank_list_except = {0};
  89. ASSERT_EQ(rank_list, rank_list_except);
  90. }
  91. TEST_F(TestDeviceMatrix, TestGetDeviceByTensorMapNoramalOrder2D) {
  92. RankList dev_list = {0, 1, 2, 3, 4, 5, 6, 7};
  93. Shape tensor_map = {-1, 0};
  94. RankList rank_list;
  95. Shape shape = {4, 2};
  96. DeviceMatrix arr(6, dev_list, shape);
  97. arr.GetDevicesByTensorMap(tensor_map, &rank_list);
  98. RankList rank_list_except = {0, 2, 4, 6};
  99. ASSERT_EQ(rank_list, rank_list_except);
  100. }
  101. TEST_F(TestDeviceMatrix, TestCornerCase2GetAlongDim) {
  102. // Rank is out of range
  103. RankList dev_list = {0, 1, 2, 3, 4, 5, 6, 7};
  104. Shape shape = {2, 2, 2};
  105. EXPECT_THROW({ DeviceMatrix arr(8, dev_list, shape); }, std::runtime_error);
  106. }
  107. } // namespace parallel
  108. } // namespace mindspore