You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

group_manager_test.cc 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <list>
  17. #include "parallel/device_manager.h"
  18. #include "common/common_test.h"
  19. #include "parallel/device.h"
  20. #include "parallel/group_manager.h"
  21. namespace mindspore {
  22. namespace parallel {
  23. extern DeviceManagerPtr g_device_manager;
  24. class TestGroup : public UT::Common {
  25. public:
  26. TestGroup() {}
  27. void SetUp();
  28. void TearDown();
  29. Status Init();
  30. Group gp;
  31. };
  32. void TestGroup::SetUp() { gp = Group(); }
  33. void TestGroup::TearDown() {
  34. // destroy resources
  35. }
  36. Status TestGroup::Init() {
  37. std::string gname = "1-2";
  38. std::vector<Device> dev_list;
  39. Device one = Device(int32_t(1));
  40. dev_list.push_back(one);
  41. Device two = Device(int32_t(2));
  42. dev_list.push_back(two);
  43. return gp.Init(gname, dev_list);
  44. }
  45. TEST_F(TestGroup, test_Init) { ASSERT_EQ(Init(), Status::SUCCESS); }
  46. TEST_F(TestGroup, test_GetDevicesList) {
  47. Init();
  48. std::vector<Device> res_dev_list = gp.GetDevicesList();
  49. std::vector<Device>::iterator it = res_dev_list.begin();
  50. ASSERT_EQ(it->rank(), int32_t(1));
  51. it++;
  52. ASSERT_EQ(it->rank(), int32_t(2));
  53. }
  54. TEST_F(TestGroup, test_IsInThisGroup) {
  55. Init();
  56. ASSERT_TRUE(gp.IsInThisGroup(int32_t(1)));
  57. ASSERT_TRUE(gp.IsInThisGroup(int32_t(2)));
  58. ASSERT_FALSE(gp.IsInThisGroup(int32_t(3)));
  59. }
  60. class TestGroupManager : public UT::Common {
  61. public:
  62. TestGroupManager() {}
  63. void SetUp();
  64. void TearDown();
  65. Status Init(Group** gp_ptr);
  66. GroupManager gm;
  67. };
  68. void TestGroupManager::SetUp() { gm = GroupManager(); }
  69. void TestGroupManager::TearDown() {
  70. // destroy resources
  71. }
  72. Status TestGroupManager::Init(Group** gp_ptr) {
  73. std::string gname = "1-2";
  74. std::vector<Device> dev_list;
  75. Device one = Device(int32_t(1));
  76. dev_list.push_back(one);
  77. Device two = Device(int32_t(2));
  78. dev_list.push_back(two);
  79. return gm.CreateGroup(gname, dev_list, *gp_ptr);
  80. }
  81. TEST_F(TestGroupManager, test_CreateGroup) {
  82. // testing for creating a group
  83. Group* gp_ptr = new Group();
  84. ASSERT_EQ(Init(&gp_ptr), Status::SUCCESS);
  85. std::vector<Device> res_dev_list = gp_ptr->GetDevicesList();
  86. std::vector<Device>::iterator it = res_dev_list.begin();
  87. ASSERT_EQ(it->rank(), int32_t(1));
  88. it++;
  89. ASSERT_EQ(it->rank(), int32_t(2));
  90. delete gp_ptr;
  91. // testing for creating a group with an existing group name
  92. std::vector<Device> dev_list2;
  93. Device three = Device(int32_t(3));
  94. dev_list2.push_back(three);
  95. Device four = Device(int32_t(4));
  96. dev_list2.push_back(four);
  97. gp_ptr = new Group();
  98. ASSERT_EQ(gm.CreateGroup("1-2", dev_list2, gp_ptr), Status::SUCCESS);
  99. ASSERT_STREQ(gp_ptr->name().data(), "1-2");
  100. std::vector<Device> res_dev_list2 = gp_ptr->GetDevicesList();
  101. std::vector<Device>::iterator it2 = res_dev_list2.begin();
  102. ASSERT_EQ(it2->rank(), int32_t(1));
  103. it2++;
  104. ASSERT_EQ(it2->rank(), int32_t(2));
  105. delete gp_ptr;
  106. gp_ptr = nullptr;
  107. }
  108. TEST_F(TestGroupManager, test_FindGroup) {
  109. std::string gname = "1-2";
  110. Group* gp_ptr = new Group();
  111. Group* gp_ptr2 = new Group();
  112. ASSERT_EQ(Init(&gp_ptr), Status::SUCCESS);
  113. ASSERT_EQ(gm.FindGroup(gname, &gp_ptr2), Status::SUCCESS);
  114. std::vector<Device> res_dev_list = gp_ptr2->GetDevicesList();
  115. std::vector<Device>::iterator it = res_dev_list.begin();
  116. ASSERT_EQ(it->rank(), int32_t(1));
  117. it++;
  118. ASSERT_EQ(it->rank(), int32_t(2));
  119. delete gp_ptr;
  120. gp_ptr = nullptr;
  121. std::string gname2 = "3-4";
  122. gp_ptr2 = new Group();
  123. ASSERT_EQ(gm.FindGroup(gname2, &gp_ptr2), Status::FAILED);
  124. delete gp_ptr2;
  125. gp_ptr2 = nullptr;
  126. }
  127. } // namespace parallel
  128. } // namespace mindspore