/** * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "frontend/optimizer/graph_transform.h" #include #include #include #include "ir/graph_utils.h" namespace mindspore { /* namespace to support opt */ namespace opt { // check cnode input values, whether it is tuple input bool CNodeHasTupleInput(const CNodePtr &cnode) { auto &inputs = cnode->inputs(); for (size_t i = 1; i < inputs.size(); i++) { if (IsValueNode(inputs[i])) { continue; } if (IsValueNode(inputs[i])) { // unexpected high order primitvie as cnode input when transform graph MS_LOG(WARNING) << "CheckTupleInput, got unexpected primitive as input" << cnode->DebugString(); return false; } auto abs = inputs[i]->abstract(); if (abs == nullptr) { MS_LOG(WARNING) << "CheckTupleInput, got abstract nullptr for node:" << cnode->DebugString(); return false; } if (abs->isa()) { return true; } } return false; } bool FuncGraphHasTupleInput(const FuncGraphPtr &fg) { auto ¶ms = fg->parameters(); for (auto ¶m : params) { if (param->abstract()->isa()) { return true; } } return false; } std::vector TransformTupleArgument(const FuncGraphPtr &fg, const AnfNodePtr &node, const abstract::AbstractTuplePtr &abs) { auto &elements = abs->elements(); std::vector tuple_node_expanded; for (size_t i = 0; i < elements.size(); i++) { auto elem_node = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(SizeToLong(i))}); elem_node->set_abstract(elements[i]); if (elements[i]->isa()) { auto nodes = TransformTupleArgument(fg, elem_node, elements[i]->cast()); tuple_node_expanded.insert(tuple_node_expanded.end(), nodes.begin(), nodes.end()); } else { tuple_node_expanded.push_back(elem_node); } } return tuple_node_expanded; } AnfNodePtr TransformCallGraph(const FuncGraphPtr &trans_fg, const CNodePtr &cnode) { auto &cinputs = cnode->inputs(); auto fg = cnode->func_graph(); std::vector inputs; inputs.push_back(NewValueNode(trans_fg)); for (size_t i = 1; i < cinputs.size(); i++) { auto abs = cinputs[i]->abstract(); if (abs == nullptr) { MS_LOG(EXCEPTION) << "TransformCallGraph:Node abstract should not be nullptr" << cinputs[i]->DebugString(); } if (abs->isa()) { auto nodes = TransformTupleArgument(fg, cinputs[i], abs->cast()); inputs.insert(inputs.end(), nodes.begin(), nodes.end()); } else { inputs.push_back(cinputs[i]); } } auto new_node = fg->NewCNode(inputs); new_node->set_abstract(cnode->abstract()); return new_node; } AnfNodePtr TransformPartial(const FuncGraphPtr &trans_fg, const CNodePtr &cnode) { auto &cinputs = cnode->inputs(); auto fg = cnode->func_graph(); std::vector inputs; inputs.push_back(NewValueNode(prim::kPrimPartial)); inputs.push_back(NewValueNode(trans_fg)); for (size_t i = 2; i < cinputs.size(); i++) { auto abs = cinputs[i]->abstract(); if (abs == nullptr) { MS_LOG(EXCEPTION) << "TransformPartial:Node abstract should not be nullptr" << cinputs[i]->DebugString(); } if (abs->isa()) { auto nodes = TransformTupleArgument(fg, cinputs[i], abs->cast()); inputs.insert(inputs.end(), nodes.begin(), nodes.end()); } else { inputs.push_back(cinputs[i]); } } auto new_node = fg->NewCNode(inputs); new_node->set_abstract(cnode->abstract()); return new_node; } AnfNodePtr TransformSwitchCall(const AnfNodePtr &swtich_node, const CNodePtr &cnode) { auto &cinputs = cnode->inputs(); auto fg = cnode->func_graph(); std::vector inputs; inputs.push_back(swtich_node); for (size_t i = 1; i < cinputs.size(); i++) { auto abs = cinputs[i]->abstract(); if (abs == nullptr) { MS_LOG(EXCEPTION) << "TransformSwitchCall:Node abstract should not be nullptr" << cinputs[i]->DebugString(); } if (abs->isa()) { auto nodes = TransformTupleArgument(fg, cinputs[i], abs->cast()); inputs.insert(inputs.end(), nodes.begin(), nodes.end()); } else { inputs.push_back(cinputs[i]); } } auto new_node = fg->NewCNode(inputs); new_node->set_abstract(cnode->abstract()); return new_node; } } // namespace opt } // namespace mindspore