/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "frontend/optimizer/control_depend.h" #include #include #include #include #include "frontend/optimizer/optimizer.h" namespace mindspore { namespace opt { std::vector DoControlDepend(const FuncGraphPtr &graph, const CNodePtr &return_node, const std::vector &effect_index, const std::vector &cnodes) { std::vector depend_nodes{NewValueNode(prim::kPrimDepend), return_node->input(1)}; std::vector make_tuple{NewValueNode(prim::kPrimMakeTuple)}; size_t effect_size = effect_index.size(); for (size_t i = 0; i < effect_size; i++) { size_t pre_index = 0; if (i > 0) { pre_index = effect_index[i - 1] + 1; } size_t this_index = effect_index[i]; size_t last_index = cnodes.size() - 2; if (i < effect_size - 1) { last_index = effect_index[i + 1]; } if (this_index > pre_index) { std::vector pre_segment; for (size_t k = pre_index; k < this_index; k++) { // Skip depend, make_tuple, and tuple_get_item, because these primitives are not real operator in GE. if (IsPrimitiveCNode(cnodes[k], prim::kPrimDepend) || IsPrimitiveCNode(cnodes[k], prim::kPrimMakeTuple) || IsPrimitiveCNode(cnodes[k], prim::kPrimTupleGetItem)) { continue; } pre_segment.push_back(cnodes[k]); } auto roots = FindRoots(pre_segment); for (auto iter = roots->begin(); iter != roots->end(); (void)iter++) { AnfNodePtr control_depend = graph->NewCNode({NewValueNode(prim::kPrimControlDepend), *iter, cnodes[this_index]}); make_tuple.push_back(control_depend); } } if (last_index > this_index) { std::vector last_segment; for (size_t k = this_index + 1; k <= last_index; k++) { // Skip depend, make_tuple, and tuple_get_item, because these primitives are not real operator in GE. if (IsPrimitiveCNode(cnodes[k], prim::kPrimDepend) || IsPrimitiveCNode(cnodes[k], prim::kPrimMakeTuple) || IsPrimitiveCNode(cnodes[k], prim::kPrimTupleGetItem)) { continue; } last_segment.push_back(cnodes[k]); } auto leaves = FindLeaves(last_segment); for (auto iter = leaves->begin(); iter != leaves->end(); (void)iter++) { AnfNodePtr control_depend = graph->NewCNode({NewValueNode(prim::kPrimControlDepend), cnodes[this_index], *iter}); make_tuple.push_back(control_depend); } } } depend_nodes.push_back(graph->NewCNode(make_tuple)); return depend_nodes; } void AddControlDepend(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); std::list orders = graph->GetOrderedCnodes(); std::vector cnodes(orders.begin(), orders.end()); size_t cnodes_size = cnodes.size(); // get effect index of cnodes std::vector effect_index{}; for (size_t i = 0; i < cnodes_size; i++) { if (graph->HasEffect(cnodes[i])) { effect_index.push_back(i); } } if (effect_index.empty()) { return; } AnfNodePtr last_node = cnodes[cnodes_size - 1]; CNodePtr return_node; if (last_node->isa()) { return_node = last_node->cast(); } MS_EXCEPTION_IF_NULL(return_node); if (!IsPrimitiveCNode(return_node, prim::kPrimReturn)) { MS_LOG(EXCEPTION) << "The last cnode after sorting, not return cnode."; } if (return_node->inputs().size() < 2) { MS_LOG(EXCEPTION) << "Number of return node inputs should be great than or equal to 2."; } auto depend_node_inputs = DoControlDepend(graph, return_node, effect_index, cnodes); auto depend_cnode = graph->NewCNode(depend_node_inputs); depend_cnode->set_abstract(depend_cnode->input(1)->abstract()); auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); if (!manager->Replace(return_node->input(1), depend_cnode)) { MS_LOG(EXCEPTION) << "Depend replace node failed"; } } } // namespace opt } // namespace mindspore