/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tools/optimizer/graph/if_pass.h" #include #include #include #include "mindspore/lite/include/errorcode.h" #include "mindspore/lite/src/ops/primitive_c.h" #include "tools/anf_importer/import_from_meta_graphT.h" #include "tools/optimizer/common/gllo_utils.h" #include "src/ops/primitive_c.h" #include "schema/inner/model_generated.h" #include "src/tensor.h" #include "src/common/log_adapter.h" #include "src/ops/switch.h" #include "src/ops/partial.h" namespace mindspore::opt { ValueNodePtr IfPass::GetSwitchAnfPrim() { std::unique_ptr switch_primitiveT(new (std::nothrow) schema::PrimitiveT); if (switch_primitiveT == nullptr) { MS_LOG(ERROR) << "new switch_primitiveT failed"; return nullptr; } switch_primitiveT->value.type = schema::PrimitiveType_Switch; switch_primitiveT->value.value = new (std::nothrow) schema::SwitchT; if (switch_primitiveT->value.value == nullptr) { MS_LOG(ERROR) << "new MakeTupleT failed"; return nullptr; } auto partial_prim = std::make_shared(switch_primitiveT.release()); ValueNodePtr partial_anf_prim = NewValueNode(partial_prim); return partial_anf_prim; } void IfPass::ReplaceInput(const std::vector &node_list, AnfNodePtr new_input_cnode, std::string para_name) { for (auto &node : node_list) { if (utils::isa(node)) { auto cnode = utils::cast(node); for (size_t k = 0; k < cnode->inputs().size(); k++) { if (!utils::isa(cnode->input(k))) { continue; } auto para_input = utils::cast(cnode->input(k)); if (para_input->name() == para_name) { cnode->set_input(k, new_input_cnode); } } } } } bool IfPass::Run(const FuncGraphPtr &graph) { auto node_list = TopoSort(graph->get_return()); for (auto &node : node_list) { if (!utils::isa(node)) { continue; } if (opt::GetCNodeType(node) != schema::PrimitiveType_If) { continue; } auto if_cnode = node->cast(); MS_ASSERT(if_cnode != nullptr); if (if_cnode->inputs().size() < kIfMinInputSize) { MS_LOG(ERROR) << "if input is not right."; return false; } // the order is fixed. auto then_vnode = if_cnode->input(kIfThenIndex); auto else_vnode = if_cnode->input(kIfElseIndex); auto cond_vnode = if_cnode->input(kIfCondIndex); // else_vnode->cast()->set_value() auto then_fg = GetValueNode>(then_vnode); auto else_fg = GetValueNode>(else_vnode); if (then_fg == nullptr || else_fg == nullptr) { MS_LOG(ERROR) << "Get value as func_graph failed."; lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_FAILED); return false; } // create then partial cnode std::vector then_partial_op_inputs{then_vnode}; // create else partial cnode std::vector else_partial_op_inputs{else_vnode}; // add if op input to then_cnode and else_cnode then_partial_op_inputs.insert(then_partial_op_inputs.end(), if_cnode->inputs().begin() + kIfMinInputSize, if_cnode->inputs().end()); else_partial_op_inputs.insert(else_partial_op_inputs.end(), if_cnode->inputs().begin() + kIfMinInputSize, if_cnode->inputs().end()); auto then_partial_node = graph->NewCNode(then_partial_op_inputs); then_partial_node->set_fullname_with_scope(node->fullname_with_scope() + "-partial-if-then"); then_partial_node->set_abstract(then_fg->output()->abstract()); auto else_partial_node = graph->NewCNode(else_partial_op_inputs); else_partial_node->set_fullname_with_scope(node->fullname_with_scope() + "-partial-if-else"); // create switch cnode ValueNodePtr switch_anf_primitive = GetSwitchAnfPrim(); if (switch_anf_primitive == nullptr) { MS_LOG(ERROR) << "GetSwitchAnfPrim failed."; return false; } // insert switch node std::vector switch_op_inputs = {switch_anf_primitive, then_partial_node, else_partial_node, cond_vnode}; switch_op_inputs.insert(switch_op_inputs.end(), if_cnode->inputs().begin() + kIfMinInputSize, if_cnode->inputs().end()); auto switch_cnode = graph->NewCNode(switch_op_inputs); switch_cnode->set_fullname_with_scope(node->fullname_with_scope() + "-Switch"); switch_cnode->set_abstract(if_cnode->abstract()); // create then partial cnode auto manager = graph->manager(); auto node_users = manager->node_users()[if_cnode]; for (auto &node_user : node_users) { manager->SetEdge(node_user.first, node_user.second, switch_cnode); } } return true; } } // namespace mindspore::opt