You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ascend_control_parser.cc 29 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "session/ascend_control_parser.h"
  17. #include <utility>
  18. #include <memory>
  19. #include "session/anf_runtime_algorithm.h"
  20. #include "utils/union_find_set.h"
  21. #include "device/ascend/ascend_label_assign.h"
  22. static constexpr size_t kCNodePrim = 0;
  23. static constexpr size_t kCNodeCallArg = 1;
  24. static constexpr size_t kCNodeSwitchCond = 1;
  25. static constexpr size_t kCNodeSwitchTrue = 2;
  26. static constexpr size_t kCNodeSwitchFalse = 3;
  27. static constexpr size_t kCNodeSwitchLength = 4;
  28. static constexpr size_t kCNodePartialLength = 2;
  29. static constexpr size_t kCNodePartialFunc = 1;
  30. static constexpr size_t kCNodeSwitchLayerBranch = 2;
  31. static constexpr size_t kCNodeSwitchLayerLength = 3;
  32. namespace mindspore {
  33. namespace session {
  34. static CNodePtr GetJumpNode(NotNull<KernelGraphPtr> parent_graph, NotNull<KernelGraphPtr> child_graph) {
  35. auto &nodes = parent_graph->execution_order();
  36. CNodePtr last_jump_node = nullptr;
  37. for (auto &node : nodes) {
  38. if (IsPrimitiveCNode(node, prim::kPrimLabelGoto)) {
  39. if (child_graph->get_start_label() == node->input(kCNodeCallArg)) {
  40. return node;
  41. }
  42. last_jump_node = node;
  43. } else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch)) {
  44. if (child_graph->get_start_label() == node->input(kCNodeSwitchFalse) ||
  45. child_graph->get_start_label() == node->input(kCNodeSwitchTrue)) {
  46. return node;
  47. }
  48. last_jump_node = node;
  49. }
  50. }
  51. if (last_jump_node == nullptr) {
  52. MS_LOG(EXCEPTION) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString();
  53. }
  54. return last_jump_node;
  55. }
  56. static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set,
  57. const NotNull<std::set<KernelGraphPtr> *> memo) {
  58. if (memo->find(kg.get()) != memo->end()) {
  59. return;
  60. }
  61. memo->insert(kg.get());
  62. const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs = kg->real_inputs();
  63. for (auto &iter : real_inputs) {
  64. auto &para = iter.first;
  65. MS_EXCEPTION_IF_NULL(para);
  66. if (para->isa<Parameter>()) {
  67. union_find_set->Add(para);
  68. }
  69. for (auto &arg : iter.second) {
  70. MS_EXCEPTION_IF_NULL(arg);
  71. if (!arg->isa<Parameter>()) {
  72. continue;
  73. }
  74. union_find_set->Add(arg);
  75. }
  76. }
  77. for (auto &child : kg->child_graph_order()) {
  78. InitUnionFindSet(NOT_NULL(child), union_find_set, memo);
  79. }
  80. }
  81. static void UnionParentParameter(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set,
  82. const NotNull<std::set<KernelGraphPtr> *> memo) {
  83. if (memo->find(kg.get()) != memo->end()) {
  84. return;
  85. }
  86. memo->insert(kg.get());
  87. const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs = kg->real_inputs();
  88. for (auto &iter : real_inputs) {
  89. auto &para = iter.first;
  90. for (auto &arg : iter.second) {
  91. MS_EXCEPTION_IF_NULL(arg);
  92. if (!arg->isa<Parameter>()) {
  93. continue;
  94. }
  95. if (kg->unreuse_args().find(arg) != kg->unreuse_args().end()) {
  96. continue;
  97. }
  98. union_find_set->Union(arg, para);
  99. }
  100. }
  101. for (auto &child : kg->child_graph_order()) {
  102. UnionParentParameter(NOT_NULL(child), union_find_set, memo);
  103. }
  104. }
  105. static UnionFindSet<AnfNodePtr> MakeUnionFindSet(NotNull<KernelGraphPtr> root_kg) {
  106. UnionFindSet<AnfNodePtr> result;
  107. std::set<KernelGraphPtr> memo;
  108. InitUnionFindSet(root_kg, NOT_NULL(&result), NOT_NULL(&memo));
  109. memo.clear();
  110. UnionParentParameter(root_kg, NOT_NULL(&result), NOT_NULL(&memo));
  111. return result;
  112. }
  113. static void RecursiveReplaceNode(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> main_parameter,
  114. const std::set<AnfNodePtr> &parameter_reuse_set,
  115. const NotNull<std::set<KernelGraphPtr> *> memo) {
  116. if (parameter_reuse_set.empty()) {
  117. MS_LOG(EXCEPTION) << "Parameter_reuse_set is empty.";
  118. }
  119. if (memo->find(kg.get()) != memo->end()) {
  120. return;
  121. }
  122. memo->insert(kg.get());
  123. for (auto &para : parameter_reuse_set) {
  124. if (para == main_parameter.get()) {
  125. continue;
  126. }
  127. MS_EXCEPTION_IF_NULL(para);
  128. MS_LOG(INFO) << "Replace " << para->DebugString() << " of graph " << AnfAlgo::GetGraphId(para.get()) << " to "
  129. << main_parameter->DebugString() << " of graph " << AnfAlgo::GetGraphId(main_parameter.get().get());
  130. kg->ReplaceNode(NOT_NULL(para), main_parameter);
  131. }
  132. for (auto &child : kg->child_graph_order()) {
  133. RecursiveReplaceNode(NOT_NULL(child), main_parameter, parameter_reuse_set, memo);
  134. }
  135. }
  136. static AnfNodePtr GetMainParameter(NotNull<KernelGraphPtr> root_kg, const AnfNodePtr key,
  137. const std::set<AnfNodePtr> &parameter_reuse_set) {
  138. AnfNodePtr main_parameter = key;
  139. std::set<AnfNodePtr> root_inputs_set;
  140. const auto &root_inputs_vector = root_kg->inputs();
  141. root_inputs_set.insert(root_inputs_vector.begin(), root_inputs_vector.end());
  142. for (auto &node : parameter_reuse_set) {
  143. if (root_inputs_set.find(node) != root_inputs_set.end()) {
  144. main_parameter = node;
  145. break;
  146. }
  147. }
  148. return main_parameter;
  149. }
  150. static void ReuseParameter(NotNull<KernelGraphPtr> root_kg, NotNull<UnionFindSet<AnfNodePtr> *> parameter_set) {
  151. auto parameter_reuse_sets = parameter_set->GetSets();
  152. for (auto &[key, parameter_reuse_set] : parameter_reuse_sets) {
  153. if (parameter_reuse_set.size() <= 1) {
  154. continue;
  155. }
  156. auto main_parameter = GetMainParameter(root_kg, key, parameter_reuse_set);
  157. std::set<KernelGraphPtr> memo;
  158. RecursiveReplaceNode(root_kg, NOT_NULL(main_parameter), parameter_reuse_set, NOT_NULL(&memo));
  159. }
  160. }
  161. CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) {
  162. for (size_t i = start; i < list.size() - 1; ++i) {
  163. if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) {
  164. return list[i];
  165. }
  166. }
  167. return nullptr;
  168. }
  169. void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
  170. std::set<KernelGraphPtr> memo;
  171. (void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo));
  172. device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kg);
  173. std::map<uint32_t, KernelGraphPtr> graph_id_map;
  174. for (auto &g : memo) {
  175. MS_EXCEPTION_IF_NULL(g);
  176. if (graph_id_map.find(g->graph_id()) != graph_id_map.end()) {
  177. MS_LOG(EXCEPTION) << "Two graph has same graph id " << g->graph_id()
  178. << ", graph: " << graph_id_map[g->graph_id()]->ToString() << " " << g->ToString();
  179. }
  180. graph_id_map[g->graph_id()] = g;
  181. }
  182. // Insert Assign
  183. ChildGraphDataAssign(graph_id_map);
  184. // Make UnionFindSet
  185. UnionFindSet<AnfNodePtr> parameter_set = MakeUnionFindSet(kg);
  186. // Reuse Parameter
  187. ReuseParameter(kg, NOT_NULL(&parameter_set));
  188. }
  189. void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
  190. std::set<KernelGraphPtr> memo;
  191. (void)RecurseGraph(root_graph, NOT_NULL(&memo));
  192. }
  193. void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map) {
  194. for (auto &iter : graph_id_map) {
  195. auto &kg = iter.second;
  196. MS_LOG(INFO) << "Data assign graph:" << kg->graph_id();
  197. MS_EXCEPTION_IF_NULL(kg);
  198. std::set<std::pair<AnfNodePtr, AnfNodePtr>> memo;
  199. const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs = kg->real_inputs();
  200. for (auto &it : real_inputs) {
  201. auto &parameter = it.first;
  202. auto &args = it.second;
  203. for (auto &arg : args) {
  204. MS_EXCEPTION_IF_NULL(arg);
  205. if (memo.find({parameter, arg}) != memo.end()) {
  206. continue;
  207. } else {
  208. memo.emplace(parameter, arg);
  209. }
  210. auto unreuse_args_map = kg->unreuse_args();
  211. auto unreuse_arg_iter = unreuse_args_map.find(arg);
  212. if (unreuse_arg_iter == unreuse_args_map.end()) {
  213. MS_EXCEPTION_IF_NULL(arg);
  214. MS_EXCEPTION_IF_NULL(parameter);
  215. if (!arg->isa<Parameter>()) {
  216. MS_LOG(EXCEPTION) << "Reused arg must be parameter, arg:" << arg->DebugString() << ".";
  217. }
  218. MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString()
  219. << ", arg:" << arg->DebugString();
  220. continue;
  221. }
  222. auto target_graph_iter = graph_id_map.find(AnfAlgo::GetGraphId(arg.get()));
  223. if (target_graph_iter == graph_id_map.end()) {
  224. MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found.";
  225. }
  226. InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(kg), NOT_NULL(arg),
  227. NOT_NULL(parameter));
  228. }
  229. }
  230. kg->SetExecOrderByDefault();
  231. }
  232. }
  233. NotNull<CNodePtr> AscendControlParser::GetStartLabel(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
  234. const CNodePtr &last_label) {
  235. CNodePtr start_label;
  236. if (last_node != nullptr && last_label != nullptr) {
  237. start_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
  238. MS_LOG(INFO) << "Insert start label " << start_label->DebugString() << " to " << kg->ToString();
  239. kg->set_start_label(start_label);
  240. } else {
  241. // no goto node will jump to start label of root graph, so return a fake label
  242. start_label = std::make_shared<CNode>(std::vector<AnfNodePtr>(), FuncGraphPtr(nullptr));
  243. }
  244. return NOT_NULL(start_label);
  245. }
  246. NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
  247. const CNodePtr &last_label,
  248. const NotNull<std::set<KernelGraphPtr> *> memo) {
  249. MS_LOG(INFO) << "Start process KernelGraph " << kg->ToString();
  250. // 1. recursive condition
  251. if (memo->find(kg) != memo->end()) {
  252. MS_LOG(INFO) << "KernelGraph has beed processed: " << kg->ToString();
  253. return NOT_NULL(kg->get_start_label());
  254. }
  255. memo->insert(kg.get());
  256. // 2. args replace placeholder
  257. LinkParentGraph(kg, last_node, last_label);
  258. // 3. topological sort
  259. kg->SetExecOrderByDefault();
  260. const std::vector<CNodePtr> &nodes = kg->execution_order();
  261. // 4. insert first_label
  262. CNodePtr start_label = GetStartLabel(kg, last_node, last_label);
  263. // 5. traverse
  264. for (size_t i = 0; i < nodes.size(); ++i) {
  265. auto &cnode = nodes[i];
  266. MS_EXCEPTION_IF_NULL(cnode);
  267. if (cnode->size() < kCNodePrim + 1) {
  268. MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
  269. }
  270. AnfNodePtr fn = cnode->input(kAnfPrimitiveIndex);
  271. if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) {
  272. MS_LOG(DEBUG) << "Continue node " << cnode->DebugString();
  273. continue;
  274. }
  275. AnfNodePtr arg = cnode->input(kFirstDataInputIndex);
  276. MS_EXCEPTION_IF_NULL(arg);
  277. if (IsValueNode<KernelGraph>(arg)) {
  278. RecurseCall(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo);
  279. } else if (!arg->isa<CNode>()) {
  280. MS_LOG(EXCEPTION) << "Unknown type call node " << cnode->DebugString();
  281. } else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitch)) {
  282. auto arg_cnode = arg->cast<CNodePtr>();
  283. MS_EXCEPTION_IF_NULL(arg_cnode);
  284. cnode->set_inputs(arg_cnode->inputs());
  285. RecurseSwitch(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo);
  286. } else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitchLayer)) {
  287. auto arg_cnode = arg->cast<CNodePtr>();
  288. MS_EXCEPTION_IF_NULL(arg_cnode);
  289. cnode->set_inputs(arg_cnode->inputs());
  290. RecurseSwitchLayer(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo);
  291. }
  292. }
  293. kg->SetExecOrderByDefault();
  294. MS_LOG(INFO) << "End KernelGraph process: " << kg->ToString();
  295. return NOT_NULL(start_label);
  296. }
  297. void AscendControlParser::InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node) {
  298. auto return_node = kg->get_return();
  299. MS_EXCEPTION_IF_NULL(return_node);
  300. std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
  301. return_node->input(kFirstDataInputIndex), attch_node.get()};
  302. auto depend_node = kg->NewCNode(inputs);
  303. return_node->set_input(1, depend_node);
  304. }
  305. void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node,
  306. NotNull<AnfNodePtr> second_node) {
  307. MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString()
  308. << ", the second node is " << second_node->DebugString();
  309. std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimControlDepend->name())),
  310. first_node, second_node};
  311. auto control_depend = kg->NewCNode(inputs);
  312. InsertDependToGraph(kg, NOT_NULL(control_depend));
  313. }
  314. void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node,
  315. const CNodePtr &last_label) {
  316. // if not entry graph, replace return with label_goto
  317. if (from_graph_call_node != nullptr && last_label != nullptr) {
  318. auto label_goto =
  319. kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName)), last_label});
  320. MS_EXCEPTION_IF_NULL(label_goto);
  321. MS_LOG(INFO) << "Insert end goto " << label_goto->DebugString() << " to " << kg->ToString();
  322. kg->set_end_goto(label_goto);
  323. }
  324. }
  325. void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
  326. const NotNull<std::set<KernelGraphPtr> *> memo) {
  327. MS_LOG(INFO) << "Process call func " << cur_node->DebugString();
  328. // 1 get kernel graph
  329. const std::vector<AnfNodePtr> &origin_inputs = cur_node->inputs();
  330. if (kCNodeCallArg >= origin_inputs.size()) {
  331. MS_LOG(EXCEPTION) << "Index out of range,size:" << origin_inputs.size();
  332. }
  333. std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName))};
  334. if (!IsValueNode<KernelGraph>(origin_inputs[kCNodeCallArg])) {
  335. MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode";
  336. return;
  337. }
  338. // 2 return label
  339. auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
  340. MS_LOG(INFO) << "Insert back label " << back_label->DebugString() << " to " << kg->ToString() << " call node "
  341. << cur_node->DebugString();
  342. // 3 add depend relationship
  343. InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label));
  344. if (next_node != nullptr && next_node != kg->get_return()) {
  345. InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node));
  346. }
  347. auto call_kg = GetValueNode<KernelGraphPtr>(origin_inputs[kCNodeCallArg]);
  348. // 4 modify call op to goto op
  349. cur_node->set_input(kCNodePrim, new_inputs[kCNodePrim]);
  350. // 5 recurse sub graph
  351. CNodePtr sub_label = ProcessKernelGraph(NOT_NULL(call_kg), cur_node, back_label, memo);
  352. new_inputs.push_back(sub_label);
  353. cur_node->set_inputs(new_inputs);
  354. cur_node->set_abstract(nullptr);
  355. MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString();
  356. }
  357. void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node,
  358. const CNodePtr &next_node, const NotNull<std::set<KernelGraphPtr> *> memo) {
  359. MS_LOG(INFO) << "Process switch node " << cur_node->DebugString();
  360. if (cur_node->size() < kCNodeSwitchLength) {
  361. MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLength;
  362. }
  363. // 1 return label
  364. auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
  365. MS_EXCEPTION_IF_NULL(back_label);
  366. MS_LOG(INFO) << "Insert back label " << back_label->DebugString() << " to " << kg->ToString() << " switch node "
  367. << cur_node->DebugString();
  368. // 2 add depend relationship
  369. InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label));
  370. if (next_node != nullptr && next_node != kg->get_return()) {
  371. InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node));
  372. }
  373. // 3 recurse sub graph
  374. const std::vector<AnfNodePtr> &origin_switch_inputs = cur_node->inputs();
  375. if (kCNodeSwitchCond >= origin_switch_inputs.size()) {
  376. MS_LOG(EXCEPTION) << "The size of origin_switch_inputs is not more than " << kCNodeSwitchCond;
  377. }
  378. std::vector<AnfNodePtr> new_switch_inputs = {
  379. std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
  380. origin_switch_inputs[kCNodeSwitchCond]};
  381. for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) {
  382. // 3.1 branch kernel graph and args
  383. KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
  384. // 3.2 recurse sub graph
  385. CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
  386. new_switch_inputs.push_back(branch_label);
  387. }
  388. std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]);
  389. cur_node->set_inputs(new_switch_inputs);
  390. cur_node->set_abstract(nullptr);
  391. MS_LOG(INFO) << "Succeed processing switch func " << cur_node->DebugString();
  392. }
  393. void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node,
  394. const CNodePtr &next_node,
  395. const NotNull<std::set<KernelGraphPtr> *> memo) {
  396. MS_LOG(INFO) << "Process switch node " << cur_node->DebugString();
  397. if (cur_node->size() < kCNodeSwitchLayerLength) {
  398. MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength;
  399. }
  400. auto branch_tuple = cur_node->input(kCNodeSwitchLayerBranch);
  401. MS_EXCEPTION_IF_NULL(branch_tuple);
  402. if (!branch_tuple->isa<CNode>()) {
  403. MS_LOG(EXCEPTION) << branch_tuple->DebugString() << " is not a CNode";
  404. }
  405. const std::vector<AnfNodePtr> &branch_partial = utils::cast<CNodePtr>(branch_tuple)->inputs();
  406. // 1 return label
  407. auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
  408. // 2 add depend relationship
  409. InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label));
  410. if (next_node != nullptr && next_node != kg->get_return()) {
  411. InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node));
  412. }
  413. // 3 recurse sub graph
  414. const std::vector<AnfNodePtr> &origin_switch_inputs = cur_node->inputs();
  415. if (kCNodeSwitchCond >= origin_switch_inputs.size()) {
  416. MS_LOG(EXCEPTION) << "Index out of range:" << origin_switch_inputs.size() << ".";
  417. }
  418. std::vector<AnfNodePtr> new_switch_inputs = {
  419. std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
  420. origin_switch_inputs[kCNodeSwitchCond]};
  421. for (size_t i = 0; i < branch_partial.size(); ++i) {
  422. // 3.1 branch kernel graph and args
  423. KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
  424. // 3.2 recurse sub graph
  425. CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
  426. new_switch_inputs.push_back(branch_label);
  427. }
  428. new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end());
  429. cur_node->set_inputs(new_switch_inputs);
  430. cur_node->set_abstract(nullptr);
  431. MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString();
  432. }
  433. KernelGraphPtr AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) {
  434. if (!node.get()->isa<CNode>()) {
  435. if (IsValueNode<KernelGraph>(node)) {
  436. return GetValueNode<KernelGraphPtr>(node);
  437. }
  438. MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString();
  439. }
  440. // 2.1 branch kernel graph and args
  441. auto partial_cnode = utils::cast<CNodePtr>(node.get());
  442. MS_EXCEPTION_IF_NULL(partial_cnode);
  443. if (partial_cnode->size() < kCNodePartialLength) {
  444. MS_LOG(EXCEPTION) << "Inputs of partial node must more than " << kCNodePartialLength;
  445. }
  446. const auto &partial_inputs = partial_cnode->inputs();
  447. if (kCNodePartialFunc >= partial_inputs.size()) {
  448. MS_LOG(EXCEPTION) << "Index out of range:" << partial_inputs.size() << ".";
  449. }
  450. auto branch_kg = GetValueNode<KernelGraphPtr>(partial_inputs[kCNodePartialFunc]);
  451. return branch_kg;
  452. }
  453. void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph,
  454. NotNull<KernelGraphPtr> to_graph, NotNull<AnfNodePtr> from,
  455. NotNull<AnfNodePtr> to) {
  456. std::vector<AnfNodePtr> from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem});
  457. std::vector<AnfNodePtr> to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem});
  458. MS_LOG(INFO) << "Insert multi-assign from [" << from->DebugString() << "] to [" << to->DebugString() << "]";
  459. if (from_outputs.size() != to_outputs.size()) {
  460. MS_LOG(EXCEPTION) << "From outputs size[" << from_outputs.size() << "] is not equal to to outputs size["
  461. << to_outputs.size() << "]";
  462. }
  463. for (size_t i = 0; i < from_outputs.size(); i++) {
  464. auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i]));
  465. if (assign_node != nullptr) {
  466. auto jump_node = GetJumpNode(from_graph, to_graph);
  467. const auto &from_graph_exe_order = from_graph->execution_order();
  468. auto jump_node_iter = std::find(from_graph_exe_order.begin(), from_graph_exe_order.end(), jump_node);
  469. if (jump_node_iter == from_graph_exe_order.end()) {
  470. MS_EXCEPTION_IF_NULL(jump_node);
  471. MS_LOG(EXCEPTION) << "Can't find node:" << jump_node->DebugString() << " in graph:" << from_graph->graph_id();
  472. }
  473. // insert assign between jump_node -1 and jump_node
  474. if (jump_node_iter != from_graph_exe_order.begin()) {
  475. InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node));
  476. }
  477. if (jump_node != nullptr) {
  478. InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node));
  479. }
  480. }
  481. }
  482. }
  483. AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
  484. NotNull<AnfNodePtr> to) {
  485. if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) &&
  486. AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) {
  487. return nullptr;
  488. }
  489. if (from.get() == to.get()) {
  490. return nullptr;
  491. }
  492. MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to "
  493. << to->DebugString();
  494. // config inputs of assign node
  495. std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimAssign->name())), to, from};
  496. // generate a new cnode
  497. auto assign_node = kg->NewCNode(inputs);
  498. MS_EXCEPTION_IF_NULL(assign_node);
  499. assign_node->set_abstract(to->abstract());
  500. return assign_node;
  501. }
  502. std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph,
  503. const NotNull<std::set<KernelGraphPtr> *> memo) {
  504. MS_LOG(INFO) << "Graph:" << graph->graph_id() << " start";
  505. if (memo->find(graph) != memo->end()) {
  506. return {};
  507. }
  508. memo->insert(graph.get());
  509. graph->SetExecOrderByDefault();
  510. std::vector<CNodePtr> cnodes = graph->execution_order();
  511. auto end_label_goto = graph->get_end_goto();
  512. if (cnodes.rbegin() != cnodes.rend() && *cnodes.rbegin() == end_label_goto) {
  513. cnodes.pop_back();
  514. }
  515. AnfAlgo::ReorderExecList(NOT_NULL(&cnodes));
  516. if (end_label_goto != nullptr) {
  517. cnodes.push_back(end_label_goto);
  518. }
  519. std::vector<CNodePtr> execution_order;
  520. uint32_t child_order_index = 0;
  521. for (auto &node : cnodes) {
  522. execution_order.push_back(node);
  523. if (node == graph->get_end_goto()) {
  524. continue;
  525. }
  526. if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) {
  527. std::vector<uint32_t> label_switch_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(node, kAttrLabelSwitchList);
  528. for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) {
  529. if (!CheckLabelIndex(child_order_index, *iter, node, graph)) {
  530. MS_LOG(EXCEPTION) << "Check label index fail";
  531. }
  532. if (child_order_index >= graph->child_graph_order().size()) {
  533. MS_LOG(EXCEPTION) << "Index out of range:" << graph->child_graph_order().size();
  534. }
  535. auto child_graph = graph->child_graph_order()[child_order_index++];
  536. auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo);
  537. execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end());
  538. }
  539. } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) {
  540. uint32_t label_index = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
  541. if (!CheckLabelIndex(child_order_index, label_index, node, graph)) {
  542. MS_LOG(EXCEPTION) << "Check label index fail";
  543. }
  544. auto child_graph = graph->child_graph_order()[child_order_index++];
  545. auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo);
  546. execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end());
  547. }
  548. }
  549. graph->set_execution_order(execution_order);
  550. graph->PrintGraphExecuteOrder();
  551. return execution_order;
  552. }
  553. bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label,
  554. NotNull<KernelGraphPtr> graph) {
  555. const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order = graph->child_graph_order();
  556. // check index and child order size
  557. if (child_graph_order.size() <= IntToSize(order_index)) {
  558. MS_LOG(EXCEPTION) << "Child graph order is wrong, graph " << graph->ToString() << " child graph size "
  559. << child_graph_order.size() << " goto index " << order_index;
  560. }
  561. auto child_graph = child_graph_order[order_index];
  562. MS_EXCEPTION_IF_NULL(child_graph);
  563. // get start_label_set_index of child graph
  564. auto start_label_set = child_graph->get_start_label();
  565. uint32_t start_label_set_index = AnfAlgo::GetNodeAttr<uint32_t>(start_label_set, kAttrLabelIndex);
  566. if (label_index != start_label_set_index) {
  567. MS_EXCEPTION_IF_NULL(cur_label);
  568. MS_EXCEPTION_IF_NULL(start_label_set);
  569. MS_LOG(WARNING) << cur_label->DebugString() << " index " << label_index << " but " << start_label_set->DebugString()
  570. << " index " << start_label_set_index << " current child graph order : " << order_index;
  571. return false;
  572. } else {
  573. return true;
  574. }
  575. }
  576. void AscendControlParser::UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg) {
  577. MS_LOG(INFO) << "Graph id:" << kg->graph_id();
  578. kg->SetExecOrderByDefault();
  579. auto call_nodes = kg->FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name()));
  580. std::vector<KernelGraphPtr> child_graph_order;
  581. for (auto &call_node : call_nodes) {
  582. MS_EXCEPTION_IF_NULL(call_node);
  583. auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast<CNodePtr>());
  584. for (const auto &child_graph : call_child_graphs) {
  585. MS_EXCEPTION_IF_NULL(child_graph);
  586. if (child_graph != kg->parent_graph()) {
  587. child_graph->set_parent_graph(kg.get());
  588. }
  589. child_graph_order.push_back(child_graph);
  590. }
  591. }
  592. for (size_t i = 0; i < child_graph_order.size(); i++) {
  593. MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]";
  594. }
  595. kg->set_child_graph_order(child_graph_order);
  596. }
  597. } // namespace session
  598. } // namespace mindspore