You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

device_manager_test.cc 3.9 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <list>
  17. #include "common/common_test.h"
  18. #include "frontend/parallel/device.h"
  19. #include "frontend/parallel/device_manager.h"
  20. #include "frontend/parallel/group_manager.h"
  21. namespace mindspore {
  22. namespace parallel {
  23. class TestDevice : public UT::Common {
  24. public:
  25. TestDevice() {}
  26. void SetUp();
  27. void TearDown();
  28. Device dev_1;
  29. Device dev_2;
  30. };
  31. void TestDevice::SetUp() {
  32. std::string name = "#1";
  33. dev_1 = Device(name, std::int32_t(1));
  34. dev_2 = Device(std::int32_t(2));
  35. }
  36. void TestDevice::TearDown() {
  37. // destroy resources
  38. }
  39. TEST_F(TestDevice, test_device) {
  40. std::string name = "#1";
  41. int32_t dev1_rank = 1;
  42. int32_t dev2_rank = 2;
  43. ASSERT_STREQ(dev_1.name().data(), name.data());
  44. ASSERT_EQ(dev_1.rank(), dev1_rank);
  45. ASSERT_EQ(dev_2.rank(), dev2_rank);
  46. }
  47. // need to complete
  48. class TestStage : public UT::Common {};
  49. class TestDeviceManager : public UT::Common {
  50. public:
  51. TestDeviceManager() {}
  52. void SetUp();
  53. void TearDown();
  54. DeviceManager dm_;
  55. };
  56. void TestDeviceManager::SetUp() { dm_ = DeviceManager::GetInstance(); }
  57. void TestDeviceManager::TearDown() {
  58. // destroy resources
  59. }
  60. TEST_F(TestDeviceManager, test_dm_init_AND_get_device_list) {
  61. RankList dev_list;
  62. RankList stage_map;
  63. int32_t local_dev = 0;
  64. dev_list.push_back(5);
  65. dev_list.push_back(3);
  66. dev_list.push_back(1);
  67. dev_list.push_back(0);
  68. stage_map.push_back(2);
  69. stage_map.push_back(2);
  70. ASSERT_EQ(dm_.Init(dev_list, local_dev, stage_map, "hccl"), Status::SUCCESS);
  71. ASSERT_EQ(dm_.DeviceNum(), 4);
  72. ASSERT_EQ(dm_.stage_num(), (int32_t)(2));
  73. RankList dev_list_0 = dm_.GetDeviceListByStageId(0);
  74. RankList dev_list_1 = dm_.GetDeviceListByStageId(1);
  75. ASSERT_EQ(dev_list_0.size(), 2);
  76. ASSERT_EQ(dev_list_1.size(), 2);
  77. RankList::iterator it = dev_list_0.begin();
  78. ASSERT_EQ((*it), int32_t(5));
  79. it++;
  80. ASSERT_EQ((*it), int32_t(3));
  81. it = dev_list_1.begin();
  82. ASSERT_EQ((*it), int32_t(1));
  83. it++;
  84. ASSERT_EQ((*it), int32_t(0));
  85. }
  86. TEST_F(TestDeviceManager, test_CreateNewDeviceByRank) {
  87. Device one = dm_.CreateNewDeviceByRank(int32_t(3));
  88. ASSERT_EQ(one.rank(), int32_t(3));
  89. }
  90. TEST_F(TestDeviceManager, test_CreateDeviceListByRankList) {
  91. std::vector<Device> dev_list;
  92. RankList rlist;
  93. rlist.push_back(int32_t(2));
  94. rlist.push_back(int32_t(1));
  95. dev_list = dm_.CreateDeviceListByRankList(rlist);
  96. std::vector<Device>::iterator it = dev_list.begin();
  97. ASSERT_EQ(it->rank(), int32_t(2));
  98. it++;
  99. ASSERT_EQ(it->rank(), int32_t(1));
  100. }
  101. TEST_F(TestDeviceManager, test_StageID) {
  102. RankList dev_list;
  103. RankList stage_map;
  104. int32_t local_dev = 2;
  105. dev_list.push_back(0);
  106. dev_list.push_back(1);
  107. dev_list.push_back(2);
  108. dev_list.push_back(3);
  109. stage_map.push_back(2);
  110. stage_map.push_back(2);
  111. ASSERT_EQ(dm_.Init(dev_list, local_dev, stage_map, "hccl"), Status::SUCCESS);
  112. ASSERT_EQ(dm_.DeviceNum(), 4);
  113. ASSERT_EQ(dm_.stage_num(), 2);
  114. ASSERT_EQ(dm_.stage_id(), 1);
  115. ASSERT_EQ(dm_.rank_index_in_stage(), 0);
  116. ASSERT_EQ(dm_.GetDeviceListInThisStage().back(), 3);
  117. RankList dev_list_0 = dm_.GetDeviceListByStageId(0);
  118. RankList dev_list_1 = dm_.GetDeviceListByStageId(1);
  119. ASSERT_EQ(dev_list_0.size(), 2);
  120. ASSERT_EQ(dev_list_1.size(), 2);
  121. }
  122. } // namespace parallel
  123. } // namespace mindspore