You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

step_parallel_test.cc 21 kB

4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/common_test.h"
  17. #include "frontend/parallel/step_parallel.h"
  18. #include "frontend/parallel/graph_util/generate_graph.h"
  19. #include "common/py_func_graph_fetcher.h"
  20. #include "include/common/debug/draw.h"
  21. #include "frontend/operator/ops.h"
  22. #include "pipeline/jit/static_analysis/static_analysis.h"
  23. #include "include/common/utils/convert_utils_py.h"
  24. namespace mindspore {
  25. namespace parallel {
  26. extern size_t TOTAL_OPS;
  27. class TestStepParallel : public UT::Common {
  28. public:
  29. TestStepParallel() {}
  30. void SetUp();
  31. void TearDown() {}
  32. };
  33. void Init_Device_Manager() {
  34. RankList dev_list;
  35. for (int32_t i = 0; i < 20; i++) {
  36. dev_list.push_back(i);
  37. }
  38. RankList stage_map;
  39. stage_map.push_back(16);
  40. stage_map.push_back(4);
  41. int32_t local_dev = 0;
  42. // create a new g_device_manager
  43. g_device_manager = std::make_shared<DeviceManager>();
  44. g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
  45. }
  46. void TestStepParallel::SetUp() {
  47. UT::InitPythonPath();
  48. Init_Device_Manager();
  49. }
  50. CNodePtr Make_Node(Shape x, Shape y, Shape out, int64_t condition = 0) {
  51. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  52. ParameterPtr param1 = func_graph->add_parameter();
  53. ParameterPtr param2 = func_graph->add_parameter();
  54. param1->set_name("x");
  55. param2->set_name("y");
  56. BaseShapePtr shape1 = std::make_shared<abstract::Shape>(x);
  57. BaseShapePtr shape2 = std::make_shared<abstract::Shape>(y);
  58. BaseShapePtr shape3 = std::make_shared<abstract::Shape>(out);
  59. std::shared_ptr<tensor::Tensor> inputs_x = std::make_shared<tensor::Tensor>(kNumberTypeInt32, x);
  60. std::shared_ptr<tensor::Tensor> inputs_y = std::make_shared<tensor::Tensor>(kNumberTypeInt32, y);
  61. std::shared_ptr<tensor::Tensor> inputs_out = std::make_shared<tensor::Tensor>(kNumberTypeInt32, out);
  62. AbstractBasePtr abstract1 = abstract::FromValue(inputs_x, true);
  63. AbstractBasePtr abstract2 = abstract::FromValue(inputs_y, true);
  64. AbstractBasePtr abstract3 = abstract::FromValue(inputs_out, true);
  65. switch (condition) {
  66. case 0: {
  67. abstract1->set_shape(shape1);
  68. abstract2->set_shape(shape2);
  69. abstract3->set_shape(shape3);
  70. param1->set_abstract(abstract1);
  71. param2->set_abstract(abstract2);
  72. break;
  73. }
  74. case 1: {
  75. // Don't set abstract of param1, expecting a exception raised.
  76. param2->set_abstract(abstract2);
  77. break;
  78. }
  79. case 2: {
  80. abstract1->set_shape(shape1);
  81. abstract2->set_shape(shape2);
  82. param1->set_abstract(abstract1);
  83. param2->set_abstract(abstract2);
  84. abstract3 = abstract::FromValue(static_cast<int64_t>(1), false);
  85. break;
  86. }
  87. case 3: {
  88. std::vector<BaseShapePtr> shape_o = {std::make_shared<abstract::Shape>(x), std::make_shared<abstract::Shape>(y)};
  89. BaseShapePtr shape4 = std::make_shared<abstract::TupleShape>(shape_o);
  90. abstract1->set_shape(shape1);
  91. abstract2->set_shape(shape2);
  92. abstract3->set_shape(shape4);
  93. param1->set_abstract(abstract1);
  94. param2->set_abstract(abstract2);
  95. break;
  96. }
  97. default:
  98. MS_LOG(INFO) << "Do Nothing!";
  99. }
  100. std::vector<AnfNodePtr> inputs;
  101. inputs.push_back(NewValueNode(prim::kPrimMatMul));
  102. inputs.push_back(param1);
  103. inputs.push_back(param2);
  104. CNodePtr node = func_graph->NewCNode(inputs);
  105. node->set_abstract(abstract3);
  106. return node;
  107. }
  108. FuncGraphManagerPtr Make_Manager(int64_t condition = 0) {
  109. std::vector<int64_t> inputs_x = {64, 32};
  110. std::vector<int64_t> inputs_y = {32, 64};
  111. std::vector<int64_t> inputs_z = {64, 128};
  112. std::vector<int64_t> outputs_1 = {64, 64};
  113. std::vector<int64_t> outputs_2 = {64, 128};
  114. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  115. ParameterPtr param1 = func_graph->add_parameter();
  116. ParameterPtr param2 = func_graph->add_parameter();
  117. ParameterPtr param3 = func_graph->add_parameter();
  118. std::shared_ptr<tensor::Tensor> inputs_x_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, inputs_x);
  119. std::shared_ptr<tensor::Tensor> inputs_y_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, inputs_y);
  120. std::shared_ptr<tensor::Tensor> inputs_z_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, inputs_z);
  121. std::shared_ptr<tensor::Tensor> inputs_out1_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, outputs_1);
  122. std::shared_ptr<tensor::Tensor> inputs_out2_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, outputs_2);
  123. AbstractBasePtr abstract_x = abstract::FromValue(inputs_x_dim, true);
  124. AbstractBasePtr abstract_y = abstract::FromValue(inputs_y_dim, true);
  125. AbstractBasePtr abstract_z = abstract::FromValue(inputs_z_dim, true);
  126. AbstractBasePtr abstract_out1 = abstract::FromValue(inputs_out1_dim, true);
  127. AbstractBasePtr abstract_out2 = abstract::FromValue(inputs_out2_dim, true);
  128. param1->set_abstract(abstract_x);
  129. param2->set_abstract(abstract_y);
  130. param3->set_abstract(abstract_z);
  131. Dimensions v1 = {2, 2};
  132. Dimensions v2 = {2, 4};
  133. std::vector<ValuePtr> elements = {MakeValue(v1), MakeValue(v2)};
  134. ValueTuplePtr var = std::make_shared<ValueTuple>(elements);
  135. std::vector<AnfNodePtr> inputs;
  136. inputs.push_back(NewValueNode(prim::kPrimMatMul));
  137. inputs.push_back(param1);
  138. inputs.push_back(param2);
  139. CNodePtr node1 = func_graph->NewCNode(inputs);
  140. node1->set_in_forward_flag(true);
  141. node1->set_abstract(abstract_out1);
  142. PrimitivePtr prim1 = node1->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  143. ValuePtr transpose_a = MakeValue(false);
  144. ValuePtr transpose_b = MakeValue(false);
  145. prim1->AddAttr("transpose_a", transpose_a);
  146. prim1->AddAttr("transpose_b", transpose_b);
  147. prim1->AddAttr("instance_name", MakeValue("matmul1"));
  148. prim1->AddAttr("in_strategy", var);
  149. inputs.clear();
  150. Dimensions v3 = {2, 2};
  151. Dimensions v4 = {2, 4};
  152. std::vector<ValuePtr> elements2 = {MakeValue(v3), MakeValue(v4)};
  153. ValueTuplePtr var2 = std::make_shared<ValueTuple>(elements2);
  154. inputs.push_back(NewValueNode(prim::kPrimMatMul));
  155. inputs.push_back(node1);
  156. inputs.push_back(param3);
  157. CNodePtr node2 = func_graph->NewCNode(inputs);
  158. node2->set_in_forward_flag(true);
  159. node2->set_abstract(abstract_out2);
  160. inputs.clear();
  161. inputs.push_back(NewValueNode(prim::kPrimReturn));
  162. inputs.push_back(node2);
  163. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  164. cnode_return->set_in_forward_flag(true);
  165. func_graph->set_return(cnode_return);
  166. PrimitivePtr prim2 = node2->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  167. prim2->AddAttr("transpose_a", transpose_a);
  168. prim2->AddAttr("transpose_b", transpose_b);
  169. prim2->AddAttr("instance_name", MakeValue("matmul2"));
  170. prim2->AddAttr("in_strategy", var2);
  171. switch (condition) {
  172. case 1: {
  173. prim1->set_attr("in_strategy", MakeValue(static_cast<int64_t>(0)));
  174. break;
  175. }
  176. case 2: {
  177. std::vector<ValuePtr> elements_t = {MakeValue(static_cast<int64_t>(0))};
  178. ValueTuplePtr var_t = std::make_shared<ValueTuple>(elements_t);
  179. prim1->set_attr("in_strategy", var_t);
  180. break;
  181. }
  182. case 3: {
  183. Dimensions vt1 = {2, 4};
  184. Dimensions vt2 = {2, 4};
  185. std::vector<ValuePtr> elements_t2 = {MakeValue(vt1), MakeValue(vt2)};
  186. ValueTuplePtr var_t2 = std::make_shared<ValueTuple>(elements_t2);
  187. prim1->set_attr("in_strategy", var_t2);
  188. break;
  189. }
  190. }
  191. std::vector<FuncGraphPtr> func_graphs{func_graph};
  192. FuncGraphManagerPtr manager = std::make_shared<FuncGraphManager>(func_graphs, true);
  193. manager->Init();
  194. return manager;
  195. }
  196. TEST_F(TestStepParallel, GetPythonPath1) {
  197. OperatorName operator_name = "AllReduce";
  198. const std::string expect = "mindspore.ops.operations";
  199. auto temp = parallel::GetOpPythonPath(operator_name);
  200. ASSERT_EQ(temp, expect);
  201. }
  202. TEST_F(TestStepParallel, GetPythonPath2) {
  203. OperatorName operator_name = "Add";
  204. const std::string expect = "mindspore.ops.operations";
  205. auto temp = parallel::GetOpPythonPath(operator_name);
  206. ASSERT_EQ(temp, expect);
  207. }
  208. TEST_F(TestStepParallel, ExtractStrategy) {
  209. Dimensions v1 = {2, 2};
  210. Dimensions v2 = {4, 4};
  211. mindspore::HashMap<std::string, ValuePtr> attrs;
  212. // stage
  213. ValuePtr val1 = MakeValue(v1);
  214. ValuePtr val2 = MakeValue(v2);
  215. std::vector<ValuePtr> elements = {val1, val2};
  216. ValueTuplePtr strategy_tuple = std::make_shared<ValueTuple>(elements);
  217. attrs["in_strategy"] = strategy_tuple;
  218. Strategys strategy_expect = {v1, v2};
  219. StrategyPtr strategy = ExtractStrategy(attrs["in_strategy"]);
  220. Strategys strategy_test = strategy->GetInputDim();
  221. ASSERT_EQ(strategy_expect, strategy_test);
  222. }
  223. TEST_F(TestStepParallel, ExtractShape) {
  224. Shape inputs_x_dims = {64, 32};
  225. Shape inputs_y_dims = {32, 64};
  226. Shape outputs_dims = {64, 64};
  227. CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims, 4);
  228. EXPECT_THROW({ ExtractShape(node); }, std::runtime_error);
  229. }
  230. TEST_F(TestStepParallel, ExtractShape1) {
  231. Shape inputs_x_dims = {64, 32};
  232. Shape inputs_y_dims = {32, 64};
  233. Shape outputs_dims = {64, 64};
  234. CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims);
  235. std::vector<Shapes> shape_test = ExtractShape(node);
  236. Shapes inputs_shape = std::vector<Shape>{inputs_x_dims, inputs_y_dims};
  237. Shapes outputs_shape = std::vector<Shape>{outputs_dims};
  238. std::vector<Shapes> shape_expect = {inputs_shape, outputs_shape};
  239. ASSERT_EQ(shape_test, shape_expect);
  240. }
  241. TEST_F(TestStepParallel, ExtractShape2) {
  242. Shape inputs_x_dims = {64, 32};
  243. Shape inputs_y_dims = {32, 64};
  244. Shape outputs_dims = {64, 64};
  245. CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims, 1);
  246. EXPECT_THROW({ ExtractShape(node); }, std::runtime_error);
  247. }
  248. TEST_F(TestStepParallel, ExtractShape3) {
  249. Shape inputs_x_dims = {64, 32};
  250. Shape inputs_y_dims = {32, 64};
  251. Shape outputs_dims = {64, 64};
  252. CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims, 3);
  253. Shapes inputs_shape = std::vector<Shape>{inputs_x_dims, inputs_y_dims};
  254. std::vector<Shapes> shape_expect = {inputs_shape, inputs_shape};
  255. std::vector<Shapes> shape_test = ExtractShape(node);
  256. ASSERT_EQ(shape_test, shape_expect);
  257. }
  258. /// Feature: test CreateOpInstance in auto parallel.
  259. /// Description: net with MicroBatchInterleaved in semi auto parallel.
  260. /// Expectation: success.
  261. TEST_F(TestStepParallel, CreateOpInstance) {
  262. ValuePtr attr0_value = MakeValue(REDUCE_OP_SUM);
  263. ValuePtr attr1_value = MakeValue("0-1-2");
  264. Attr attr0 = std::make_pair("op", attr0_value);
  265. Attr attr1 = std::make_pair("group", attr1_value);
  266. OperatorAttrs attrs = {attr0, attr1};
  267. OperatorName op_name = "AllReduce";
  268. OperatorParams operator_param;
  269. OperatorArgs args = std::make_pair(attrs, operator_param);
  270. auto op_instance = CreateOpInstance(args.first, op_name, "test");
  271. ASSERT_TRUE(op_instance);
  272. PrimitivePyPtr allreduce_ptr = dyn_cast<PrimitivePy>(op_instance);
  273. ASSERT_TRUE(allreduce_ptr);
  274. if (nullptr != allreduce_ptr) {
  275. MS_LOG(INFO) << "Get PrimitivePyPtr: " << allreduce_ptr->name();
  276. std::vector<py::object> arglist;
  277. (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arglist),
  278. [](Attr attr) { return ValueToPyData(attr.second); });
  279. py::object allreduce_pyobj = python_adapter::CallPyFn(
  280. "mindspore.parallel._utils", "_get_python_op", "AllReduce", "mindspore.ops.operations", "test", arglist);
  281. py::dict opAttr = py::getattr(allreduce_pyobj, "attrs");
  282. mindspore::HashMap<std::string, ValuePtr> attributes{};
  283. for (auto item : opAttr) {
  284. if (!py::isinstance<py::str>(item.first)) {
  285. MS_LOG(EXCEPTION) << "type error in py dict convert";
  286. }
  287. std::string name = py::cast<std::string>(item.first);
  288. MS_LOG(INFO) << "Attr name: " << name;
  289. ValuePtr converted_ret;
  290. if (name == "op") {
  291. parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
  292. ASSERT_EQ(converted_ret->ToString(), "sum");
  293. } else {
  294. if (name == "group") {
  295. parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
  296. ASSERT_EQ(converted_ret->ToString(), "0-1-2");
  297. } else if (name == "fusion") {
  298. parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
  299. ASSERT_EQ(converted_ret->ToString(), "0");
  300. } else if (name == "instance_name") {
  301. parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
  302. ASSERT_EQ(converted_ret->ToString(), "test");
  303. } else if (name == "index") {
  304. parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
  305. ASSERT_EQ(converted_ret->ToString(), "0");
  306. } else if (name == "no_eliminate") {
  307. parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
  308. ASSERT_EQ(converted_ret->ToString(), "true");
  309. } else {
  310. MS_LOG(EXCEPTION) << "Test failed";
  311. }
  312. }
  313. attributes.emplace(name, converted_ret);
  314. }
  315. }
  316. }
  317. /// Feature: test CreateOpInstance in auto parallel.
  318. /// Description: net with MicroBatchInterleaved in semi auto parallel.
  319. /// Expectation: success.
  320. TEST_F(TestStepParallel, CreateOpInstance1) {
  321. OperatorAttrs attrs;
  322. OperatorName op_name = "ABC";
  323. OperatorParams operator_param;
  324. OperatorArgs args = std::make_pair(attrs, operator_param);
  325. EXPECT_THROW({ CreateOpInstance(args.first, op_name, "test"); }, std::runtime_error);
  326. }
  327. TEST_F(TestStepParallel, OperatorInstance) {
  328. // create attrs and prim
  329. PrimitivePtr prim = NewValueNode(prim::kPrimMatMul)->value()->cast<PrimitivePtr>();
  330. ValuePtr transpose_a = MakeValue(false);
  331. ValuePtr transpose_b = MakeValue(false);
  332. prim->set_attr("transpose_a", transpose_a);
  333. prim->set_attr("transpose_b", transpose_b);
  334. auto attrs = prim->attrs();
  335. // create strategy
  336. Strategys strategy = {{2, 2}, {2, 4}};
  337. StrategyPtr strategyPtr = parallel::NewStrategy(0, strategy);
  338. // create shape
  339. Shapes inputs_shape = std::vector<Shape>{{64, 32}, {32, 64}};
  340. Shapes outputs_shape = std::vector<Shape>{{64, 64}};
  341. std::vector<Shapes> shape = {inputs_shape, outputs_shape};
  342. TOTAL_OPS = 0;
  343. OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape);
  344. matmul_info->Init(strategyPtr, nullptr);
  345. std::string name_expect = "MatMulInfo00";
  346. std::string name_test = matmul_info->name();
  347. ASSERT_EQ(name_expect, name_test);
  348. }
  349. TEST_F(TestStepParallel, ExtractInformation) {
  350. FuncGraphManagerPtr manager = Make_Manager();
  351. FuncGraphSet graphs = manager->func_graphs();
  352. FuncGraphPtr graph = *graphs.begin();
  353. auto ret = graph->get_return();
  354. std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
  355. ExtractInformation(all_nodes);
  356. }
  357. TEST_F(TestStepParallel, ExtractInformation2) {
  358. FuncGraphManagerPtr manager = Make_Manager(2);
  359. FuncGraphSet graphs = manager->func_graphs();
  360. FuncGraphPtr graph = *graphs.begin();
  361. auto ret = graph->get_return();
  362. std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
  363. EXPECT_THROW({ ExtractInformation(all_nodes); }, std::runtime_error);
  364. }
  365. TEST_F(TestStepParallel, ExtractInformation3) {
  366. FuncGraphManagerPtr manager = Make_Manager(3);
  367. FuncGraphSet graphs = manager->func_graphs();
  368. FuncGraphPtr graph = *graphs.begin();
  369. auto ret = graph->get_return();
  370. std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
  371. EXPECT_THROW({ ExtractInformation(all_nodes); }, std::runtime_error);
  372. }
  373. TEST_F(TestStepParallel, ForwardCommunication1) {
  374. ValuePtr attr0_value = MakeValue(REDUCE_OP_SUM);
  375. ValuePtr attr1_value = MakeValue("0-1-2");
  376. Attr attr0 = std::make_pair("op", attr0_value);
  377. Attr attr1 = std::make_pair("group", attr1_value);
  378. OperatorAttrs attrs = {attr0, attr1};
  379. OperatorName op_name = "AllReduce";
  380. OperatorParams operator_param;
  381. OperatorArgs args = std::make_pair(attrs, operator_param);
  382. Operator op = std::make_pair(op_name, args);
  383. OperatorVector op_list = {op, op};
  384. FuncGraphManagerPtr manager = Make_Manager();
  385. FuncGraphSet graphs = manager->func_graphs();
  386. FuncGraphPtr graph = *graphs.begin();
  387. auto ret = graph->get_return();
  388. std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
  389. ExtractInformation(all_nodes);
  390. for (auto &node : all_nodes) {
  391. if (!node->isa<CNode>()) {
  392. continue;
  393. }
  394. auto cnode = node->cast<CNodePtr>();
  395. FuncGraphPtr func_graph = node->func_graph();
  396. PrimitivePtr prim = cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  397. if (prim->name() == "MatMul") {
  398. ForwardCommunication(op_list, cnode);
  399. }
  400. }
  401. AnfNodeSet after_nodes = manager->all_nodes();
  402. for (auto &node : after_nodes) {
  403. if (!node->isa<CNode>()) {
  404. continue;
  405. }
  406. auto &inputs = node->cast<CNodePtr>()->inputs();
  407. PrimitivePtr prim = inputs[0]->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  408. if (prim->name() == "Return" || prim->name() == "MatMul") {
  409. if (!inputs[1]->isa<Parameter>()) {
  410. CNodePtr pre_node = inputs[1]->cast<CNodePtr>();
  411. PrimitivePtr pre_prim = pre_node->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  412. CNodePtr pre_node2 = pre_node->input(1)->cast<CNodePtr>();
  413. PrimitivePtr pre_prim2 = pre_node2->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  414. ASSERT_EQ("AllReduce", pre_prim->name());
  415. ASSERT_EQ("AllReduce", pre_prim2->name());
  416. }
  417. }
  418. }
  419. }
  420. TEST_F(TestStepParallel, ForwardCommunication2) {
  421. OperatorVector op_list;
  422. FuncGraphManagerPtr manager = Make_Manager();
  423. FuncGraphSet graphs = manager->func_graphs();
  424. FuncGraphPtr graph = *graphs.begin();
  425. auto ret = graph->get_return();
  426. std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
  427. ExtractInformation(all_nodes);
  428. for (auto &node : all_nodes) {
  429. if (!node->isa<CNode>()) {
  430. continue;
  431. }
  432. auto cnode = node->cast<CNodePtr>();
  433. FuncGraphPtr func_graph = node->func_graph();
  434. func_graph->set_manager(nullptr);
  435. PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
  436. if (prim->name() == "MatMul") {
  437. EXPECT_THROW({ ForwardCommunication(op_list, cnode); }, std::runtime_error);
  438. break;
  439. }
  440. }
  441. }
  442. TEST_F(TestStepParallel, ForwardCommunication3) {
  443. OperatorVector op_list;
  444. FuncGraphManagerPtr manager = Make_Manager();
  445. FuncGraphSet graphs = manager->func_graphs();
  446. FuncGraphPtr graph = *graphs.begin();
  447. auto ret = graph->get_return();
  448. std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
  449. ExtractInformation(all_nodes);
  450. for (auto &node : all_nodes) {
  451. if (!node->isa<CNode>()) {
  452. continue;
  453. }
  454. auto cnode = node->cast<CNodePtr>();
  455. FuncGraphPtr func_graph = node->func_graph();
  456. PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
  457. if (prim->name() == "MatMul") {
  458. OperatorAttrs attrs;
  459. OperatorParams operator_param;
  460. OperatorArgs args = std::make_pair(attrs, operator_param);
  461. Operator op = std::make_pair("ABC", args);
  462. OperatorVector op_list = {op};
  463. EXPECT_THROW({ ForwardCommunication(op_list, cnode); }, std::runtime_error);
  464. break;
  465. }
  466. }
  467. }
  468. TEST_F(TestStepParallel, GetTensorInLayout) {
  469. // create attrs and prim
  470. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  471. Shape inputs_x_dims = {64, 32};
  472. Shape inputs_y_dims = {32, 64};
  473. Shape outputs_dims = {64, 64};
  474. CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims);
  475. std::vector<AnfNodePtr> inputs(node->inputs());
  476. CNodePtr node1 = func_graph->NewCNode(inputs);
  477. PrimitivePtr prim = node1->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
  478. ValuePtr transpose_a = MakeValue(false);
  479. ValuePtr transpose_b = MakeValue(false);
  480. prim->set_attr("transpose_a", transpose_a);
  481. prim->set_attr("transpose_b", transpose_b);
  482. auto attrs = prim->attrs();
  483. // create strategy
  484. Strategys strategy = {{2, 2}, {2, 4}};
  485. StrategyPtr strategyPtr = parallel::NewStrategy(0, strategy);
  486. // create shape
  487. Shapes inputs_shape = std::vector<Shape>{{64, 32}, {32, 64}};
  488. Shapes outputs_shape = std::vector<Shape>{{64, 64}};
  489. std::vector<Shapes> shape = {inputs_shape, outputs_shape};
  490. OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape);
  491. matmul_info->Init(strategyPtr, nullptr);
  492. node->set_user_data<OperatorInfo>(matmul_info);
  493. OperatorInfoPtr distribute_operator_pre = node->user_data<OperatorInfo>();
  494. TensorLayout tensorlayout_e;
  495. Shape array = {64, 64};
  496. TensorLayout tensorlayout = GetTensorInLayout(node1, prim, distribute_operator_pre);
  497. Shape tensor_shape_test = tensorlayout.tensor_shape().array();
  498. ASSERT_EQ(array, tensor_shape_test);
  499. }
  500. } // namespace parallel
  501. } // namespace mindspore