/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "frontend/optimizer/irpass/branch_culling.h" #include #include #include #include "ir/func_graph.h" #include "frontend/operator/ops.h" namespace mindspore { namespace opt { namespace irpass { namespace internal { AnfNodePtr GenerateSwitchNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data, int switch_idx) { auto switch_node = prim::GetPythonOps("geswitch", "mindspore.ops.functional")->cast(); std::vector switch_nodes{NewValueNode(switch_node), data, cond}; auto switch_apply = graph->NewCNode(switch_nodes); std::vector tuple_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), switch_apply, NewValueNode(MakeValue(switch_idx))}; return graph->NewCNode(tuple_getitem_nodes); } AnfNodePtr GenerateSwitchTrueNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) { return GenerateSwitchNode(graph, cond, data, 1); } AnfNodePtr GenerateSwitchFalseNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) { return GenerateSwitchNode(graph, cond, data, 0); } bool InConvertWhiteList(const AnfNodePtr &node, size_t index) { // The CNode inputs of the following Primitive with index in std::vector should not be guarded by geswitch // node because it is attribute or ge specific reason. // Example : when convert CNode(kPrimReduceSum, x, axis), node of index 2 in CNode->inputs is axis which should not be // converted to switch guarded. std::vector>> white_list({{prim::kPrimApplyMomentum, {1, 2}}, {prim::kPrimMomentum, {2, 3}}, {prim::kPrimStateSetItem, {1}}, {prim::kPrimTupleGetItem, {2}}, {prim::kPrimEnvGetItem, {1}}, {prim::kPrimEnvSetItem, {1}}, {prim::kPrimReduceSum, {2}}, {prim::kPrimReduceMean, {2}}, {prim::kPrimReduceAll, {2}}, {prim::kPrimCast, {2}}, {prim::kPrimTranspose, {2}}, {prim::kPrimOneHot, {2}}, {prim::kPrimGatherV2, {3}}, {prim::kPrimReshape, {2}}, {prim::kPrimAssign, {1}}, {prim::kPrimAssignAdd, {1}}, {prim::kPrimAssignSub, {1}}, {prim::kPrimTensorSummary, {1}}, {prim::kPrimImageSummary, {1}}, {prim::kPrimScalarSummary, {1}}, {prim::kPrimApplyRMSProp, {6, 7, 8}}, {prim::kPrimCumSum, {2}}, {prim::kPrimTile, {2}}, {prim::kPrimExpandDims, {2}}, {prim::kPrimHistogramSummary, {1}}}); for (auto &item : white_list) { auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) { return IsPrimitiveCNode(node, item.first) && idx == index; }); if (matched) { return true; } } std::vector adapter_convert_ops = {prim::kPrimDepend, prim::kPrimControlDepend}; for (auto &item : adapter_convert_ops) { if (IsPrimitiveCNode(node, item)) { return true; } } return false; } using NodeInputReplMap = std::unordered_map, AnfNodePtr, PairHasher>; // replace the nodes which should be changed void RunSwitchNodeReplace(const FuncGraphManagerPtr &manager, std::vector> nodes_changed, std::unordered_map repl_node, NodeInputReplMap repl_node_inputs, const FuncGraphPtr &func_graph) { for (auto &node_pair : nodes_changed) { CNodePtr old_node = node_pair.first; CNodePtr new_node = node_pair.second; MS_EXCEPTION_IF_NULL(old_node); MS_EXCEPTION_IF_NULL(new_node); for (size_t i = 0; i < old_node->size(); i++) { auto input = old_node->input(i); if (repl_node.count(input) != 0) { new_node->add_input(repl_node[input]); } else if (repl_node_inputs.count(std::pair(old_node, i)) != 0) { new_node->add_input(repl_node_inputs[std::pair(old_node, i)]); } else { new_node->add_input(input); } } } for (auto &item : repl_node) { if (IsPrimitiveCNode(item.second, prim::kPrimReturn)) { func_graph->set_output(item.second->cast()->input(1)); } else if (!manager->Replace(item.first, item.second)) { MS_LOG(EXCEPTION) << "TransformGraphDependNode replace node failed original:" << item.first->DebugString(2) << " to new: " << item.second->DebugString(2); } } } // trace the node that should add switch and replace them with new nodes in the graph FuncGraphPtr TransformGraphCondBranchNodes( const FuncGraphPtr &graph, const AnfNodePtr &cond, const std::function &generate_func) { auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); // record the node that has been changed std::vector> nodes_changed; // record the node to be replaced std::unordered_map repl_node; // record the node input to be replaced NodeInputReplMap repl_node_inputs; const AnfNodeSet &nodes = graph->nodes(); for (auto &node : nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; } auto inputs = node->cast()->inputs(); bool should_replace = false; // if the apply input does not belong to graph, insert a switch node for (size_t index = 0; index < inputs.size(); index++) { auto input_node = inputs[index]; MS_EXCEPTION_IF_NULL(input_node); // for some ops input should not guard it with switch if (InConvertWhiteList(node, index)) { continue; } // If the input for node is not the graph belonged, or it is an ValueNode. // Bypass the Primitive node which is inputs[0]. if ((index >= 1 && inputs[index]->func_graph() != nullptr && inputs[index]->func_graph() != graph) || ((index >= 1 && inputs[index]->isa()))) { input_node = generate_func(graph, cond, inputs[index]); repl_node_inputs[std::pair(node, index)] = input_node; should_replace = true; } if (input_node == nullptr) { MS_LOG(EXCEPTION) << "generate switch node failed"; } } if (should_replace) { auto new_node = graph->NewCNode(); repl_node[node] = new_node; nodes_changed.emplace_back(node->cast(), new_node); } } RunSwitchNodeReplace(manager, nodes_changed, repl_node, repl_node_inputs, graph); return graph; } struct SharedOp { tensor::TensorPtr const_data; CNodePtr square_ops[2]; CNodePtr merge_ops[2]; } MergeNetOutput; inline tensor::TensorPtr GetConstData() { return MergeNetOutput.const_data; } inline void SetConstData(const tensor::TensorPtr &const_value) { MergeNetOutput.const_data = const_value; } inline CNodePtr GetSquareOp(int switch_idx) { return MergeNetOutput.square_ops[switch_idx]; } inline void SetSquareOp(int switch_idx, const CNodePtr &op) { MergeNetOutput.square_ops[switch_idx] = op; } inline CNodePtr GetMergeOp(int switch_idx) { return MergeNetOutput.merge_ops[switch_idx]; } inline void SetMergeOp(int switch_idx, const CNodePtr &op) { MergeNetOutput.merge_ops[switch_idx] = op; } inline void ResetSharedOp() { SetConstData(nullptr); SetSquareOp(0, nullptr); SetSquareOp(1, nullptr); SetMergeOp(0, nullptr); SetMergeOp(1, nullptr); } tensor::TensorPtr ConstData() { std::vector shp = {1}; tensor::TensorPtr const_data = std::make_shared(kInt32->type_id(), shp); auto *val = static_cast(const_data->data_c()); *val = 0; return const_data; } CNodePtr SquareOp(const FuncGraphPtr &graph, const AnfNodePtr &cond, int switch_idx, const tensor::TensorPtr &const_data) { auto PrimSquare = prim::GetPythonOps("square", "mindspore.ops.functional")->cast(); // for the depended node , add two const data to merge the flow ,one for depended node with same switch, // the other use the opposite auto ctrl_data = NewValueNode(const_data); auto ctrl_node = GenerateSwitchNode(graph, cond, ctrl_data, switch_idx); std::vector square_nodes{NewValueNode(PrimSquare), ctrl_node}; auto square_op = graph->NewCNode(square_nodes); return square_op; } CNodePtr MergeNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, int switch_idx, const tensor::TensorPtr &const_data, const CNodePtr &square_op) { // for the depended node , add two const data to merge the flow ,one for depended node with same switch, // the other use the opposite auto oppsite_ctrl_data = NewValueNode(const_data); auto opposite_ctrl_node = GenerateSwitchNode(graph, cond, oppsite_ctrl_data, 1 - switch_idx); std::vector merge_nodes; auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast(); merge_nodes.push_back(NewValueNode(PrimMerge)); std::vector make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), square_op, opposite_ctrl_node}; merge_nodes.push_back(graph->NewCNode(make_tuple_nodes)); auto merge_op = graph->NewCNode(merge_nodes); return merge_op; } // construct a depend node with merge output node, merge(square_op(switch(ctrl_data)), switch(opposite_ctrl_data)) // control_depend(output_node, square_op) AnfNodePtr GenerateSwitchDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &output_node, int switch_idx) { tensor::TensorPtr const_data = GetConstData(); if (const_data == nullptr) { const_data = ConstData(); SetConstData(const_data); } CNodePtr square_op = GetSquareOp(switch_idx); if (square_op == nullptr) { square_op = SquareOp(graph, cond, switch_idx, const_data); SetSquareOp(switch_idx, square_op); } CNodePtr merge_op = GetMergeOp(switch_idx); if (merge_op == nullptr) { merge_op = MergeNode(graph, cond, switch_idx, const_data, square_op); SetMergeOp(switch_idx, merge_op); } std::vector control_depend_nodes{NewValueNode(prim::kPrimControlDepend), output_node, square_op}; auto control_depend_op = graph->NewCNode(control_depend_nodes); std::vector depend_nodes{NewValueNode(prim::kPrimDepend), merge_op, control_depend_op}; auto depend_op = graph->NewCNode(depend_nodes); return depend_op; } // construct a merge output and add dependency with the netoutput node from control_depend // we need to reserve the control_depend node, besides the generated merge node and control_depend node CNodePtr GenerateSwitchControlDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &ctrl_dep_node, const AnfNodePtr &ctrl_depend_dst, int switch_idx) { auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast(); auto PrimSquare = prim::GetPythonOps("square", "mindspore.ops.functional")->cast(); std::vector shp = {1}; tensor::TensorPtr const_data = std::make_shared(kInt32->type_id(), shp); auto *val = static_cast(const_data->data_c()); *val = 0; // for the control_depend netoutput node , add two const data to merge the flow ,one for depended node with same // switch the other use the opposite auto ctrl_data = NewValueNode(const_data); auto oppsite_ctrl_data = NewValueNode(const_data); auto ctrl_node = GenerateSwitchNode(graph, cond, ctrl_data, switch_idx); auto opposite_ctrl_node = GenerateSwitchNode(graph, cond, oppsite_ctrl_data, 1 - switch_idx); std::vector square_nodes{NewValueNode(PrimSquare), ctrl_node}; auto square_op = graph->NewCNode(square_nodes); std::vector merge_nodes; merge_nodes.push_back(NewValueNode(PrimMerge)); std::vector make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), square_op, opposite_ctrl_node}; merge_nodes.push_back(graph->NewCNode(make_tuple_nodes)); auto merge_output = graph->NewCNode(merge_nodes); std::vector control_depend_nodes{NewValueNode(prim::kPrimControlDepend), ctrl_depend_dst, square_op}; auto cond_dep_output = graph->NewCNode(control_depend_nodes); std::vector depended_make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), ctrl_dep_node, merge_output, cond_dep_output}; return graph->NewCNode(depended_make_tuple_nodes); } // generate switch nodes for true graph node inputs AnfNodePtr GenerateSwitchDependTrueNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) { // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch return GenerateSwitchDependNode(graph, cond, data, 1); } // generate switch nodes for false graph node inputs AnfNodePtr GenerateSwitchDependFalseNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) { // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch return GenerateSwitchDependNode(graph, cond, data, 0); } // generate switch nodes for true graph node inputs CNodePtr GenerateSwitchControlDependTrueNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &con_input, const AnfNodePtr &output) { // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch return GenerateSwitchControlDependNode(graph, cond, con_input, output, 1); } // generate switch nodes for false graph node inputs CNodePtr GenerateSwitchControlDependFalseNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &con_input, const AnfNodePtr &output) { // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch return GenerateSwitchControlDependNode(graph, cond, con_input, output, 0); } // to judge if the node used in ControlDepend is a net output node bool IsNetOutputNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) { auto uses = manager->node_users()[node]; bool is_output_node = true; for (auto &item : uses) { if (IsPrimitiveCNode(item.first, prim::kPrimControlDepend) || IsPrimitiveCNode(item.first, prim::kPrimDepend)) { continue; } is_output_node = false; break; } return is_output_node; } // generate node for Depended MakeTuple void GenerateReplNodeForDependMakeTuple( const AnfNodePtr &depended_node, const FuncGraphPtr &graph, const AnfNodePtr &cond, const std::shared_ptr> &repl_node, const std::function &generate_func, const std::function &gen_ctl_depd_func) { MS_EXCEPTION_IF_NULL(graph->manager()); auto make_tuple_inputs = depended_node->cast()->inputs(); const size_t make_tuple_begin_idx = 1; std::vector new_make_tuple_nodes; bool replace_make_tuple = false; new_make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); for (size_t idx = make_tuple_begin_idx; idx < make_tuple_inputs.size(); idx++) { auto depended_tuple_input_node = make_tuple_inputs[idx]; if (IsPrimitiveCNode(depended_tuple_input_node->cast(), prim::kPrimDepend)) { new_make_tuple_nodes.push_back(depended_tuple_input_node); continue; } if (IsPrimitiveCNode(depended_tuple_input_node->cast(), prim::kPrimControlDepend)) { // only when the control depend input is not square op (the op to use as merge output) auto control_inputs = depended_tuple_input_node->cast()->inputs(); if (control_inputs.size() != 3) { MS_LOG(EXCEPTION) << "controldepend input size != 3, got " << control_inputs.size(); } // control inputs: primitive, src, dst auto dst_node = control_inputs[2]; if (!IsPrimitiveCNode(dst_node, prim::kPrimSquare) && IsNetOutputNode(graph->manager(), dst_node)) { auto gen_node = gen_ctl_depd_func(graph, cond, make_tuple_inputs[idx], dst_node); MS_EXCEPTION_IF_NULL(gen_node); auto tuple_inputs = gen_node->inputs(); // add depended tuple inputs to new_make_tuple directly for (size_t i = 1; i < tuple_inputs.size(); i++) { new_make_tuple_nodes.push_back(tuple_inputs[i]); } } replace_make_tuple = true; continue; } if (graph->manager()->node_users()[depended_tuple_input_node].size() == 1) { auto gen_node = generate_func(graph, cond, depended_tuple_input_node); new_make_tuple_nodes.push_back(gen_node); replace_make_tuple = true; continue; } MS_LOG(WARNING) << "depended node being used by others, "; } if (replace_make_tuple) { auto make_tuple_op = graph->NewCNode(new_make_tuple_nodes); (*repl_node)[depended_node] = make_tuple_op; } } // generate a replace depend node for a single network output node void GenerateRepDepend( const CNodePtr &node, const FuncGraphPtr &graph, const AnfNodePtr &cond, const std::shared_ptr> &repl_node, const std::function &generate_func, const std::function &gen_ctl_depd_func) { auto inputs = node->inputs(); if (inputs.size() != 3) { MS_LOG(EXCEPTION) << "Inputs should be [depend, actual_value, depended_node]."; } std::vector new_depened_inputs; // Inputs should be [depend, actual_value, depended_node] auto depended_node = inputs[2]; new_depened_inputs.push_back(inputs[0]); new_depened_inputs.push_back(inputs[1]); // depended node should be make_tuple or a single depended node if (IsPrimitiveCNode(depended_node, prim::kPrimMakeTuple)) { GenerateReplNodeForDependMakeTuple(depended_node, graph, cond, repl_node, generate_func, gen_ctl_depd_func); } else if (IsPrimitiveCNode(depended_node, prim::kPrimControlDepend)) { // only when the control depend input is not square op (the op to use as merge output) auto control_inputs = depended_node->cast()->inputs(); // control inputs: primitive, src, dst if (control_inputs.size() != 3) { MS_LOG(EXCEPTION) << "controldepend input size != 3, got " << control_inputs.size(); } auto dst_node = control_inputs[2]; if (!IsPrimitiveCNode(dst_node, prim::kPrimSquare) && IsNetOutputNode(graph->manager(), dst_node)) { auto gen_node = gen_ctl_depd_func(graph, cond, depended_node, dst_node); (*repl_node)[depended_node] = gen_node; } } else { // Check if there is only single user for depend_node. if (graph->manager()->node_users()[depended_node].size() == 1) { auto gen_node = generate_func(graph, cond, depended_node); (*repl_node)[depended_node] = gen_node; } else { MS_LOG(WARNING) << "depended node being used by others"; } } } // generate depend node for netoutput node, to resolve the stream synchronize problem of ge // traverse all nodes of depend node, find the graph output node , generaete a merge node of (square, const) // and add control_depend of graph output node and square node. FuncGraphPtr TransformGraphDependNode( const FuncGraphPtr &graph, const AnfNodePtr &cond, const std::function &gen_depend_func, const std::function &gen_ctl_depd_func) { auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); ResetSharedOp(); std::shared_ptr> repl_node = std::make_shared>(); // record the node to be replaced const AnfNodeSet &nodes = graph->nodes(); for (auto &node : nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; } if (IsPrimitiveCNode(node, prim::kPrimDepend)) { auto cnode = node->cast(); if (cnode->size() != 3) { MS_LOG(EXCEPTION) << "Dependnode input size != 3"; } auto depended_node = cnode->input(2); MS_EXCEPTION_IF_NULL(depended_node); if (!depended_node->isa()) { continue; } if (IsPrimitiveCNode(depended_node, prim::kPrimDepend)) { continue; } GenerateRepDepend(cnode, graph, cond, repl_node, gen_depend_func, gen_ctl_depd_func); } } ResetSharedOp(); for (auto &item : *repl_node) { if (!manager->Replace(item.first, item.second)) { MS_LOG(EXCEPTION) << "TransformGraphDependNode replace node failed"; } } return graph; } FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) { (void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchTrueNode); return TransformGraphDependNode(graph, cond, GenerateSwitchDependTrueNode, GenerateSwitchControlDependTrueNode); } FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) { (void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchFalseNode); return TransformGraphDependNode(graph, cond, GenerateSwitchDependFalseNode, GenerateSwitchControlDependFalseNode); } // judge if the true and false graph output is compatible(they shall have same tuple size) bool GraphOutputCompatible(const AbstractBasePtr &true_branch_abs, const AbstractBasePtr &false_branch_abs) { MS_EXCEPTION_IF_NULL(true_branch_abs); MS_EXCEPTION_IF_NULL(false_branch_abs); if (true_branch_abs->isa() && false_branch_abs->isa()) { abstract::AbstractTuplePtr true_branch_tuple = true_branch_abs->cast(); abstract::AbstractTuplePtr false_branch_tuple = false_branch_abs->cast(); if (true_branch_tuple->elements().size() != false_branch_tuple->elements().size()) { MS_LOG(ERROR) << "true branch size:" << true_branch_tuple->elements().size() << ", not equal to false banch size:" << false_branch_tuple->elements().size() << " "; return false; } bool all_compatible = true; for (size_t i = 0; i < true_branch_tuple->elements().size(); i++) { all_compatible = all_compatible && GraphOutputCompatible(true_branch_tuple->elements()[i], false_branch_tuple->elements()[i]); } return all_compatible; } TypePtr true_branch_type = true_branch_abs->BuildType(); TypePtr false_branch_type = false_branch_abs->BuildType(); MS_LOG(DEBUG) << "branch output Type equal?" << (*true_branch_type == *false_branch_type) << " true:" << true_branch_type->ToString() << " false:" << false_branch_type->ToString(); return (*true_branch_type == *false_branch_type); } AnfNodePtr GenerateMergeNodes(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node, const AbstractBasePtr &true_graph_output_abs, const AbstractBasePtr &false_graph_output_abs, const FuncGraphPtr &switch_graph, const AnfNodePtr &cond) { MS_EXCEPTION_IF_NULL(true_graph_output_abs); MS_EXCEPTION_IF_NULL(false_graph_output_abs); MS_EXCEPTION_IF_NULL(cond); MS_EXCEPTION_IF_NULL(switch_graph); auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast(); MS_EXCEPTION_IF_NULL(PrimMerge); if (!true_graph_output_abs->isa()) { std::vector merge_nodes; merge_nodes.push_back(NewValueNode(PrimMerge)); std::vector make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), true_output_node, false_output_node}; merge_nodes.push_back(switch_graph->NewCNode(make_tuple_nodes)); std::vector tuple_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), switch_graph->NewCNode(merge_nodes), NewValueNode(MakeValue(0))}; return switch_graph->NewCNode(tuple_getitem_nodes); } else { abstract::AbstractTuplePtr true_branch_tuple = true_graph_output_abs->cast(); abstract::AbstractTuplePtr false_branch_tuple = false_graph_output_abs->cast(); std::vector make_tuple_nodes; make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); for (size_t i = 0; i < true_branch_tuple->elements().size(); i++) { std::vector true_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), true_output_node, NewValueNode(MakeValue(SizeToInt(i)))}; auto true_node = switch_graph->NewCNode(true_getitem_nodes); std::vector false_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), false_output_node, NewValueNode(MakeValue(SizeToInt(i)))}; auto false_node = switch_graph->NewCNode(false_getitem_nodes); auto merge_node = GenerateMergeNodes(true_node, false_node, true_branch_tuple->elements()[i], false_branch_tuple->elements()[i], switch_graph, cond); make_tuple_nodes.push_back(merge_node); } return switch_graph->NewCNode(make_tuple_nodes); } } AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node, const AbstractBasePtr &true_graph_output_abs, const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond, const FuncGraphPtr &switch_graph) { if (!GraphOutputCompatible(true_graph_output_abs, false_graph_output_abs)) { MS_LOG(EXCEPTION) << "Switch output branch not compatible, true:" << true_graph_output_abs->ToString() << ", false:" << false_graph_output_abs->ToString(); } return GenerateMergeNodes(true_output_node, false_output_node, true_graph_output_abs, false_graph_output_abs, switch_graph, cond); } } // namespace internal } // namespace irpass } // namespace opt } // namespace mindspore