You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

device_matrix_test.cc 3.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <vector>
  17. #include "common/common_test.h"
  18. #include "common/py_func_graph_fetcher.h"
  19. #include "frontend/parallel/device_matrix.h"
  20. namespace mindspore {
  21. namespace parallel {
  22. class TestDeviceMatrix : public UT::Common {
  23. public:
  24. TestDeviceMatrix() {}
  25. void SetUp() { UT::InitPythonPath(); }
  26. virtual void TearDown() {}
  27. };
  28. TEST_F(TestDeviceMatrix, Test2Dgroup_list) {
  29. RankList dev_list = {0, 1, 2, 3, 4, 5};
  30. Shape shape = {2, 3};
  31. DeviceMatrix arr(0, dev_list, shape);
  32. std::vector<RankList> group_list;
  33. if (arr.CreateGroupList() == Status::SUCCESS) group_list = arr.group_list();
  34. std::vector<RankList> group_list_expect = {{0, 3}, {0, 1, 2}};
  35. ASSERT_EQ(group_list, group_list_expect);
  36. }
  37. TEST_F(TestDeviceMatrix, Test3Dgroup_list) {
  38. RankList dev_list = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
  39. Shape shape = {2, 2, 3};
  40. DeviceMatrix arr(5, dev_list, shape);
  41. std::vector<RankList> group_list;
  42. if (arr.CreateGroupList() == Status::SUCCESS) group_list = arr.group_list();
  43. std::vector<RankList> group_list_expect = {{5, 11}, {2, 5}, {3, 4, 5}};
  44. ASSERT_EQ(group_list, group_list_expect);
  45. }
  46. TEST_F(TestDeviceMatrix, Test4DGetAlongDim) {
  47. RankList dev_list = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
  48. Shape shape = {2, 1, 4, 2};
  49. DeviceMatrix arr(5, dev_list, shape);
  50. std::vector<RankList> group_list;
  51. if (arr.CreateGroupList() == Status::SUCCESS) group_list = arr.group_list();
  52. std::vector<RankList> group_list_expect = {{5, 13}, {5}, {1, 3, 5, 7}, {4, 5}};
  53. ASSERT_EQ(group_list, group_list_expect);
  54. }
  55. TEST_F(TestDeviceMatrix, Test5DGetAlongDim) {
  56. RankList dev_list;
  57. for (int i = 0; i < 144; i++) dev_list.push_back(i);
  58. Shape shape = {3, 4, 2, 3, 2};
  59. DeviceMatrix arr(5, dev_list, shape);
  60. std::vector<RankList> group_list;
  61. if (arr.CreateGroupList() == Status::SUCCESS) group_list = arr.group_list();
  62. std::vector<RankList> group_list_expect = {{5, 53, 101}, {5, 17, 29, 41}, {5, 11}, {1, 3, 5}, {4, 5}};
  63. ASSERT_EQ(group_list, group_list_expect);
  64. }
  65. TEST_F(TestDeviceMatrix, TestCornerCaseGetAlongDim) {
  66. // Shape does not match the number of devices
  67. RankList dev_list = {0, 1, 2, 3, 4, 5, 6, 7, 8};
  68. Shape shape = {2, 2, 2};
  69. EXPECT_THROW({ DeviceMatrix arr(3, dev_list, shape); }, std::runtime_error);
  70. }
  71. TEST_F(TestDeviceMatrix, TestCornerCase2GetAlongDim) {
  72. // Rank is out of range
  73. RankList dev_list = {0, 1, 2, 3, 4, 5, 6, 7};
  74. Shape shape = {2, 2, 2};
  75. EXPECT_THROW({ DeviceMatrix arr(8, dev_list, shape); }, std::runtime_error);
  76. }
  77. } // namespace parallel
  78. } // namespace mindspore