You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

step_parallel_test.cc 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/common_test.h"
  17. #include "parallel/step_parallel.h"
  18. #include "parallel/graph_util/generate_graph.h"
  19. #include "common/py_func_graph_fetcher.h"
  20. #include "debug/draw.h"
  21. #include "operator/ops.h"
  22. #include "pipeline/static_analysis/static_analysis.h"
  23. namespace mindspore {
  24. namespace parallel {
  25. extern size_t TOTAL_OPS;
  26. class TestStepParallel : public UT::Common {
  27. public:
  28. TestStepParallel() {}
  29. void SetUp();
  30. void TearDown() {}
  31. };
  32. void TestStepParallel::SetUp() { UT::InitPythonPath(); }
  33. void Init_Device_Manager() {
  34. std::vector<int32_t> dev_list;
  35. for (int32_t i = 0; i < 20; i++) {
  36. dev_list.push_back(i);
  37. }
  38. std::vector<int32_t> stage_map;
  39. stage_map.push_back(16);
  40. stage_map.push_back(4);
  41. int32_t local_dev = 0;
  42. // create a new g_device_manager
  43. g_device_manager = std::make_shared<DeviceManager>();
  44. g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
  45. }
  46. CNodePtr Make_Node(Shape x, Shape y, Shape out, int condition = 0) {
  47. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  48. ParameterPtr param1 = func_graph->add_parameter();
  49. ParameterPtr param2 = func_graph->add_parameter();
  50. param1->set_name("x");
  51. param2->set_name("y");
  52. BaseShapePtr shape1 = std::make_shared<abstract::Shape>(x);
  53. BaseShapePtr shape2 = std::make_shared<abstract::Shape>(y);
  54. BaseShapePtr shape3 = std::make_shared<abstract::Shape>(out);
  55. std::shared_ptr<tensor::Tensor> inputs_x = std::make_shared<tensor::Tensor>();
  56. inputs_x->set_data_type(kNumberTypeInt32);
  57. inputs_x->set_shape(x);
  58. std::shared_ptr<tensor::Tensor> inputs_y = std::make_shared<tensor::Tensor>();
  59. inputs_y->set_data_type(kNumberTypeInt32);
  60. inputs_y->set_shape(y);
  61. std::shared_ptr<tensor::Tensor> inputs_out = std::make_shared<tensor::Tensor>();
  62. inputs_out->set_data_type(kNumberTypeInt32);
  63. inputs_out->set_shape(out);
  64. AbstractBasePtr abstract1 = abstract::FromValue(inputs_x, true);
  65. AbstractBasePtr abstract2 = abstract::FromValue(inputs_y, true);
  66. AbstractBasePtr abstract3 = abstract::FromValue(inputs_out, true);
  67. switch (condition) {
  68. case 0: {
  69. abstract1->set_shape(shape1);
  70. abstract2->set_shape(shape2);
  71. abstract3->set_shape(shape3);
  72. param1->set_abstract(abstract1);
  73. param2->set_abstract(abstract2);
  74. break;
  75. }
  76. case 1: {
  77. abstract1->set_shape(nullptr);
  78. param1->set_abstract(abstract1);
  79. param2->set_abstract(abstract2);
  80. break;
  81. }
  82. case 2: {
  83. abstract1->set_shape(shape1);
  84. abstract2->set_shape(shape2);
  85. param1->set_abstract(abstract1);
  86. param2->set_abstract(abstract2);
  87. abstract3 = abstract::FromValue(1, false);
  88. break;
  89. }
  90. case 3: {
  91. std::vector<BaseShapePtr> shape_o = {std::make_shared<abstract::Shape>(x), std::make_shared<abstract::Shape>(y)};
  92. BaseShapePtr shape4 = std::make_shared<abstract::TupleShape>(shape_o);
  93. abstract1->set_shape(shape1);
  94. abstract2->set_shape(shape2);
  95. abstract3->set_shape(shape4);
  96. param1->set_abstract(abstract1);
  97. param2->set_abstract(abstract2);
  98. break;
  99. }
  100. default:
  101. MS_LOG(INFO) << "Do Nothing!";
  102. }
  103. std::vector<AnfNodePtr> inputs;
  104. inputs.push_back(NewValueNode(prim::kPrimMatMul));
  105. inputs.push_back(param1);
  106. inputs.push_back(param2);
  107. CNodePtr node = func_graph->NewCNode(inputs);
  108. node->set_abstract(abstract3);
  109. return node;
  110. }
  111. FuncGraphManagerPtr Make_Manager(int condition = 0) {
  112. Shape inputs_x = {64, 32};
  113. Shape inputs_y = {32, 64};
  114. Shape inputs_z = {64, 128};
  115. Shape outputs_1 = {64, 64};
  116. Shape outputs_2 = {64, 128};
  117. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  118. ParameterPtr param1 = func_graph->add_parameter();
  119. ParameterPtr param2 = func_graph->add_parameter();
  120. ParameterPtr param3 = func_graph->add_parameter();
  121. std::shared_ptr<tensor::Tensor> inputs_x_dim = std::make_shared<tensor::Tensor>();
  122. inputs_x_dim->set_data_type(kNumberTypeInt32);
  123. inputs_x_dim->set_shape(inputs_x);
  124. std::shared_ptr<tensor::Tensor> inputs_y_dim = std::make_shared<tensor::Tensor>();
  125. inputs_y_dim->set_data_type(kNumberTypeInt32);
  126. inputs_y_dim->set_shape(inputs_y);
  127. std::shared_ptr<tensor::Tensor> inputs_z_dim = std::make_shared<tensor::Tensor>();
  128. inputs_z_dim->set_data_type(kNumberTypeInt32);
  129. inputs_z_dim->set_shape(inputs_z);
  130. std::shared_ptr<tensor::Tensor> inputs_out1_dim = std::make_shared<tensor::Tensor>();
  131. inputs_out1_dim->set_data_type(kNumberTypeInt32);
  132. inputs_out1_dim->set_shape(outputs_1);
  133. std::shared_ptr<tensor::Tensor> inputs_out2_dim = std::make_shared<tensor::Tensor>();
  134. inputs_out2_dim->set_data_type(kNumberTypeInt32);
  135. inputs_out2_dim->set_shape(outputs_2);
  136. AbstractBasePtr abstract_x = abstract::FromValue(inputs_x_dim, true);
  137. AbstractBasePtr abstract_y = abstract::FromValue(inputs_y_dim, true);
  138. AbstractBasePtr abstract_z = abstract::FromValue(inputs_z_dim, true);
  139. AbstractBasePtr abstract_out1 = abstract::FromValue(inputs_out1_dim, true);
  140. AbstractBasePtr abstract_out2 = abstract::FromValue(inputs_out2_dim, true);
  141. param1->set_abstract(abstract_x);
  142. param2->set_abstract(abstract_y);
  143. param3->set_abstract(abstract_z);
  144. std::vector<int> v1 = {2, 2};
  145. std::vector<int> v2 = {2, 4};
  146. std::vector<ValuePtr> elements = {MakeValue(v1), MakeValue(v2)};
  147. ValueTuplePtr var = std::make_shared<ValueTuple>(elements);
  148. std::vector<AnfNodePtr> inputs;
  149. inputs.push_back(NewValueNode(prim::kPrimMatMul));
  150. inputs.push_back(param1);
  151. inputs.push_back(param2);
  152. CNodePtr node1 = func_graph->NewCNode(inputs);
  153. node1->set_in_forward_flag(true);
  154. node1->set_abstract(abstract_out1);
  155. PrimitivePtr prim1 = node1->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  156. ValuePtr transpose_a = MakeValue(false);
  157. ValuePtr transpose_b = MakeValue(false);
  158. prim1->AddAttr("transpose_a", transpose_a);
  159. prim1->AddAttr("transpose_b", transpose_b);
  160. prim1->AddAttr("instance_name", MakeValue("matmul1"));
  161. prim1->AddAttr("strategy", var);
  162. inputs.clear();
  163. std::vector<int> v3 = {2, 2};
  164. std::vector<int> v4 = {2, 4};
  165. std::vector<ValuePtr> elements2 = {MakeValue(v3), MakeValue(v4)};
  166. ValueTuplePtr var2 = std::make_shared<ValueTuple>(elements2);
  167. inputs.push_back(NewValueNode(prim::kPrimMatMul));
  168. inputs.push_back(node1);
  169. inputs.push_back(param3);
  170. CNodePtr node2 = func_graph->NewCNode(inputs);
  171. node2->set_in_forward_flag(true);
  172. node2->set_abstract(abstract_out2);
  173. inputs.clear();
  174. inputs.push_back(NewValueNode(prim::kPrimReturn));
  175. inputs.push_back(node2);
  176. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  177. cnode_return->set_in_forward_flag(true);
  178. func_graph->set_return(cnode_return);
  179. PrimitivePtr prim2 = node2->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  180. prim2->AddAttr("transpose_a", transpose_a);
  181. prim2->AddAttr("transpose_b", transpose_b);
  182. prim2->AddAttr("instance_name", MakeValue("matmul2"));
  183. prim2->AddAttr("strategy", var2);
  184. switch (condition) {
  185. case 1: {
  186. prim1->set_attr("strategy", MakeValue(0));
  187. break;
  188. }
  189. case 2: {
  190. std::vector<ValuePtr> elements_t = {MakeValue(0)};
  191. ValueTuplePtr var_t = std::make_shared<ValueTuple>(elements_t);
  192. prim1->set_attr("strategy", var_t);
  193. break;
  194. }
  195. case 3: {
  196. std::vector<int> vt1 = {2, 4};
  197. std::vector<int> vt2 = {2, 4};
  198. std::vector<ValuePtr> elements_t2 = {MakeValue(vt1), MakeValue(vt2)};
  199. ValueTuplePtr var_t2 = std::make_shared<ValueTuple>(elements_t2);
  200. prim1->set_attr("strategy", var_t2);
  201. break;
  202. }
  203. }
  204. std::vector<FuncGraphPtr> func_graphs{func_graph};
  205. FuncGraphManagerPtr manager = std::make_shared<FuncGraphManager>(func_graphs, true);
  206. manager->Init();
  207. return manager;
  208. }
  209. TEST_F(TestStepParallel, GetPythonPath1) {
  210. OperatorName operator_name = "AllReduce";
  211. const std::string expect = "mindspore.ops.operations";
  212. auto temp = parallel::GetOpPythonPath(operator_name);
  213. ASSERT_EQ(temp, expect);
  214. }
  215. TEST_F(TestStepParallel, GetPythonPath2) {
  216. OperatorName operator_name = "TensorAdd";
  217. const std::string expect = "mindspore.ops.operations";
  218. auto temp = parallel::GetOpPythonPath(operator_name);
  219. ASSERT_EQ(temp, expect);
  220. }
  221. TEST_F(TestStepParallel, ExtractStrategy) {
  222. Dimensions v1 = {2, 2};
  223. Dimensions v2 = {4, 4};
  224. std::unordered_map<std::string, ValuePtr> attrs;
  225. // stage
  226. ValuePtr val1 = MakeValue(v1);
  227. ValuePtr val2 = MakeValue(v2);
  228. std::vector<ValuePtr> elements = {val1, val2};
  229. ValueTuplePtr strategy_tuple = std::make_shared<ValueTuple>(elements);
  230. attrs["strategy"] = strategy_tuple;
  231. std::vector<Dimensions> strategy_expect = {v1, v2};
  232. StrategyPtr strategy = ExtractStrategy(attrs);
  233. std::vector<Dimensions> strategy_test = strategy->GetInputDim();
  234. ASSERT_EQ(strategy_expect, strategy_test);
  235. }
  236. TEST_F(TestStepParallel, ExtractShape) {
  237. Shape inputs_x_dims = {64, 32};
  238. Shape inputs_y_dims = {32, 64};
  239. Shape outputs_dims = {64, 64};
  240. CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims, 4);
  241. EXPECT_THROW({ ExtractShape(node); }, std::runtime_error);
  242. }
  243. TEST_F(TestStepParallel, ExtractShape1) {
  244. Shape inputs_x_dims = {64, 32};
  245. Shape inputs_y_dims = {32, 64};
  246. Shape outputs_dims = {64, 64};
  247. CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims);
  248. std::vector<Shapes> shape_test = ExtractShape(node);
  249. Shapes inputs_shape = std::vector<Shape>{inputs_x_dims, inputs_y_dims};
  250. Shapes outputs_shape = std::vector<Shape>{outputs_dims};
  251. std::vector<Shapes> shape_expect = {inputs_shape, outputs_shape};
  252. ASSERT_EQ(shape_test, shape_expect);
  253. }
  254. TEST_F(TestStepParallel, ExtractShape2) {
  255. Shape inputs_x_dims = {64, 32};
  256. Shape inputs_y_dims = {32, 64};
  257. Shape outputs_dims = {64, 64};
  258. CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims, 1);
  259. EXPECT_THROW({ ExtractShape(node); }, std::runtime_error);
  260. }
  261. TEST_F(TestStepParallel, ExtractShape3) {
  262. Shape inputs_x_dims = {64, 32};
  263. Shape inputs_y_dims = {32, 64};
  264. Shape outputs_dims = {64, 64};
  265. CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims, 3);
  266. Shapes inputs_shape = std::vector<Shape>{inputs_x_dims, inputs_y_dims};
  267. std::vector<Shapes> shape_expect = {inputs_shape, inputs_shape};
  268. std::vector<Shapes> shape_test = ExtractShape(node);
  269. ASSERT_EQ(shape_test, shape_expect);
  270. }
  271. TEST_F(TestStepParallel, ExtractShape4) {
  272. Shape inputs_x_dims = {64, 32};
  273. Shape inputs_y_dims = {32, 64};
  274. Shape outputs_dims = {64, 64};
  275. CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims, 2);
  276. Shapes inputs_shape = std::vector<Shape>{inputs_x_dims, inputs_y_dims};
  277. EXPECT_THROW({ ExtractShape(node); }, std::runtime_error);
  278. }
  279. TEST_F(TestStepParallel, CreatOpInstance) {
  280. ValuePtr attr0_value = MakeValue(REDUCE_OP_SUM);
  281. ValuePtr attr1_value = MakeValue("0-1-2");
  282. Attr attr0 = std::make_pair("op", attr0_value);
  283. Attr attr1 = std::make_pair("group", attr1_value);
  284. OperatorAttrs attrs = {attr0, attr1};
  285. OperatorName op_name = "AllReduce";
  286. OperatorParams operator_param;
  287. OperatorArgs args = std::make_pair(attrs, operator_param);
  288. auto op_instance = CreatOpInstance(args.first, op_name, "test");
  289. ASSERT_TRUE(op_instance);
  290. PrimitivePyPtr allreduce_ptr = dyn_cast<PrimitivePy>(op_instance);
  291. ASSERT_TRUE(allreduce_ptr);
  292. if (nullptr != allreduce_ptr) {
  293. MS_LOG(INFO) << "Get PrimitivePyPtr: " << allreduce_ptr->name();
  294. auto func = allreduce_ptr->GetComputeFunction();
  295. if (py::isinstance<py::none>(func)) {
  296. MS_LOG(EXCEPTION) << "" << allreduce_ptr->name() << "'s compute function is not implemented";
  297. }
  298. std::vector<py::object> arglist;
  299. (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arglist),
  300. [](Attr attr) { return ValuePtrToPyData(attr.second); });
  301. py::object allreduce_pyobj = parse::python_adapter::CallPyFn(
  302. "mindspore.parallel._utils", "_get_python_op", "AllReduce", "mindspore.ops.operations", "test", arglist);
  303. py::dict opAttr = py::getattr(allreduce_pyobj, "attrs");
  304. std::unordered_map<std::string, ValuePtr> attributes{};
  305. for (auto item : opAttr) {
  306. if (!py::isinstance<py::str>(item.first)) {
  307. MS_LOG(EXCEPTION) << "type error in py dict convert";
  308. }
  309. std::string name = py::cast<std::string>(item.first);
  310. MS_LOG(INFO) << "Attr name: " << name;
  311. ValuePtr converted_ret;
  312. if (name == "op") {
  313. parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
  314. ASSERT_EQ(converted_ret->ToString(), "sum");
  315. } else {
  316. if (name == "group") {
  317. parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
  318. ASSERT_EQ(converted_ret->ToString(), "0-1-2");
  319. } else if (name == "fusion") {
  320. parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
  321. ASSERT_EQ(converted_ret->ToString(), "0");
  322. } else if (name == "instance_name") {
  323. parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
  324. ASSERT_EQ(converted_ret->ToString(), "test");
  325. } else {
  326. MS_LOG(EXCEPTION) << "Test failed";
  327. }
  328. }
  329. attributes.emplace(name, converted_ret);
  330. }
  331. }
  332. }
  333. TEST_F(TestStepParallel, CreatOpInstance1) {
  334. OperatorAttrs attrs;
  335. OperatorName op_name = "ABC";
  336. OperatorParams operator_param;
  337. OperatorArgs args = std::make_pair(attrs, operator_param);
  338. EXPECT_THROW({ CreatOpInstance(args.first, op_name, "test"); }, std::runtime_error);
  339. }
  340. TEST_F(TestStepParallel, OperatorInstance) {
  341. Init_Device_Manager();
  342. // creat attrs and prim
  343. PrimitivePtr prim = NewValueNode(prim::kPrimMatMul)->value()->cast<PrimitivePtr>();
  344. ValuePtr transpose_a = MakeValue(false);
  345. ValuePtr transpose_b = MakeValue(false);
  346. prim->set_attr("transpose_a", transpose_a);
  347. prim->set_attr("transpose_b", transpose_b);
  348. auto attrs = prim->attrs();
  349. // creat strategy
  350. std::vector<Dimensions> strategy = {{2, 2}, {2, 4}};
  351. StrategyPtr strategyPtr = parallel::NewStrategy(0, strategy);
  352. // creat shape
  353. Shapes inputs_shape = std::vector<Shape>{{64, 32}, {32, 64}};
  354. Shapes outputs_shape = std::vector<Shape>{{64, 64}};
  355. std::vector<Shapes> shape = {inputs_shape, outputs_shape};
  356. TOTAL_OPS = 0;
  357. OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape);
  358. matmul_info->Init(strategyPtr);
  359. std::string name_expect = "MatMulInfo00";
  360. std::string name_test = matmul_info->name();
  361. ASSERT_EQ(name_expect, name_test);
  362. }
  363. TEST_F(TestStepParallel, ExtractInformation) {
  364. Init_Device_Manager();
  365. FuncGraphManagerPtr manager = Make_Manager();
  366. FuncGraphSet graphs = manager->func_graphs();
  367. FuncGraphPtr graph = *graphs.begin();
  368. auto ret = graph->get_return();
  369. std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
  370. ExtractInformation(all_nodes);
  371. }
  372. TEST_F(TestStepParallel, ExtractInformation2) {
  373. Init_Device_Manager();
  374. FuncGraphManagerPtr manager = Make_Manager(2);
  375. FuncGraphSet graphs = manager->func_graphs();
  376. FuncGraphPtr graph = *graphs.begin();
  377. auto ret = graph->get_return();
  378. std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
  379. EXPECT_THROW({ ExtractInformation(all_nodes); }, std::runtime_error);
  380. }
  381. TEST_F(TestStepParallel, ExtractInformation3) {
  382. Init_Device_Manager();
  383. FuncGraphManagerPtr manager = Make_Manager(3);
  384. FuncGraphSet graphs = manager->func_graphs();
  385. FuncGraphPtr graph = *graphs.begin();
  386. auto ret = graph->get_return();
  387. std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
  388. EXPECT_THROW({ ExtractInformation(all_nodes); }, std::runtime_error);
  389. }
  390. TEST_F(TestStepParallel, ForwardCommunication1) {
  391. Init_Device_Manager();
  392. ValuePtr attr0_value = MakeValue(REDUCE_OP_SUM);
  393. ValuePtr attr1_value = MakeValue("0-1-2");
  394. Attr attr0 = std::make_pair("op", attr0_value);
  395. Attr attr1 = std::make_pair("group", attr1_value);
  396. OperatorAttrs attrs = {attr0, attr1};
  397. OperatorName op_name = "AllReduce";
  398. OperatorParams operator_param;
  399. OperatorArgs args = std::make_pair(attrs, operator_param);
  400. Operator op = std::make_pair(op_name, args);
  401. OperatorVector op_list = {op, op};
  402. FuncGraphManagerPtr manager = Make_Manager();
  403. FuncGraphSet graphs = manager->func_graphs();
  404. FuncGraphPtr graph = *graphs.begin();
  405. auto ret = graph->get_return();
  406. std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
  407. ExtractInformation(all_nodes);
  408. for (auto &node : all_nodes) {
  409. if (!node->isa<CNode>()) {
  410. continue;
  411. }
  412. auto cnode = node->cast<CNodePtr>();
  413. FuncGraphPtr func_graph = node->func_graph();
  414. PrimitivePtr prim = cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  415. if (prim->name() == "MatMul") {
  416. ForwardCommunication(op_list, cnode);
  417. draw::Draw("./forwardcommunication.dot", func_graph);
  418. }
  419. }
  420. AnfNodeSet after_nodes = manager->all_nodes();
  421. for (auto &node : after_nodes) {
  422. if (!node->isa<CNode>()) {
  423. continue;
  424. }
  425. auto &inputs = node->cast<CNodePtr>()->inputs();
  426. PrimitivePtr prim = inputs[0]->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  427. if (prim->name() == "return" || prim->name() == "MatMul") {
  428. if (!inputs[1]->isa<Parameter>()) {
  429. CNodePtr pre_node = inputs[1]->cast<CNodePtr>();
  430. PrimitivePtr pre_prim = pre_node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  431. CNodePtr pre_node2 = pre_node->input(1)->cast<CNodePtr>();
  432. PrimitivePtr pre_prim2 = pre_node2->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  433. ASSERT_EQ("AllReduce", pre_prim->name());
  434. ASSERT_EQ("AllReduce", pre_prim2->name());
  435. }
  436. }
  437. }
  438. }
  439. TEST_F(TestStepParallel, ForwardCommunication2) {
  440. OperatorVector op_list;
  441. FuncGraphManagerPtr manager = Make_Manager();
  442. FuncGraphSet graphs = manager->func_graphs();
  443. FuncGraphPtr graph = *graphs.begin();
  444. auto ret = graph->get_return();
  445. std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
  446. ExtractInformation(all_nodes);
  447. for (auto &node : all_nodes) {
  448. if (!node->isa<CNode>()) {
  449. continue;
  450. }
  451. auto cnode = node->cast<CNodePtr>();
  452. FuncGraphPtr func_graph = node->func_graph();
  453. func_graph->set_manager(nullptr);
  454. PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
  455. if (prim->name() == "MatMul") {
  456. EXPECT_THROW({ ForwardCommunication(op_list, cnode); }, std::runtime_error);
  457. break;
  458. }
  459. }
  460. }
  461. TEST_F(TestStepParallel, ForwardCommunication3) {
  462. OperatorVector op_list;
  463. FuncGraphManagerPtr manager = Make_Manager();
  464. FuncGraphSet graphs = manager->func_graphs();
  465. FuncGraphPtr graph = *graphs.begin();
  466. auto ret = graph->get_return();
  467. std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
  468. ExtractInformation(all_nodes);
  469. for (auto &node : all_nodes) {
  470. if (!node->isa<CNode>()) {
  471. continue;
  472. }
  473. auto cnode = node->cast<CNodePtr>();
  474. FuncGraphPtr func_graph = node->func_graph();
  475. PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
  476. if (prim->name() == "MatMul") {
  477. OperatorAttrs attrs;
  478. OperatorParams operator_param;
  479. OperatorArgs args = std::make_pair(attrs, operator_param);
  480. Operator op = std::make_pair("ABC", args);
  481. OperatorVector op_list = {op};
  482. EXPECT_THROW({ ForwardCommunication(op_list, cnode); }, std::runtime_error);
  483. break;
  484. }
  485. }
  486. }
  487. TEST_F(TestStepParallel, GetTensorInLayout) {
  488. Init_Device_Manager();
  489. // creat attrs and prim
  490. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  491. Shape inputs_x_dims = {64, 32};
  492. Shape inputs_y_dims = {32, 64};
  493. Shape outputs_dims = {64, 64};
  494. CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims);
  495. std::vector<AnfNodePtr> inputs(node->inputs());
  496. CNodePtr node1 = func_graph->NewCNode(inputs);
  497. PrimitivePtr prim = node1->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  498. ValuePtr transpose_a = MakeValue(false);
  499. ValuePtr transpose_b = MakeValue(false);
  500. prim->set_attr("transpose_a", transpose_a);
  501. prim->set_attr("transpose_b", transpose_b);
  502. auto attrs = prim->attrs();
  503. // creat strategy
  504. std::vector<Dimensions> strategy = {{2, 2}, {2, 4}};
  505. StrategyPtr strategyPtr = parallel::NewStrategy(0, strategy);
  506. // creat shape
  507. Shapes inputs_shape = std::vector<Shape>{{64, 32}, {32, 64}};
  508. Shapes outputs_shape = std::vector<Shape>{{64, 64}};
  509. std::vector<Shapes> shape = {inputs_shape, outputs_shape};
  510. OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape);
  511. matmul_info->Init(strategyPtr);
  512. node->set_operator_info(matmul_info);
  513. OperatorInfoPtr distribute_operator_pre = node->operator_info();
  514. TensorLayout tensorlayout_e;
  515. std::vector<int32_t> array = {64, 64};
  516. TensorLayout tensorlayout = GetTensorInLayout(node1, prim, distribute_operator_pre);
  517. std::vector<int32_t> tensor_shape_test = tensorlayout.tensor_shape().array();
  518. ASSERT_EQ(array, tensor_shape_test);
  519. }
  520. } // namespace parallel
  521. } // namespace mindspore