You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ascend_control_parser.cc 22 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <utility>
  17. #include <memory>
  18. #include "session/ascend_control_parser.h"
  19. #include "session/anf_runtime_algorithm.h"
  20. static constexpr size_t kCNodePrim = 0;
  21. static constexpr size_t kCNodeCallArg = 1;
  22. static constexpr size_t kCNodeSwitchCond = 1;
  23. static constexpr size_t kCNodeSwitchTrue = 2;
  24. static constexpr size_t kCNodeSwitchFalse = 3;
  25. static constexpr size_t kCNodeSwitchLength = 4;
  26. static constexpr size_t kCNodePartialLength = 2;
  27. static constexpr size_t kCNodePartialFunc = 1;
  28. static constexpr size_t kCNodeSwitchLayerBranch = 2;
  29. static constexpr size_t kCNodeSwitchLayerLength = 3;
  30. namespace mindspore {
  31. namespace session {
  32. void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map) {
  33. for (auto &iter : graph_id_map) {
  34. auto &kg = iter.second;
  35. MS_EXCEPTION_IF_NULL(kg);
  36. auto real_inputs = kg->real_inputs();
  37. for (auto &it : real_inputs) {
  38. auto &parameter = it.first;
  39. auto &args = it.second;
  40. for (auto &arg : args) {
  41. MS_EXCEPTION_IF_NULL(arg);
  42. if (arg->isa<Parameter>()) {
  43. MS_LOG(INFO) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString()
  44. << ", arg:" << arg->DebugString();
  45. continue;
  46. }
  47. auto target_graph_iter = graph_id_map.find(AnfAlgo::GetGraphId(arg.get()));
  48. if (target_graph_iter == graph_id_map.end()) {
  49. MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found.";
  50. }
  51. InsertAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(arg), NOT_NULL(parameter));
  52. }
  53. }
  54. }
  55. }
  56. void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
  57. std::set<KernelGraphPtr> memo;
  58. ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo));
  59. std::map<uint32_t, KernelGraphPtr> graph_id_map;
  60. for (auto &g : memo) {
  61. if (graph_id_map.find(g->graph_id()) != graph_id_map.end()) {
  62. MS_LOG(EXCEPTION) << "Two graph has same graph id " << g->graph_id()
  63. << ", graph: " << graph_id_map[g->graph_id()]->ToString() << " " << g->ToString();
  64. }
  65. graph_id_map[g->graph_id()] = g;
  66. }
  67. ChildGraphDataAssign(graph_id_map);
  68. }
  69. CNodePtr AscendControlParser::GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) {
  70. for (size_t i = start; i < list.size() - 1; ++i) {
  71. if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) {
  72. return list[i];
  73. }
  74. }
  75. return nullptr;
  76. }
  77. NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
  78. const CNodePtr &last_label,
  79. const NotNull<std::set<KernelGraphPtr> *> memo) {
  80. MS_LOG(INFO) << "Start process KernelGraph " << kg->ToString();
  81. // 1. recursive condition
  82. if (memo->find(kg) != memo->end()) {
  83. MS_LOG(INFO) << "KernelGraph has beed processed: " << kg->ToString();
  84. return NOT_NULL(kg->get_start_label());
  85. }
  86. memo->insert(kg.get());
  87. // 2. args replace placeholder
  88. LinkParentGraph(kg, last_node, last_label);
  89. // 3. topological sort
  90. kg->SetExecOrderByDefault();
  91. const std::vector<CNodePtr> &nodes = kg->execution_order();
  92. if (nodes.empty()) {
  93. MS_LOG(EXCEPTION) << "KernelGraph " << kg->ToString() << " has no cnodes!";
  94. }
  95. // 4. insert first_label
  96. auto start_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
  97. MS_LOG(INFO) << "Insert start label " << start_label->DebugString() << " to " << kg->ToString();
  98. kg->set_start_label(start_label);
  99. // 5. traverse
  100. for (size_t i = 0; i < nodes.size(); ++i) {
  101. auto &cnode = nodes[i];
  102. if (cnode->size() < kCNodePrim + 1) {
  103. MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
  104. }
  105. AnfNodePtr fn = cnode->input(kCNodePrim);
  106. if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) {
  107. MS_LOG(DEBUG) << "continue node " << cnode->DebugString();
  108. continue;
  109. }
  110. AnfNodePtr arg = cnode->input(kCNodeCallArg);
  111. if (IsValueNode<KernelGraph>(arg)) {
  112. RecurseCall(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo);
  113. } else if (!arg->isa<CNode>()) {
  114. MS_LOG(EXCEPTION) << "Unknown type call node " << cnode->DebugString();
  115. } else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitch)) {
  116. auto arg_cnode = arg->cast<CNodePtr>();
  117. MS_EXCEPTION_IF_NULL(arg_cnode);
  118. cnode->set_inputs(arg_cnode->inputs());
  119. RecurseSwitch(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo);
  120. } else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitchLayer)) {
  121. auto arg_cnode = arg->cast<CNodePtr>();
  122. MS_EXCEPTION_IF_NULL(arg_cnode);
  123. cnode->set_inputs(arg_cnode->inputs());
  124. RecurseSwitchLayer(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo);
  125. }
  126. }
  127. MS_LOG(INFO) << "End KernelGraph process: " << kg->ToString();
  128. return NOT_NULL(start_label);
  129. }
  130. void AscendControlParser::InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node) {
  131. std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("depend"))};
  132. auto return_node = kg->get_return();
  133. MS_EXCEPTION_IF_NULL(return_node);
  134. inputs.push_back(return_node->input(1));
  135. inputs.push_back(attch_node.get());
  136. auto depend_node = kg->NewCNode(inputs);
  137. return_node->set_input(1, depend_node);
  138. }
  139. void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node,
  140. NotNull<AnfNodePtr> second_node) {
  141. MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString()
  142. << ", the second node is " << second_node->DebugString();
  143. std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimControlDepend->name())),
  144. first_node, second_node};
  145. auto control_depend = kg->NewCNode(inputs);
  146. InsertDependToGraph(kg, NOT_NULL(control_depend));
  147. }
  148. void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node,
  149. const CNodePtr &last_label) {
  150. auto origin_return = kg->get_return();
  151. const std::vector<AnfNodePtr> &origin_return_inputs = origin_return->inputs();
  152. // if entry graph, replace return with make_tuple
  153. if (from_graph_call_node == nullptr || last_label == nullptr) {
  154. MS_LOG(INFO) << kg->ToString() << " is entry graph.";
  155. std::vector<AnfNodePtr> make_tuple_inputs = {std::make_shared<ValueNode>(prim::kPrimMakeTuple)};
  156. make_tuple_inputs.insert(make_tuple_inputs.end(), origin_return_inputs.begin() + 1, origin_return_inputs.end());
  157. auto make_tuple = kg->NewCNode(make_tuple_inputs);
  158. origin_return->set_inputs({origin_return->input(kCNodePrim), make_tuple});
  159. } else {
  160. // else replace return with label_goto
  161. auto label_goto =
  162. kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName)), last_label});
  163. MS_LOG(INFO) << "Insert end goto " << label_goto->DebugString() << " to " << kg->ToString();
  164. kg->set_end_goto(label_goto);
  165. }
  166. }
  167. void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
  168. const NotNull<std::set<KernelGraphPtr> *> memo) {
  169. MS_LOG(INFO) << "process call func " << cur_node->DebugString();
  170. // 1 get kernel graph
  171. const std::vector<AnfNodePtr> &origin_inputs = cur_node->inputs();
  172. std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName))};
  173. if (!IsValueNode<KernelGraph>(origin_inputs[kCNodeCallArg])) {
  174. MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode";
  175. return;
  176. }
  177. // 2 return label
  178. auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
  179. MS_LOG(INFO) << "Insert back label " << back_label->DebugString() << " to " << kg->ToString() << " call node "
  180. << cur_node->DebugString();
  181. // 3 add depend relationship
  182. InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label));
  183. if (next_node != nullptr && next_node != kg->get_return()) {
  184. InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node));
  185. }
  186. auto call_kg = GetValueNode<KernelGraphPtr>(origin_inputs[kCNodeCallArg]);
  187. // 4 modify call op to goto op
  188. cur_node->set_input(kCNodePrim, new_inputs[kCNodePrim]);
  189. // 5 recurse sub graph
  190. CNodePtr sub_label = ProcessKernelGraph(NOT_NULL(call_kg), cur_node, back_label, memo);
  191. new_inputs.push_back(sub_label);
  192. new_inputs.insert(new_inputs.end(), origin_inputs.begin(), origin_inputs.end());
  193. cur_node->set_inputs(new_inputs);
  194. cur_node->set_abstract(nullptr);
  195. MS_LOG(INFO) << "success process call func " << cur_node->DebugString();
  196. }
  197. void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node,
  198. const CNodePtr &next_node, const NotNull<std::set<KernelGraphPtr> *> memo) {
  199. MS_LOG(INFO) << "process switch node " << cur_node->DebugString();
  200. if (cur_node->size() < kCNodeSwitchLength) {
  201. MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLength;
  202. }
  203. // 1 return label
  204. auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
  205. MS_LOG(INFO) << "Insert back label " << back_label->DebugString() << " to " << kg->ToString() << " switch node "
  206. << cur_node->DebugString();
  207. // 2 add depend relationship
  208. InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label));
  209. if (next_node != nullptr && next_node != kg->get_return()) {
  210. InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node));
  211. }
  212. // 3 recurse sub graph
  213. const std::vector<AnfNodePtr> &origin_switch_inputs = cur_node->inputs();
  214. std::vector<AnfNodePtr> new_switch_inputs = {
  215. std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
  216. origin_switch_inputs[kCNodeSwitchCond]};
  217. for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) {
  218. // 3.1 branch kernel graph and args
  219. KernelGraphPtr branch_fg;
  220. std::tie(std::ignore, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
  221. // 3.2 recurse sub graph
  222. CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
  223. new_switch_inputs.push_back(branch_label);
  224. }
  225. std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]);
  226. new_switch_inputs.insert(new_switch_inputs.end(), origin_switch_inputs.begin(), origin_switch_inputs.end());
  227. cur_node->set_inputs(new_switch_inputs);
  228. cur_node->set_abstract(nullptr);
  229. MS_LOG(INFO) << "success process switch func " << cur_node->DebugString();
  230. }
  231. void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node,
  232. const CNodePtr &next_node,
  233. const NotNull<std::set<KernelGraphPtr> *> memo) {
  234. MS_LOG(INFO) << "process switch node " << cur_node->DebugString();
  235. if (cur_node->size() < kCNodeSwitchLayerLength) {
  236. MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength;
  237. }
  238. auto branch_tuple = cur_node->input(kCNodeSwitchLayerBranch);
  239. MS_EXCEPTION_IF_NULL(branch_tuple);
  240. if (!branch_tuple->isa<CNode>()) {
  241. MS_LOG(EXCEPTION) << branch_tuple->DebugString() << " is not a CNode";
  242. }
  243. const std::vector<AnfNodePtr> &branch_partial = utils::cast<CNodePtr>(branch_tuple)->inputs();
  244. // 1 return label
  245. auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
  246. // 2 add depend relationship
  247. InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label));
  248. if (next_node != nullptr && next_node != kg->get_return()) {
  249. InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node));
  250. }
  251. // 3 recurse sub graph
  252. const std::vector<AnfNodePtr> &origin_switch_inputs = cur_node->inputs();
  253. std::vector<AnfNodePtr> new_switch_inputs = {
  254. std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
  255. origin_switch_inputs[kCNodeSwitchCond]};
  256. for (size_t i = 0; i < branch_partial.size(); ++i) {
  257. // 3.1 branch kernel graph and args
  258. KernelGraphPtr branch_fg;
  259. std::tie(std::ignore, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
  260. // 3.2 recurse sub graph
  261. CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
  262. new_switch_inputs.push_back(branch_label);
  263. }
  264. new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end());
  265. cur_node->set_inputs(new_switch_inputs);
  266. cur_node->set_abstract(nullptr);
  267. MS_LOG(INFO) << "success process switch layer " << cur_node->DebugString();
  268. }
  269. std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) {
  270. if (!node.get()->isa<CNode>()) {
  271. MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString();
  272. }
  273. // 2.1 branch kernel graph and args
  274. auto partial_cnode = utils::cast<CNodePtr>(node.get());
  275. if (partial_cnode->size() < kCNodePartialLength) {
  276. MS_LOG(EXCEPTION) << "Inputs of partial node must more than " << kCNodePartialLength;
  277. }
  278. auto partial_inputs = partial_cnode->inputs();
  279. auto branch_kg = GetValueNode<KernelGraphPtr>(partial_inputs[kCNodePartialFunc]);
  280. return {partial_cnode, branch_kg};
  281. }
  282. void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
  283. NotNull<AnfNodePtr> to) {
  284. if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) &&
  285. AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) {
  286. return;
  287. }
  288. if (from.get() == to.get()) {
  289. return;
  290. }
  291. MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to "
  292. << to->DebugString();
  293. // config inputs of assign node
  294. std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("Assign")), to, from};
  295. // generate a new cnode
  296. auto assign_node = kg->NewCNode(inputs);
  297. MS_EXCEPTION_IF_NULL(assign_node);
  298. assign_node->set_abstract(to->abstract());
  299. // append the assign at the end of from graph
  300. InsertDependToGraph(kg, NOT_NULL(assign_node));
  301. }
  302. void AscendControlParser::LinkArgsToParam(NotNull<KernelGraphPtr> to_graph, NotNull<KernelGraphPtr> target_graph,
  303. NotNull<AnfNodePtr> arg, NotNull<AnfNodePtr> param) {
  304. if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple) && IsPrimitiveCNode(param, prim::kPrimMakeTuple)) {
  305. MS_LOG(INFO) << "Arg " << arg->DebugString() << " Param " << param->DebugString() << " is a tuple";
  306. CNodePtr cnode_arg = arg.get()->cast<CNodePtr>();
  307. CNodePtr cnode_param = param.get()->cast<CNodePtr>();
  308. MS_EXCEPTION_IF_NULL(cnode_arg);
  309. MS_EXCEPTION_IF_NULL(cnode_param);
  310. if (cnode_arg->size() != cnode_param->size()) {
  311. MS_LOG(EXCEPTION) << "Arg " << arg->DebugString() << " size " << cnode_arg->size() << " but Param "
  312. << param->DebugString() << " size " << cnode_param->size();
  313. }
  314. for (size_t i = 1; i < cnode_param->size(); ++i) {
  315. LinkArgsToParam(to_graph, target_graph, NOT_NULL(cnode_arg->input(i)), NOT_NULL(cnode_param->input(i)));
  316. }
  317. } else if (arg->isa<CNode>()) {
  318. InsertAssignToGraph(target_graph, arg, param);
  319. } else {
  320. MS_LOG(EXCEPTION) << "Arg " << arg->DebugString() << " Param " << param->DebugString() << " unknown type.";
  321. }
  322. }
  323. void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
  324. std::set<KernelGraphPtr> memo;
  325. (void)RecurseGraph(root_graph, NOT_NULL(&memo));
  326. }
  327. std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph,
  328. const NotNull<std::set<KernelGraphPtr> *> memo) {
  329. MS_LOG(INFO) << "graph:" << graph->graph_id() << " start";
  330. if (memo->find(graph) != memo->end()) {
  331. return {};
  332. }
  333. memo->insert(graph.get());
  334. graph->SetExecOrderByDefault();
  335. const std::vector<CNodePtr> &cnodes = graph->execution_order();
  336. std::vector<CNodePtr> execution_order;
  337. uint32_t child_order_index = 0;
  338. for (auto &node : cnodes) {
  339. execution_order.push_back(node);
  340. if (node == graph->get_end_goto()) {
  341. continue;
  342. }
  343. if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) {
  344. if (!CheckLabelIndex(child_order_index, 0, node, graph)) {
  345. MS_LOG(EXCEPTION) << "Check label index fail";
  346. }
  347. auto child_graph = graph->child_graph_order()[child_order_index++];
  348. if (child_graph == graph->parent_graph()) {
  349. continue;
  350. }
  351. auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo);
  352. execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end());
  353. } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) {
  354. std::vector<uint32_t> label_switch_list = GetLabelSwitchList(node);
  355. for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) {
  356. if (!CheckLabelIndex(child_order_index, *iter, node, graph)) {
  357. MS_LOG(EXCEPTION) << "Check label index fail";
  358. }
  359. auto child_graph = graph->child_graph_order()[child_order_index++];
  360. if (child_graph == graph->parent_graph()) {
  361. continue;
  362. }
  363. auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo);
  364. execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end());
  365. }
  366. }
  367. }
  368. graph->set_execution_order(execution_order);
  369. graph->PrintGraphExecuteOrder();
  370. return execution_order;
  371. }
  372. std::vector<uint32_t> AscendControlParser::GetLabelSwitchList(const CNodePtr &node) {
  373. if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, node)) {
  374. MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list";
  375. }
  376. auto primitive = AnfAlgo::GetCNodePrimitive(node);
  377. MS_EXCEPTION_IF_NULL(primitive);
  378. return GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrLabelSwitchList));
  379. }
  380. bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label,
  381. NotNull<KernelGraphPtr> graph) {
  382. const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order = graph->child_graph_order();
  383. // check index and child order size
  384. if (child_graph_order.size() <= IntToSize(order_index)) {
  385. MS_LOG(EXCEPTION) << "Child graph order is wrong, graph " << graph->ToString() << " child graph size "
  386. << child_graph_order.size() << " goto index " << order_index;
  387. }
  388. if (AnfAlgo::CheckPrimitiveType(cur_label, prim::kPrimLabelGoto)) {
  389. // check label_goto and start_label in child graph
  390. if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, cur_label)) {
  391. MS_LOG(EXCEPTION) << "LabelSetKernel has no attr label_index";
  392. }
  393. auto primitive = AnfAlgo::GetCNodePrimitive(cur_label);
  394. MS_EXCEPTION_IF_NULL(primitive);
  395. uint32_t label_goto_index = GetValue<uint32_t>(primitive->GetAttr(kAttrLabelIndex));
  396. label_index = label_goto_index;
  397. }
  398. // get start_label_set_index of child graph
  399. auto child_graph = child_graph_order[order_index];
  400. MS_EXCEPTION_IF_NULL(child_graph);
  401. auto start_label_set = child_graph->get_start_label();
  402. if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, start_label_set)) {
  403. MS_LOG(EXCEPTION) << "LabelSetKernel has no attr label_index";
  404. }
  405. auto start_primitive = AnfAlgo::GetCNodePrimitive(start_label_set);
  406. MS_EXCEPTION_IF_NULL(start_primitive);
  407. uint32_t start_label_set_index = GetValue<uint32_t>(start_primitive->GetAttr(kAttrLabelIndex));
  408. if (label_index != start_label_set_index) {
  409. MS_LOG(WARNING) << cur_label->DebugString() << " index " << label_index << " but " << start_label_set->DebugString()
  410. << " index " << start_label_set_index << " current child graph order : " << order_index;
  411. return false;
  412. }
  413. return true;
  414. }
  415. void AscendControlParser::UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg) {
  416. MS_LOG(INFO) << "graph id:" << kg->graph_id();
  417. kg->SetExecOrderByDefault();
  418. auto call_nodes = kg->FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name()));
  419. std::vector<KernelGraphPtr> child_graph_order;
  420. for (auto &call_node : call_nodes) {
  421. MS_EXCEPTION_IF_NULL(call_node);
  422. auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast<CNodePtr>());
  423. for (const auto &child_graph : call_child_graphs) {
  424. MS_EXCEPTION_IF_NULL(child_graph);
  425. if (child_graph != kg->parent_graph()) {
  426. child_graph->set_parent_graph(kg.get());
  427. }
  428. child_graph_order.push_back(child_graph);
  429. }
  430. }
  431. for (size_t i = 0; i < child_graph_order.size(); i++) {
  432. MS_LOG(INFO) << "child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]";
  433. }
  434. kg->set_child_graph_order(child_graph_order);
  435. }
  436. } // namespace session
  437. } // namespace mindspore