You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transpose_test.cc 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <string>
  17. #include <list>
  18. #include <vector>
  19. #include "common/common_test.h"
  20. #include "frontend/parallel/strategy.h"
  21. #include "frontend/parallel/ops_info/transpose_info.h"
  22. #include "frontend/parallel/device_manager.h"
  23. #include "frontend/parallel/step_parallel.h"
  24. namespace mindspore {
  25. namespace parallel {
  26. class TransposeInfo;
  27. using TransposeInfoPtr = std::shared_ptr<TransposeInfo>;
  28. TransposeInfoPtr transpose;
  29. class TestTransposeInfo : public UT::Common {
  30. public:
  31. TestTransposeInfo() {}
  32. void SetUp();
  33. void TearDown() {}
  34. };
  35. void TestTransposeInfo::SetUp() {
  36. RankList dev_list;
  37. for (int32_t i = 0; i < 34; i++) {
  38. dev_list.push_back(i);
  39. }
  40. RankList stage_map;
  41. stage_map.push_back(32);
  42. stage_map.push_back(2);
  43. int32_t local_dev = 0;
  44. // create a new g_device_manager
  45. g_device_manager = std::make_shared<DeviceManager>();
  46. g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
  47. std::unordered_map<std::string, ValuePtr> attr;
  48. Shapes inputs_shape = {{128, 64}};
  49. Shapes outputs_shape = {{64, 128}};
  50. std::vector<int> axis = {1, 0};
  51. ValuePtr val0;
  52. ValuePtr val1 = MakeValue(axis);
  53. std::vector<ValuePtr> val = {val0, val1};
  54. transpose = std::make_shared<TransposeInfo>("transpose_info", inputs_shape, outputs_shape, attr);
  55. transpose->set_input_value(val);
  56. }
  57. TEST_F(TestTransposeInfo, InferDevMatrixShape1) {
  58. Strategys inputs = {{4, 8}};
  59. StrategyPtr strategy = NewStrategy(0, inputs);
  60. transpose->Init(strategy);
  61. Shape dev_matrix_shape = transpose->dev_matrix_shape();
  62. Shape expect = {4, 8};
  63. ASSERT_EQ(dev_matrix_shape, expect);
  64. }
  65. TEST_F(TestTransposeInfo, InferDevMatrixShape2) {
  66. Strategys inputs = {{4, 1}};
  67. StrategyPtr strategy = NewStrategy(0, inputs);
  68. transpose->Init(strategy);
  69. Shape dev_matrix_shape = transpose->dev_matrix_shape();
  70. Shape expect = {8, 4, 1};
  71. ASSERT_EQ(dev_matrix_shape, expect);
  72. }
  73. TEST_F(TestTransposeInfo, InferSliceShape1) {
  74. Strategys str = {{4, 8}};
  75. StrategyPtr strategy = NewStrategy(0, str);
  76. transpose->Init(strategy);
  77. std::vector<TensorInfo> inputs = transpose->inputs_tensor_info();
  78. std::vector<TensorInfo> outputs = transpose->outputs_tensor_info();
  79. Shape input_slice_shape_expect = {32, 8};
  80. Shape output_slice_shape_expect = {8, 32};
  81. TensorInfo input_tensor_info = inputs.at(0);
  82. TensorInfo output_tensor_info = outputs.at(0);
  83. Shape input_slice_shape = input_tensor_info.slice_shape();
  84. Shape output_slice_shape = output_tensor_info.slice_shape();
  85. ASSERT_EQ(input_slice_shape, input_slice_shape_expect);
  86. ASSERT_EQ(output_slice_shape, output_slice_shape_expect);
  87. }
  88. TEST_F(TestTransposeInfo, GetTensorLayout1) {
  89. Strategys str = {{4, 8}};
  90. StrategyPtr strategy = NewStrategy(0, str);
  91. transpose->Init(strategy);
  92. std::vector<TensorInfo> inputs = transpose->inputs_tensor_info();
  93. std::vector<TensorInfo> outputs = transpose->outputs_tensor_info();
  94. TensorMap input_expect = {1, 0};
  95. TensorMap output_expect = {0, 1};
  96. TensorInfo input_tensor_info = inputs.at(0);
  97. TensorInfo output_tensor_info = outputs.at(0);
  98. Map input_tensor_map = input_tensor_info.tensor_layout().origin_tensor_map();
  99. Map output_tensor_map = output_tensor_info.tensor_layout().origin_tensor_map();
  100. ASSERT_EQ(input_tensor_map.array(), input_expect);
  101. ASSERT_EQ(output_tensor_map.array(), output_expect);
  102. }
  103. TEST_F(TestTransposeInfo, GetForwardOp1) {
  104. Strategys inputs = {{4, 8}};
  105. StrategyPtr strategy = NewStrategy(0, inputs);
  106. transpose->Init(strategy);
  107. OperatorVector forward_op = transpose->forward_op();
  108. size_t size = forward_op.size();
  109. ASSERT_EQ(size, 0);
  110. }
  111. TEST_F(TestTransposeInfo, GetMirrorOPs1) {
  112. Strategys inputs = {{4, 8}};
  113. StrategyPtr strategy = NewStrategy(0, inputs);
  114. transpose->Init(strategy);
  115. MirrorOps mirror_ops = transpose->mirror_ops();
  116. size_t size = mirror_ops.size();
  117. ASSERT_EQ(size, 0);
  118. }
  119. TEST_F(TestTransposeInfo, CheckStrategy1) {
  120. Strategys inputs = {{1, 4, 8}};
  121. StrategyPtr strategy = NewStrategy(0, inputs);
  122. Status ret = transpose->Init(strategy);
  123. ASSERT_EQ(ret, FAILED);
  124. }
  125. TEST_F(TestTransposeInfo, CheckStrategy2) {
  126. Strategys inputs = {{2, 4, 8}, {2, 4, 8}};
  127. StrategyPtr strategy = NewStrategy(0, inputs);
  128. Status ret = transpose->Init(strategy);
  129. ASSERT_EQ(ret, FAILED);
  130. }
  131. TEST_F(TestTransposeInfo, CheckStrategy3) {
  132. Strategys inputs = {{4, 8}};
  133. StrategyPtr strategy = NewStrategy(0, inputs);
  134. Status ret = transpose->Init(strategy);
  135. ASSERT_EQ(ret, SUCCESS);
  136. }
  137. TEST_F(TestTransposeInfo, AutoStrategy1) {
  138. ASSERT_EQ(transpose->GenerateStrategies(0), Status::SUCCESS);
  139. std::vector<std::shared_ptr<StrategyWithCost>> sc = transpose->GetStrategyCost();
  140. Shapes splittable_inputs = {{1, 1}};
  141. std::vector<StrategyPtr> sp_vector;
  142. Shapes inputs_shape = {{128, 64}};
  143. GenerateStrategiesForIndependentInputs(0, inputs_shape, splittable_inputs, &sp_vector);
  144. ASSERT_EQ(sc.size(), sp_vector.size());
  145. }
  146. } // namespace parallel
  147. } // namespace mindspore