You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

branch_culling.cc 27 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/optimizer/irpass/branch_culling.h"
  17. #include <memory>
  18. #include <utility>
  19. #include "utils/hash_map.h"
  20. #include "ir/func_graph.h"
  21. #include "frontend/operator/ops.h"
  22. namespace mindspore {
  23. namespace opt {
  24. namespace irpass {
  25. constexpr size_t kCondIndex = 1;
  26. constexpr size_t kTrueBranchIndex = 2;
  27. constexpr size_t kFalseBranchIndex = 3;
  28. namespace internal {
  29. AnfNodePtr GenerateSwitchNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data,
  30. int64_t switch_idx) {
  31. auto switch_node = prim::GetPythonOps("geswitch", "mindspore.ops.functional")->cast<PrimitivePtr>();
  32. std::vector<AnfNodePtr> switch_nodes{NewValueNode(switch_node), data, cond};
  33. auto switch_apply = graph->NewCNode(switch_nodes);
  34. std::vector<AnfNodePtr> tuple_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), switch_apply,
  35. NewValueNode(MakeValue(switch_idx))};
  36. return graph->NewCNode(tuple_getitem_nodes);
  37. }
  38. AnfNodePtr GenerateSwitchTrueNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) {
  39. return GenerateSwitchNode(graph, cond, data, 1);
  40. }
  41. AnfNodePtr GenerateSwitchFalseNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) {
  42. return GenerateSwitchNode(graph, cond, data, 0);
  43. }
  44. bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
  45. // The CNode inputs of the following Primitive with index in std::vector<size_t> should not be guarded by geswitch
  46. // node because it is attribute or ge specific reason.
  47. // Example : when convert CNode(kPrimReduceSum, x, axis), node of index 2 in CNode->inputs is axis which should not be
  48. // converted to switch guarded.
  49. #ifndef ENABLE_SECURITY
  50. std::vector<std::pair<PrimitivePtr, std::vector<size_t>>> white_list({{prim::kPrimApplyMomentum, {1, 2}},
  51. {prim::kPrimMomentum, {2, 3}},
  52. {prim::kPrimStateSetItem, {1}},
  53. {prim::kPrimTupleGetItem, {2}},
  54. {prim::kPrimEnvGetItem, {1}},
  55. {prim::kPrimEnvSetItem, {1}},
  56. {prim::kPrimReduceSum, {2}},
  57. {prim::kPrimReduceMean, {2}},
  58. {prim::kPrimReduceAll, {2}},
  59. {prim::kPrimCast, {2}},
  60. {prim::kPrimTranspose, {2}},
  61. {prim::kPrimOneHot, {2}},
  62. {prim::kPrimGather, {3}},
  63. {prim::kPrimReshape, {2}},
  64. {prim::kPrimAssign, {1}},
  65. {prim::kPrimAssignAdd, {1}},
  66. {prim::kPrimAssignSub, {1}},
  67. {prim::kPrimTensorSummary, {1}},
  68. {prim::kPrimImageSummary, {1}},
  69. {prim::kPrimScalarSummary, {1}},
  70. {prim::kPrimApplyRMSProp, {6, 7, 8}},
  71. {prim::kPrimCumSum, {2}},
  72. {prim::kPrimTile, {2}},
  73. {prim::kPrimExpandDims, {2}},
  74. {prim::kPrimHistogramSummary, {1}}});
  75. #else
  76. std::vector<std::pair<PrimitivePtr, std::vector<size_t>>> white_list(
  77. {{prim::kPrimApplyMomentum, {1, 2}}, {prim::kPrimMomentum, {2, 3}},
  78. {prim::kPrimStateSetItem, {1}}, {prim::kPrimTupleGetItem, {2}},
  79. {prim::kPrimEnvGetItem, {1}}, {prim::kPrimEnvSetItem, {1}},
  80. {prim::kPrimReduceSum, {2}}, {prim::kPrimReduceMean, {2}},
  81. {prim::kPrimReduceAll, {2}}, {prim::kPrimCast, {2}},
  82. {prim::kPrimTranspose, {2}}, {prim::kPrimOneHot, {2}},
  83. {prim::kPrimGather, {3}}, {prim::kPrimReshape, {2}},
  84. {prim::kPrimAssign, {1}}, {prim::kPrimAssignAdd, {1}},
  85. {prim::kPrimAssignSub, {1}}, {prim::kPrimApplyRMSProp, {6, 7, 8}},
  86. {prim::kPrimCumSum, {2}}, {prim::kPrimTile, {2}},
  87. {prim::kPrimExpandDims, {2}}});
  88. #endif
  89. for (auto &item : white_list) {
  90. auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) {
  91. return IsPrimitiveCNode(node, item.first) && idx == index;
  92. });
  93. if (matched) {
  94. return true;
  95. }
  96. }
  97. std::vector<PrimitivePtr> adapter_convert_ops = {prim::kPrimDepend, prim::kPrimLoad};
  98. for (auto &item : adapter_convert_ops) {
  99. if (IsPrimitiveCNode(node, item)) {
  100. return true;
  101. }
  102. }
  103. return false;
  104. }
  105. using NodeInputReplMap = mindspore::HashMap<std::pair<AnfNodePtr, size_t>, AnfNodePtr, PairHasher>;
  106. // replace the nodes which should be changed
  107. void RunSwitchNodeReplace(const FuncGraphManagerPtr &manager, std::vector<std::pair<CNodePtr, CNodePtr>> nodes_changed,
  108. mindspore::HashMap<AnfNodePtr, AnfNodePtr> repl_node, NodeInputReplMap repl_node_inputs,
  109. const FuncGraphPtr &func_graph) {
  110. for (auto &node_pair : nodes_changed) {
  111. CNodePtr old_node = node_pair.first;
  112. CNodePtr new_node = node_pair.second;
  113. MS_EXCEPTION_IF_NULL(old_node);
  114. MS_EXCEPTION_IF_NULL(new_node);
  115. for (size_t i = 0; i < old_node->size(); i++) {
  116. auto input = old_node->input(i);
  117. if (repl_node.count(input) != 0) {
  118. new_node->add_input(repl_node[input]);
  119. } else if (repl_node_inputs.count(std::pair<AnfNodePtr, size_t>(old_node, i)) != 0) {
  120. new_node->add_input(repl_node_inputs[std::pair<AnfNodePtr, size_t>(old_node, i)]);
  121. } else {
  122. new_node->add_input(input);
  123. }
  124. }
  125. }
  126. for (auto &item : repl_node) {
  127. if (IsPrimitiveCNode(item.second, prim::kPrimReturn)) {
  128. func_graph->set_output(item.second->cast<CNodePtr>()->input(1));
  129. } else if (!manager->Replace(item.first, item.second)) {
  130. MS_LOG(EXCEPTION) << "TransformGraphDependNode replace node failed original:" << item.first->DebugString(2)
  131. << " to new: " << item.second->DebugString(2);
  132. }
  133. }
  134. }
  135. // trace the node that should add switch and replace them with new nodes in the graph
  136. FuncGraphPtr TransformGraphCondBranchNodes(
  137. const FuncGraphPtr &graph, const AnfNodePtr &cond,
  138. const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &generate_func) {
  139. auto manager = graph->manager();
  140. MS_EXCEPTION_IF_NULL(manager);
  141. // record the node that has been changed
  142. std::vector<std::pair<CNodePtr, CNodePtr>> nodes_changed;
  143. // record the node to be replaced
  144. mindspore::HashMap<AnfNodePtr, AnfNodePtr> repl_node;
  145. // record the node input to be replaced
  146. NodeInputReplMap repl_node_inputs;
  147. const AnfNodeSet &nodes = graph->nodes();
  148. for (auto &node : nodes) {
  149. MS_EXCEPTION_IF_NULL(node);
  150. if (!node->isa<CNode>()) {
  151. continue;
  152. }
  153. auto inputs = node->cast<CNodePtr>()->inputs();
  154. bool should_replace = false;
  155. // if the apply input does not belong to graph, insert a switch node
  156. for (size_t index = 0; index < inputs.size(); index++) {
  157. auto input_node = inputs[index];
  158. MS_EXCEPTION_IF_NULL(input_node);
  159. if (HasAbstractMonad(input_node)) {
  160. // Do not guard with switch for monad inputs.
  161. continue;
  162. }
  163. // for some ops input should not guard it with switch
  164. if (InConvertWhiteList(node, index)) {
  165. continue;
  166. }
  167. // If the input for node is not the graph belonged, or it is an ValueNode.
  168. // Bypass the Primitive node which is inputs[0].
  169. if ((index >= 1 && input_node->func_graph() != nullptr && input_node->func_graph() != graph) ||
  170. ((index >= 1 && input_node->isa<ValueNode>()))) {
  171. input_node = generate_func(graph, cond, input_node);
  172. repl_node_inputs[std::pair<AnfNodePtr, size_t>(node, index)] = input_node;
  173. should_replace = true;
  174. }
  175. if (input_node == nullptr) {
  176. MS_LOG(EXCEPTION) << "generate switch node failed";
  177. }
  178. }
  179. if (should_replace) {
  180. auto new_node = graph->NewCNode();
  181. repl_node[node] = new_node;
  182. nodes_changed.emplace_back(node->cast<CNodePtr>(), new_node);
  183. }
  184. }
  185. RunSwitchNodeReplace(manager, nodes_changed, repl_node, repl_node_inputs, graph);
  186. return graph;
  187. }
  188. struct SharedOp {
  189. tensor::TensorPtr const_data;
  190. CNodePtr square_ops[2];
  191. CNodePtr merge_ops[2];
  192. } MergeNetOutput;
  193. inline tensor::TensorPtr GetConstData() { return MergeNetOutput.const_data; }
  194. inline void SetConstData(const tensor::TensorPtr &const_value) { MergeNetOutput.const_data = const_value; }
  195. inline CNodePtr GetSquareOp(int64_t switch_idx) { return MergeNetOutput.square_ops[switch_idx]; }
  196. inline void SetSquareOp(int64_t switch_idx, const CNodePtr &op) { MergeNetOutput.square_ops[switch_idx] = op; }
  197. inline CNodePtr GetMergeOp(int64_t switch_idx) { return MergeNetOutput.merge_ops[switch_idx]; }
  198. inline void SetMergeOp(int64_t switch_idx, const CNodePtr &op) { MergeNetOutput.merge_ops[switch_idx] = op; }
  199. inline void ResetSharedOp() {
  200. SetConstData(nullptr);
  201. SetSquareOp(0, nullptr);
  202. SetSquareOp(1, nullptr);
  203. SetMergeOp(0, nullptr);
  204. SetMergeOp(1, nullptr);
  205. }
  206. tensor::TensorPtr ConstData() {
  207. std::vector<int64_t> shp = {1};
  208. tensor::TensorPtr const_data = std::make_shared<tensor::Tensor>(kInt64->type_id(), shp);
  209. auto *val = static_cast<int64_t *>(const_data->data_c());
  210. *val = 0;
  211. return const_data;
  212. }
  213. CNodePtr SquareOp(const FuncGraphPtr &graph, const AnfNodePtr &cond, int64_t switch_idx,
  214. const tensor::TensorPtr &const_data) {
  215. auto PrimSquare = prim::GetPythonOps("square", "mindspore.ops.functional")->cast<PrimitivePtr>();
  216. // for the depended node , add two const data to merge the flow ,one for depended node with same switch,
  217. // the other use the opposite
  218. auto ctrl_data = NewValueNode(const_data);
  219. auto ctrl_node = GenerateSwitchNode(graph, cond, ctrl_data, switch_idx);
  220. std::vector<AnfNodePtr> square_nodes{NewValueNode(PrimSquare), ctrl_node};
  221. auto square_op = graph->NewCNode(square_nodes);
  222. return square_op;
  223. }
  224. CNodePtr MergeNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, int64_t switch_idx,
  225. const tensor::TensorPtr &const_data, const CNodePtr &square_op) {
  226. // for the depended node , add two const data to merge the flow ,one for depended node with same switch,
  227. // the other use the opposite
  228. auto oppsite_ctrl_data = NewValueNode(const_data);
  229. auto opposite_ctrl_node = GenerateSwitchNode(graph, cond, oppsite_ctrl_data, 1 - switch_idx);
  230. std::vector<AnfNodePtr> merge_nodes;
  231. auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast<PrimitivePtr>();
  232. merge_nodes.push_back(NewValueNode(PrimMerge));
  233. std::vector<AnfNodePtr> make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), square_op, opposite_ctrl_node};
  234. merge_nodes.push_back(graph->NewCNode(make_tuple_nodes));
  235. auto merge_op = graph->NewCNode(merge_nodes);
  236. return merge_op;
  237. }
  238. // merge(square_op(switch(ctrl_data)), switch(opposite_ctrl_data))
  239. AnfNodePtr GenerateSwitchDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &output_node,
  240. int64_t switch_idx) {
  241. tensor::TensorPtr const_data = GetConstData();
  242. if (const_data == nullptr) {
  243. const_data = ConstData();
  244. SetConstData(const_data);
  245. }
  246. CNodePtr square_op = GetSquareOp(switch_idx);
  247. if (square_op == nullptr) {
  248. square_op = SquareOp(graph, cond, switch_idx, const_data);
  249. SetSquareOp(switch_idx, square_op);
  250. }
  251. auto manager = graph->manager();
  252. MS_EXCEPTION_IF_NULL(manager);
  253. AnfNodePtrList inputs = {NewValueNode(prim::kPrimDepend), square_op, output_node};
  254. auto depend_cnode = graph->NewCNode(inputs);
  255. if (!manager->Replace(square_op, depend_cnode)) {
  256. MS_LOG(EXCEPTION) << square_op->DebugString() << ", replace node failed.";
  257. }
  258. CNodePtr merge_op = GetMergeOp(switch_idx);
  259. if (merge_op == nullptr) {
  260. merge_op = MergeNode(graph, cond, switch_idx, const_data, square_op);
  261. SetMergeOp(switch_idx, merge_op);
  262. }
  263. return merge_op;
  264. }
  265. // generate switch nodes for true graph node inputs
  266. AnfNodePtr GenerateSwitchDependTrueNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) {
  267. // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch
  268. return GenerateSwitchDependNode(graph, cond, data, 1);
  269. }
  270. // generate switch nodes for false graph node inputs
  271. AnfNodePtr GenerateSwitchDependFalseNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) {
  272. // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch
  273. return GenerateSwitchDependNode(graph, cond, data, 0);
  274. }
  275. // to judge if the node used in Depend is a net output node
  276. bool IsNetOutputNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) {
  277. auto uses = manager->node_users()[node];
  278. bool is_output_node = true;
  279. for (auto &item : uses) {
  280. if (IsPrimitiveCNode(item.first, prim::kPrimDepend)) {
  281. continue;
  282. }
  283. is_output_node = false;
  284. break;
  285. }
  286. return is_output_node;
  287. }
  288. // generate node for Depended MakeTuple
  289. void GenerateReplNodeForDependMakeTuple(
  290. const AnfNodePtr &depended_node, const FuncGraphPtr &graph, const AnfNodePtr &cond,
  291. const std::shared_ptr<mindspore::HashMap<AnfNodePtr, AnfNodePtr>> &repl_node,
  292. const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &generate_func) {
  293. MS_EXCEPTION_IF_NULL(graph->manager());
  294. auto make_tuple_inputs = depended_node->cast<CNodePtr>()->inputs();
  295. const size_t make_tuple_begin_idx = 1;
  296. std::vector<AnfNodePtr> new_make_tuple_nodes;
  297. bool replace_make_tuple = false;
  298. new_make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple));
  299. for (size_t idx = make_tuple_begin_idx; idx < make_tuple_inputs.size(); idx++) {
  300. auto depended_tuple_input_node = make_tuple_inputs[idx];
  301. if (IsPrimitiveCNode(depended_tuple_input_node->cast<CNodePtr>(), prim::kPrimDepend)) {
  302. new_make_tuple_nodes.push_back(depended_tuple_input_node);
  303. continue;
  304. }
  305. if (graph->manager()->node_users()[depended_tuple_input_node].size() == 1) {
  306. auto gen_node = generate_func(graph, cond, depended_tuple_input_node);
  307. new_make_tuple_nodes.push_back(gen_node);
  308. replace_make_tuple = true;
  309. continue;
  310. }
  311. MS_LOG(WARNING) << "depended node being used by others, ";
  312. }
  313. if (replace_make_tuple) {
  314. auto make_tuple_op = graph->NewCNode(new_make_tuple_nodes);
  315. (*repl_node)[depended_node] = make_tuple_op;
  316. }
  317. }
  318. // generate a replace depend node for a single network output node
  319. void GenerateRepDepend(
  320. const CNodePtr &node, const FuncGraphPtr &graph, const AnfNodePtr &cond,
  321. const std::shared_ptr<mindspore::HashMap<AnfNodePtr, AnfNodePtr>> &repl_node,
  322. const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &generate_func) {
  323. MS_EXCEPTION_IF_NULL(graph->manager());
  324. auto inputs = node->inputs();
  325. if (inputs.size() != kDependInputSize) {
  326. MS_LOG(EXCEPTION) << "Inputs should be [depend, actual_value, depended_node].";
  327. }
  328. std::vector<AnfNodePtr> new_depened_inputs;
  329. // Inputs should be [depend, actual_value, depended_node]
  330. auto depended_node = inputs[kDependAttachNodeIndex];
  331. new_depened_inputs.push_back(inputs[0]);
  332. new_depened_inputs.push_back(inputs[1]);
  333. // depended node should be make_tuple or a single depended node
  334. if (IsPrimitiveCNode(depended_node, prim::kPrimMakeTuple)) {
  335. GenerateReplNodeForDependMakeTuple(depended_node, graph, cond, repl_node, generate_func);
  336. } else {
  337. // Check if there is only single user for depend_node.
  338. if (graph->manager()->node_users()[depended_node].size() == 1) {
  339. auto gen_node = generate_func(graph, cond, depended_node);
  340. (*repl_node)[depended_node] = gen_node;
  341. } else {
  342. MS_LOG(WARNING) << "depended node being used by others";
  343. }
  344. }
  345. }
  346. // generate depend node for netoutput node, to resolve the stream synchronize problem of ge
  347. // traverse all nodes of depend node, find the graph output node , generaete a merge node of (square, const)
  348. FuncGraphPtr TransformGraphDependNode(
  349. const FuncGraphPtr &graph, const AnfNodePtr &cond,
  350. const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &gen_depend_func) {
  351. auto manager = graph->manager();
  352. MS_EXCEPTION_IF_NULL(manager);
  353. ResetSharedOp();
  354. std::shared_ptr<mindspore::HashMap<AnfNodePtr, AnfNodePtr>> repl_node =
  355. std::make_shared<mindspore::HashMap<AnfNodePtr, AnfNodePtr>>(); // record the node to be replaced
  356. const AnfNodeSet &nodes = graph->nodes();
  357. for (auto &node : nodes) {
  358. MS_EXCEPTION_IF_NULL(node);
  359. if (!node->isa<CNode>()) {
  360. continue;
  361. }
  362. if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
  363. auto cnode = node->cast<CNodePtr>();
  364. if (cnode->size() != kDependInputSize) {
  365. MS_LOG(EXCEPTION) << "Dependnode input size != " << kDependInputSize;
  366. }
  367. auto depended_node = cnode->input(kDependAttachNodeIndex);
  368. MS_EXCEPTION_IF_NULL(depended_node);
  369. if (!depended_node->isa<CNode>()) {
  370. continue;
  371. }
  372. if (IsPrimitiveCNode(depended_node, prim::kPrimDepend)) {
  373. continue;
  374. }
  375. GenerateRepDepend(cnode, graph, cond, repl_node, gen_depend_func);
  376. }
  377. }
  378. ResetSharedOp();
  379. for (auto &item : *repl_node) {
  380. if (!manager->Replace(item.first, item.second)) {
  381. MS_LOG(EXCEPTION) << "TransformGraphDependNode replace node failed";
  382. }
  383. }
  384. return graph;
  385. }
  386. FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) {
  387. (void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchTrueNode);
  388. return TransformGraphDependNode(graph, cond, GenerateSwitchDependTrueNode);
  389. }
  390. FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) {
  391. (void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchFalseNode);
  392. return TransformGraphDependNode(graph, cond, GenerateSwitchDependFalseNode);
  393. }
  394. // judge if the true and false graph output is compatible(they shall have same tuple size)
  395. bool GraphOutputCompatible(const AbstractBasePtr &true_branch_abs, const AbstractBasePtr &false_branch_abs) {
  396. MS_EXCEPTION_IF_NULL(true_branch_abs);
  397. MS_EXCEPTION_IF_NULL(false_branch_abs);
  398. if (true_branch_abs->isa<abstract::AbstractTuple>() && false_branch_abs->isa<abstract::AbstractTuple>()) {
  399. abstract::AbstractTuplePtr true_branch_tuple = true_branch_abs->cast<abstract::AbstractTuplePtr>();
  400. abstract::AbstractTuplePtr false_branch_tuple = false_branch_abs->cast<abstract::AbstractTuplePtr>();
  401. if (true_branch_tuple->elements().size() != false_branch_tuple->elements().size()) {
  402. MS_LOG(ERROR) << "true branch size:" << true_branch_tuple->elements().size()
  403. << ", not equal to false branch size:" << false_branch_tuple->elements().size() << " ";
  404. return false;
  405. }
  406. bool all_compatible = true;
  407. for (size_t i = 0; i < true_branch_tuple->elements().size(); i++) {
  408. all_compatible =
  409. all_compatible && GraphOutputCompatible(true_branch_tuple->elements()[i], false_branch_tuple->elements()[i]);
  410. }
  411. return all_compatible;
  412. }
  413. TypePtr true_branch_type = true_branch_abs->BuildType();
  414. TypePtr false_branch_type = false_branch_abs->BuildType();
  415. MS_LOG(DEBUG) << "branch output Type equal?" << (*true_branch_type == *false_branch_type)
  416. << " true:" << true_branch_type->ToString() << " false:" << false_branch_type->ToString();
  417. return (*true_branch_type == *false_branch_type);
  418. }
  419. // block_nodes[0]: condition node
  420. // block_nodes[1]: true branch node
  421. // block_nodes[2]: false branch node
  422. // branch_output_abs[0]: true branch abstract
  423. // branch_output_abs[1]: false branch abstract
  424. AnfNodePtr GenerateMergeNodes(const std::vector<AnfNodePtr> &block_nodes,
  425. const std::vector<AbstractBasePtr> &branch_output_abs, const FuncGraphPtr &switch_graph) {
  426. MS_EXCEPTION_IF_NULL(branch_output_abs[0]);
  427. MS_EXCEPTION_IF_NULL(branch_output_abs[1]);
  428. MS_EXCEPTION_IF_NULL(block_nodes[0]);
  429. MS_EXCEPTION_IF_NULL(switch_graph);
  430. auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast<PrimitivePtr>();
  431. MS_EXCEPTION_IF_NULL(PrimMerge);
  432. if (!branch_output_abs[0]->isa<abstract::AbstractTuple>()) {
  433. std::vector<AnfNodePtr> merge_nodes;
  434. merge_nodes.push_back(NewValueNode(PrimMerge));
  435. std::vector<AnfNodePtr> make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), block_nodes[1], block_nodes[2]};
  436. merge_nodes.push_back(switch_graph->NewCNode(make_tuple_nodes));
  437. std::vector<AnfNodePtr> tuple_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem),
  438. switch_graph->NewCNode(merge_nodes),
  439. NewValueNode(MakeValue(static_cast<int64_t>(0)))};
  440. return switch_graph->NewCNode(tuple_getitem_nodes);
  441. } else {
  442. auto true_branch_tuple = branch_output_abs[0]->cast<abstract::AbstractTuplePtr>();
  443. auto false_branch_tuple = branch_output_abs[1]->cast<abstract::AbstractTuplePtr>();
  444. std::vector<AnfNodePtr> make_tuple_nodes;
  445. make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple));
  446. for (size_t i = 0; i < true_branch_tuple->elements().size(); i++) {
  447. std::vector<AnfNodePtr> true_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), block_nodes[1],
  448. NewValueNode(MakeValue(SizeToLong(i)))};
  449. auto true_node = switch_graph->NewCNode(true_getitem_nodes);
  450. std::vector<AnfNodePtr> false_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), block_nodes[2],
  451. NewValueNode(MakeValue(SizeToLong(i)))};
  452. auto false_node = switch_graph->NewCNode(false_getitem_nodes);
  453. auto merge_node = GenerateMergeNodes(
  454. {
  455. block_nodes[0],
  456. true_node,
  457. false_node,
  458. },
  459. {true_branch_tuple->elements()[i], false_branch_tuple->elements()[i]}, switch_graph);
  460. make_tuple_nodes.push_back(merge_node);
  461. }
  462. return switch_graph->NewCNode(make_tuple_nodes);
  463. }
  464. }
  465. AnfNodePtr TransformMergeBranches(const std::vector<AnfNodePtr> &block_nodes,
  466. const std::vector<AbstractBasePtr> &branch_output_abs,
  467. const FuncGraphPtr &func_graph) {
  468. if (!GraphOutputCompatible(branch_output_abs[0], branch_output_abs[1])) {
  469. MS_LOG(EXCEPTION) << "Switch output branch not compatible, true:" << branch_output_abs[0]->ToString()
  470. << ", false:" << branch_output_abs[1]->ToString();
  471. }
  472. return GenerateMergeNodes(block_nodes, branch_output_abs, func_graph);
  473. }
  474. } // namespace internal
  475. bool ConvertSwitchReplacement::CheckSwitchBranch(const AnfNodePtr &node) {
  476. if (!IsValueNode<FuncGraph>(node)) {
  477. return false;
  478. }
  479. // If graph contains FuncGraph, then ignore this node.
  480. auto graph = GetValueNode<FuncGraphPtr>(node);
  481. for (auto &item : graph->value_nodes()) {
  482. auto value_node = item.first;
  483. if (IsValueNode<FuncGraph>(value_node)) {
  484. return false;
  485. }
  486. }
  487. return true;
  488. }
  489. bool ConvertSwitchReplacement::CheckSwitchWrapNode(const AnfNodePtr &node) {
  490. // {{prim::kPrimSwitch, X, G1, G2}, Xs}.
  491. if (node->isa<CNode>()) {
  492. auto inp0 = node->cast<CNodePtr>()->input(0);
  493. if (IsPrimitiveCNode(inp0, prim::kPrimSwitch)) {
  494. auto switch_node = inp0->cast<CNodePtr>();
  495. // for switch replace method, only graphs without graph inside can be replaced
  496. if (CheckSwitchBranch(switch_node->input(kTrueBranchIndex)) &&
  497. CheckSwitchBranch(switch_node->input(kFalseBranchIndex))) {
  498. return true;
  499. }
  500. }
  501. }
  502. return false;
  503. }
  504. void ConvertSwitchReplacement::TransformSwitchBranchReplace(const AnfNodePtr &node) {
  505. auto cnode = node->cast<CNodePtr>();
  506. auto switch_cnode = cnode->input(0)->cast<CNodePtr>();
  507. auto cond = switch_cnode->input(kCondIndex);
  508. auto true_br = switch_cnode->input(kTrueBranchIndex);
  509. auto false_br = switch_cnode->input(kFalseBranchIndex);
  510. auto g1 = GetValueNode<FuncGraphPtr>(true_br);
  511. auto g2 = GetValueNode<FuncGraphPtr>(false_br);
  512. auto true_output = g1->output()->abstract();
  513. auto false_output = g2->output()->abstract();
  514. auto trans_g1 = internal::TransformGraphCondTrueBranchNodes(g1, cond);
  515. auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2, cond);
  516. std::vector<AnfNodePtr> params;
  517. if (cnode && cnode->size() > 1) {
  518. // There are arguments for the call of switch result,
  519. // usually these are monad states added by auto-monad.
  520. for (size_t i = 1; i < cnode->size(); ++i) {
  521. params.push_back(cnode->inputs().at(i));
  522. }
  523. }
  524. auto fg = node->func_graph();
  525. auto cloned_g1 = InlineClone(trans_g1, fg, params);
  526. auto cloned_g2 = InlineClone(trans_g2, fg, params);
  527. auto new_node = internal::TransformMergeBranches({cond, cloned_g1, cloned_g2}, {true_output, false_output}, fg);
  528. (void)fg->manager()->Replace(node, new_node);
  529. }
  530. } // namespace irpass
  531. } // namespace opt
  532. } // namespace mindspore