You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

step_auto_parallel_test.cc 8.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/common_test.h"
  17. #include "frontend/parallel/step_parallel.h"
  18. #include "frontend/parallel/step_auto_parallel.h"
  19. #include "frontend/parallel/auto_parallel/edge_costmodel.h"
  20. #include "frontend/parallel/ops_info/operator_info.h"
  21. #include "frontend/operator/ops.h"
  22. #include "pipeline/jit/static_analysis/static_analysis.h"
  23. namespace mindspore {
  24. namespace parallel {
  25. class TestStepAutoParallel : public UT::Common {
  26. public:
  27. TestStepAutoParallel() {}
  28. void SetUp();
  29. void TearDown() {}
  30. };
  31. void TestStepAutoParallel::SetUp() {
  32. RankList dev_list;
  33. for (int32_t i = 0; i < 20; i++) {
  34. dev_list.push_back(i);
  35. }
  36. RankList stage_map;
  37. stage_map.push_back(16);
  38. stage_map.push_back(4);
  39. int32_t local_dev = 0;
  40. // create a new g_device_manager
  41. g_device_manager = std::make_shared<DeviceManager>();
  42. g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
  43. }
  44. CNodePtr Create_Node(Shape x, Shape y, Shape out) {
  45. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  46. ParameterPtr param1 = func_graph->add_parameter();
  47. ParameterPtr param2 = func_graph->add_parameter();
  48. param1->set_name("x");
  49. param2->set_name("y");
  50. BaseShapePtr shape1 = std::make_shared<abstract::Shape>(x);
  51. BaseShapePtr shape2 = std::make_shared<abstract::Shape>(y);
  52. BaseShapePtr shape3 = std::make_shared<abstract::Shape>(out);
  53. AbstractBasePtr abstract1 = abstract::FromValue(static_cast<int64_t>(1), false);
  54. AbstractBasePtr abstract2 = abstract::FromValue(static_cast<int64_t>(1), false);
  55. AbstractBasePtr abstract3 = abstract::FromValue(static_cast<int64_t>(1), false);
  56. abstract1->set_shape(shape1);
  57. abstract2->set_shape(shape2);
  58. abstract3->set_shape(shape3);
  59. param1->set_abstract(abstract1);
  60. param2->set_abstract(abstract2);
  61. std::vector<AnfNodePtr> inputs;
  62. inputs.push_back(NewValueNode(prim::kPrimMatMul));
  63. inputs.push_back(param1);
  64. inputs.push_back(param2);
  65. CNodePtr node = func_graph->NewCNode(inputs);
  66. PrimitivePtr prim = node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  67. ValuePtr transpose_a = MakeValue(false);
  68. ValuePtr transpose_b = MakeValue(false);
  69. prim->set_attr("transpose_a", transpose_a);
  70. prim->set_attr("transpose_b", transpose_b);
  71. node->set_abstract(abstract3);
  72. return node;
  73. }
  74. CNodePtr Create_two_nodes(Shape x, Shape y, Shape z, Shape w, Shape out) {
  75. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  76. ParameterPtr paramX = func_graph->add_parameter();
  77. ParameterPtr paramY = func_graph->add_parameter();
  78. ParameterPtr paramW = func_graph->add_parameter();
  79. paramX->set_name("x");
  80. paramY->set_name("y");
  81. paramW->set_name("w");
  82. BaseShapePtr shapeX = std::make_shared<abstract::Shape>(x);
  83. BaseShapePtr shapeY = std::make_shared<abstract::Shape>(y);
  84. BaseShapePtr shapeZ = std::make_shared<abstract::Shape>(z);
  85. BaseShapePtr shapeW = std::make_shared<abstract::Shape>(w);
  86. BaseShapePtr shapeOut = std::make_shared<abstract::Shape>(out);
  87. AbstractBasePtr abstractX = abstract::FromValue(static_cast<int64_t>(1), false);
  88. AbstractBasePtr abstractY = abstract::FromValue(static_cast<int64_t>(1), false);
  89. AbstractBasePtr abstractZ = abstract::FromValue(static_cast<int64_t>(1), false);
  90. AbstractBasePtr abstractW = abstract::FromValue(static_cast<int64_t>(1), false);
  91. AbstractBasePtr abstractOut = abstract::FromValue(static_cast<int64_t>(1), false);
  92. abstractX->set_shape(shapeX);
  93. abstractY->set_shape(shapeY);
  94. abstractZ->set_shape(shapeZ);
  95. abstractW->set_shape(shapeW);
  96. abstractOut->set_shape(shapeOut);
  97. paramX->set_abstract(abstractX);
  98. paramY->set_abstract(abstractY);
  99. paramW->set_abstract(abstractW);
  100. std::vector<AnfNodePtr> MatMul_1_inputs;
  101. MatMul_1_inputs.push_back(NewValueNode(prim::kPrimMatMul));
  102. MatMul_1_inputs.push_back(paramX);
  103. MatMul_1_inputs.push_back(paramY);
  104. CNodePtr MatMul_1_node = func_graph->NewCNode(MatMul_1_inputs);
  105. PrimitivePtr prim = MatMul_1_node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  106. ValuePtr transpose_a = MakeValue(false);
  107. ValuePtr transpose_b = MakeValue(false);
  108. prim->set_attr("transpose_a", transpose_a);
  109. prim->set_attr("transpose_b", transpose_b);
  110. MatMul_1_node->set_abstract(abstractZ);
  111. std::vector<AnfNodePtr> MatMul_2_inputs;
  112. MatMul_2_inputs.push_back(NewValueNode(prim::kPrimMatMul));
  113. MatMul_2_inputs.push_back(MatMul_1_node);
  114. MatMul_2_inputs.push_back(paramW);
  115. CNodePtr MatMul_2_node = func_graph->NewCNode(MatMul_2_inputs);
  116. PrimitivePtr prim2 = MatMul_2_node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  117. ValuePtr transpose_a_2 = MakeValue(false);
  118. ValuePtr transpose_b_2 = MakeValue(false);
  119. prim2->set_attr("transpose_a", transpose_a);
  120. prim2->set_attr("transpose_b", transpose_b);
  121. MatMul_2_node->set_abstract(abstractOut);
  122. return MatMul_2_node;
  123. }
  124. TEST_F(TestStepAutoParallel, test_create_op_instance) {
  125. Shape inputs_x_dims = {64, 32};
  126. Shape inputs_y_dims = {32, 64};
  127. Shape outputs_dims = {64, 64};
  128. CNodePtr node = Create_Node(inputs_x_dims, inputs_y_dims, outputs_dims);
  129. bool result = node->input(0)->cast<ValueNodePtr>()->value()->isa<Primitive>();
  130. ASSERT_EQ(result, true);
  131. // creat prim and attrs
  132. PrimitivePtr prim = node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  133. auto attrs = prim->attrs();
  134. // creat shape
  135. Shapes inputs_shape = std::vector<Shape>{inputs_x_dims, inputs_y_dims};
  136. Shapes outputs_shape = std::vector<Shape>{outputs_dims};
  137. std::vector<Shapes> shape = {inputs_shape, outputs_shape};
  138. StrategyPtr strategyPtr;
  139. std::shared_ptr<OperatorInfo> matmul_info = NewOperatorInstance(prim, attrs, shape);
  140. node->set_user_data<OperatorInfo>(matmul_info);
  141. std::string name_expect = "MatMulInfo00";
  142. std::string name_test = matmul_info->name();
  143. ASSERT_EQ(name_expect, name_test);
  144. }
  145. TEST_F(TestStepAutoParallel, test_create_edge) {
  146. Shape inputs_x_dims = {64, 32};
  147. Shape inputs_y_dims = {32, 64};
  148. Shape outputs_z_dims = {64, 64};
  149. Shape inputs_w_dims = {64, 128};
  150. Shape outputs_dim = {64, 128};
  151. CNodePtr node = Create_two_nodes(inputs_x_dims, inputs_y_dims, outputs_z_dims, inputs_w_dims, outputs_dim);
  152. // u-->v
  153. PrimitivePtr v_prim = node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  154. auto v_attrs = v_prim->attrs();
  155. PrimitivePtr u_prim = node->input(1)->cast<CNodePtr>()->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  156. auto u_attrs = u_prim->attrs();
  157. // creat v node
  158. Shapes v_inputs_shape = std::vector<Shape>{outputs_z_dims, inputs_w_dims};
  159. Shapes v_outputs_shape = std::vector<Shape>{outputs_dim};
  160. std::vector<Shapes> v_shape = {v_inputs_shape, v_outputs_shape};
  161. StrategyPtr v_strategyPtr;
  162. std::shared_ptr<OperatorInfo> v_matmul_info = NewOperatorInstance(v_prim, v_attrs, v_shape);
  163. // create u node
  164. Shapes u_inputs_shape = std::vector<Shape>{inputs_x_dims, inputs_y_dims};
  165. Shapes u_outputs_shape = std::vector<Shape>{outputs_z_dims};
  166. std::vector<Shapes> u_shape = {u_inputs_shape, u_outputs_shape};
  167. StrategyPtr u_strategyPtr;
  168. std::shared_ptr<OperatorInfo> u_matmul_info = NewOperatorInstance(u_prim, u_attrs, u_shape);
  169. std::string edge_name = u_prim->name() + "-" + v_prim->name();
  170. std::shared_ptr<Edge> edge_ptr = std::make_shared<Edge>(edge_name, u_matmul_info, v_matmul_info, 0, 0, false);
  171. std::string expected_name = "MatMul-MatMul";
  172. ASSERT_EQ(edge_ptr->edge_name(), expected_name);
  173. }
  174. } // namespace parallel
  175. } // namespace mindspore