You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

reshape_test.cc 6.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <string>
  17. #include <list>
  18. #include <vector>
  19. #include "common/common_test.h"
  20. #include "parallel/strategy.h"
  21. #include "parallel/ops_info/reshape_info.h"
  22. #include "parallel/device_manager.h"
  23. #include "parallel/step_parallel.h"
  24. namespace mindspore {
  25. namespace parallel {
  26. class ReshapeInfo;
  27. using ReshapeInfoPtr = std::shared_ptr<ReshapeInfo>;
  28. ReshapeInfoPtr reshape;
  29. class TestReshapeInfo : public UT::Common {
  30. public:
  31. TestReshapeInfo() {}
  32. void SetUp();
  33. void TearDown() {}
  34. };
  35. void TestReshapeInfo::SetUp() {
  36. std::vector<int32_t> dev_list;
  37. for (int32_t i = 0; i < 34; i++) {
  38. dev_list.push_back(i);
  39. }
  40. std::vector<int32_t> stage_map;
  41. stage_map.push_back(32);
  42. stage_map.push_back(2);
  43. int32_t local_dev = 0;
  44. // create a new g_device_manager
  45. g_device_manager = std::make_shared<DeviceManager>();
  46. g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
  47. std::unordered_map<std::string, ValuePtr> attr;
  48. Shapes inputs_shape = {{32, 512, 7, 7}};
  49. Shapes outputs_shape = {{32, 25088}};
  50. std::vector<int> axis = {32, 25088};
  51. ValuePtr val0;
  52. ValuePtr val1 = MakeValue(axis);
  53. std::vector<ValuePtr> val = {val0, val1};
  54. reshape = std::make_shared<ReshapeInfo>("reshape_info", inputs_shape, outputs_shape, attr);
  55. reshape->set_input_value(val);
  56. }
  57. TEST_F(TestReshapeInfo, InferDevMatrixShape1) {
  58. std::vector<Dimensions> inputs = {{4, 1, 1, 1}};
  59. StrategyPtr strategy = NewStrategy(0, inputs);
  60. reshape->Init(strategy);
  61. std::vector<int32_t> dev_matrix_shape = reshape->dev_matrix_shape();
  62. std::vector<int32_t> expect = {8, 4};
  63. ASSERT_EQ(dev_matrix_shape, expect);
  64. }
  65. TEST_F(TestReshapeInfo, InferDevMatrixShape2) {
  66. std::vector<Dimensions> inputs = {{32, 1, 1, 1}};
  67. StrategyPtr strategy = NewStrategy(0, inputs);
  68. reshape->Init(strategy);
  69. std::vector<int32_t> dev_matrix_shape = reshape->dev_matrix_shape();
  70. std::vector<int32_t> expect = {32};
  71. ASSERT_EQ(dev_matrix_shape, expect);
  72. }
  73. TEST_F(TestReshapeInfo, InferSliceShape1) {
  74. std::vector<Dimensions> str = {{4, 1, 1, 1}};
  75. StrategyPtr strategy = NewStrategy(0, str);
  76. reshape->Init(strategy);
  77. std::vector<TensorInfo> inputs = reshape->inputs_tensor_info();
  78. std::vector<TensorInfo> outputs = reshape->outputs_tensor_info();
  79. Shape input_slice_shape_expect = {8, 512, 7, 7};
  80. Shape output_slice_shape_expect = {8, 25088};
  81. TensorInfo input_tensor_info = inputs.at(0);
  82. TensorInfo output_tensor_info = outputs.at(0);
  83. Shape input_slice_shape = input_tensor_info.slice_shape();
  84. Shape output_slice_shape = output_tensor_info.slice_shape();
  85. ASSERT_EQ(input_slice_shape, input_slice_shape_expect);
  86. ASSERT_EQ(output_slice_shape, output_slice_shape_expect);
  87. }
  88. TEST_F(TestReshapeInfo, InferSliceShape2) {
  89. std::vector<Dimensions> str = {{32, 1, 1, 1}};
  90. StrategyPtr strategy = NewStrategy(0, str);
  91. reshape->Init(strategy);
  92. std::vector<TensorInfo> inputs = reshape->inputs_tensor_info();
  93. std::vector<TensorInfo> outputs = reshape->outputs_tensor_info();
  94. Shape input_slice_shape_expect = {1, 512, 7, 7};
  95. Shape output_slice_shape_expect = {1, 25088};
  96. TensorInfo input_tensor_info = inputs.at(0);
  97. TensorInfo output_tensor_info = outputs.at(0);
  98. Shape input_slice_shape = input_tensor_info.slice_shape();
  99. Shape output_slice_shape = output_tensor_info.slice_shape();
  100. ASSERT_EQ(input_slice_shape, input_slice_shape_expect);
  101. ASSERT_EQ(output_slice_shape, output_slice_shape_expect);
  102. }
  103. TEST_F(TestReshapeInfo, GetTensorLayout1) {
  104. std::vector<Dimensions> str = {{4, 1, 1, 1}};
  105. StrategyPtr strategy = NewStrategy(0, str);
  106. reshape->Init(strategy);
  107. std::vector<TensorInfo> inputs = reshape->inputs_tensor_info();
  108. std::vector<TensorInfo> outputs = reshape->outputs_tensor_info();
  109. TensorMap input_expect = {0, -1, -1, -1};
  110. TensorMap output_expect = {0, -1};
  111. TensorInfo input_tensor_info = inputs.at(0);
  112. TensorInfo output_tensor_info = outputs.at(0);
  113. Map input_tensor_map = input_tensor_info.tensor_layout().origin_tensor_map();
  114. Map output_tensor_map = output_tensor_info.tensor_layout().origin_tensor_map();
  115. ASSERT_EQ(input_tensor_map.array(), input_expect);
  116. ASSERT_EQ(output_tensor_map.array(), output_expect);
  117. }
  118. TEST_F(TestReshapeInfo, GetTensorLayout2) {
  119. std::vector<Dimensions> str = {{32, 1, 1, 1}};
  120. StrategyPtr strategy = NewStrategy(0, str);
  121. reshape->Init(strategy);
  122. std::vector<TensorInfo> inputs = reshape->inputs_tensor_info();
  123. std::vector<TensorInfo> outputs = reshape->outputs_tensor_info();
  124. TensorMap input_expect = {0, -1, -1, -1};
  125. TensorMap output_expect = {0, -1};
  126. TensorInfo input_tensor_info = inputs.at(0);
  127. TensorInfo output_tensor_info = outputs.at(0);
  128. Map input_tensor_map = input_tensor_info.tensor_layout().origin_tensor_map();
  129. Map output_tensor_map = output_tensor_info.tensor_layout().origin_tensor_map();
  130. ASSERT_EQ(input_tensor_map.array(), input_expect);
  131. ASSERT_EQ(output_tensor_map.array(), output_expect);
  132. }
  133. TEST_F(TestReshapeInfo, GetForwardOp1) {
  134. std::vector<Dimensions> inputs = {{4, 1, 1, 1}};
  135. StrategyPtr strategy = NewStrategy(0, inputs);
  136. reshape->Init(strategy);
  137. OperatorVector forward_op = reshape->forward_op();
  138. size_t size = forward_op.size();
  139. ASSERT_EQ(size, 0);
  140. }
  141. TEST_F(TestReshapeInfo, GetMirrorOPs1) {
  142. std::vector<Dimensions> inputs = {{4, 1, 1, 1}};
  143. StrategyPtr strategy = NewStrategy(0, inputs);
  144. reshape->Init(strategy);
  145. MirrorOps mirror_ops = reshape->mirror_ops();
  146. size_t size = mirror_ops.size();
  147. ASSERT_EQ(size, 2);
  148. }
  149. TEST_F(TestReshapeInfo, CheckStrategy1) {
  150. std::vector<Dimensions> inputs = {{1, 4, 8}};
  151. StrategyPtr strategy = NewStrategy(0, inputs);
  152. Status ret = reshape->Init(strategy);
  153. ASSERT_EQ(ret, FAILED);
  154. }
  155. TEST_F(TestReshapeInfo, CheckStrategy2) {
  156. std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}};
  157. StrategyPtr strategy = NewStrategy(0, inputs);
  158. Status ret = reshape->Init(strategy);
  159. ASSERT_EQ(ret, FAILED);
  160. }
  161. TEST_F(TestReshapeInfo, CheckStrategy3) {
  162. std::vector<Dimensions> inputs = {{4, 1, 1, 1}};
  163. StrategyPtr strategy = NewStrategy(0, inputs);
  164. Status ret = reshape->Init(strategy);
  165. ASSERT_EQ(ret, SUCCESS);
  166. }
  167. } // namespace parallel
  168. } // namespace mindspore